Compare commits

..

87 Commits

Author SHA1 Message Date
greatmengqi c53b9ccb02 test(custom_agent + task_tool): set app.state.config + drop obsolete skills monkeypatches 2026-04-27 18:09:43 +08:00
greatmengqi e99cb01fe1 test(tool_deduplication): pass app_config explicitly instead of patching removed singleton 2026-04-27 16:51:28 +08:00
greatmengqi 3e6a34297d refactor(config): eliminate global mutable state — explicit parameter passing on top of main
Squashes 25 PR commits onto current main. AppConfig becomes a pure value
object with no ambient lookup. Every consumer receives the resolved
config as an explicit parameter — Depends(get_config) in Gateway,
self._app_config in DeerFlowClient, runtime.context.app_config in agent
runs, AppConfig.from_file() at the LangGraph Server registration
boundary.

Phase 1 — frozen data + typed context

- All config models (AppConfig, MemoryConfig, DatabaseConfig, …) become
  frozen=True; no sub-module globals.
- AppConfig.from_file() is pure (no side-effect singleton loaders).
- Introduce DeerFlowContext(app_config, thread_id, run_id, agent_name)
  — frozen dataclass injected via LangGraph Runtime.
- Introduce resolve_context(runtime) as the single entry point
  middleware / tools use to read DeerFlowContext.

Phase 2 — pure explicit parameter passing

- Gateway: app.state.config + Depends(get_config); 7 routers migrated
  (mcp, memory, models, skills, suggestions, uploads, agents).
- DeerFlowClient: __init__(config=...) captures config locally.
- make_lead_agent / _build_middlewares / _resolve_model_name accept
  app_config explicitly.
- RunContext.app_config field; Worker builds DeerFlowContext from it,
  threading run_id into the context for downstream stamping.
- Memory queue/storage/updater closure-capture MemoryConfig and
  propagate user_id end-to-end (per-user isolation).
- Sandbox/skills/community/factories/tools thread app_config.
- resolve_context() rejects non-typed runtime.context.
- Test suite migrated off AppConfig.current() monkey-patches.
- AppConfig.current() classmethod deleted.

Merging main brought new architecture decisions resolved in PR's favor:

- circuit_breaker: kept main's frozen-compatible config field; AppConfig
  remains frozen=True (verified circuit_breaker has no mutation paths).
- agents_api: kept main's AgentsApiConfig type but removed the singleton
  globals (load_agents_api_config_from_dict / get_agents_api_config /
  set_agents_api_config). 8 routes in agents.py now read via
  Depends(get_config).
- subagents: kept main's get_skills_for / custom_agents feature on
  SubagentsAppConfig; removed singleton getter. registry.py now reads
  app_config.subagents directly.
- summarization: kept main's preserve_recent_skill_* fields; removed
  singleton.
- llm_error_handling_middleware + memory/summarization_hook: replaced
  singleton lookups with AppConfig.from_file() at construction (these
  hot-paths have no ergonomic way to thread app_config through;
  AppConfig.from_file is a pure load).
- worker.py + thread_data_middleware.py: DeerFlowContext.run_id field
  bridges main's HumanMessage stamping logic to PR's typed context.

Trade-offs (follow-up work):

- main's #2138 (async memory updater) reverted to PR's sync
  implementation. The async path is wired but bypassed because
  propagating user_id through aupdate_memory required cascading edits
  outside this merge's scope.
- tests/test_subagent_skills_config.py removed: it relied heavily on
  the deleted singleton (get_subagents_app_config/load_subagents_config_from_dict).
  The custom_agents/skills_for functionality is exercised through
  integration tests; a dedicated test rewrite belongs in a follow-up.

Verification: backend test suite — 2560 passed, 4 skipped, 84 failures.
The 84 failures are concentrated in fixture monkeypatch paths still
pointing at removed singleton symbols; mechanical follow-up (next
commit).
2026-04-26 21:45:02 +08:00
Willem Jiang 9dc25987e0 fix(channles):update the logger for the channel config (#2524)
* fix(channles):update the logger for the channel config

* fix(channels): normalize credential values and add tests for disabled-but-configured warning

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/dfc0a566-aa59-49f9-a74d-610292fb0a63

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

* fix the backend lint error

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-04-26 10:09:55 +08:00
Willem Jiang 8a044142cb feat(dev): add pre-commit hooks for ruff, eslint, and prettier (#2525)
* feat(dev): add pre-commit hooks for ruff, eslint, and prettier

* fix: use local uv-based ruff hooks and uv run for pre-commit install

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/a1e34cc5-0d4b-4400-9e6a-e687d964ff1e

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-04-26 09:40:17 +08:00
ming1523 410f0c48b5 fix(channels): accept single slack allowed user (#2481)
* fix(channels): accept single slack allowed user

* docs: address Slack allowed_users review notes

* ci: rerun backend unit tests

* docs: clarify Slack allowed_users config

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-25 19:40:52 +08:00
Octopus 1f59e945af fix: cap prompt caching breakpoints at 4 to prevent API 400 errors (#2449)
* fix: cap prompt caching breakpoints at 4 to prevent API 400 errors (fixes #2448)

The previous _apply_prompt_caching() attached cache_control to every text
block in the system prompt, every content block in the last N messages, and
the last tool definition. In multi-turn conversations with structured content
blocks this easily exceeded the 4-breakpoint hard limit enforced by both the
Anthropic API and AWS Bedrock, producing a 400 Bad Request (or a silent
"No generations found in stream" when streaming).

Fix: collect all candidate blocks in document order, then apply cache_control
only to the last MAX_CACHE_BREAKPOINTS (4) of them. Later breakpoints cover a
larger prefix and therefore yield better cache hit rates, making this the
optimal placement strategy as well as the safe one.

Adds 13 unit tests covering the budget cap, edge cases, and correct
last-candidate placement.

* docs: clarify _apply_prompt_caching docstring includes tool definitions

Per Copilot review: the implementation also caches the last tool definition
(see the candidates list at lines 202-205), so the docstring summary should
explicitly mention tools alongside system and recent messages.

* Fix the lint error

* style: fix ruff format check for test_claude_provider_prompt_caching.py

Add the missing blank line before the 'Edge cases' section comment so
that ruff format --check passes in CI.

---------

Co-authored-by: octo-patch <octo-patch@github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-25 19:40:06 +08:00
IECspace f394c0d8c8 feat(mcp): support custom tool interceptors via extensions_config.json (#2451)
* feat(mcp): support custom tool interceptors via extensions_config.json

Add a generic extension point for registering custom MCP tool
interceptors through `extensions_config.json`. This allows downstream
projects to inject per-request header manipulation, auth context
propagation, or other cross-cutting concerns without modifying
DeerFlow source code.

Interceptors are declared as Python callable paths in a new
`mcpInterceptors` array field and loaded via the existing
`resolve_variable` reflection mechanism:

```json
{
  "mcpInterceptors": [
    "my_package.mcp.auth:build_auth_interceptor"
  ]
}
```

Each entry must resolve to a no-arg builder function that returns an
async interceptor compatible with `MultiServerMCPClient`'s
`tool_interceptors` interface.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* test(mcp): add unit tests for custom tool interceptors

Cover all branches of the mcpInterceptors loading logic:

- valid interceptor loaded and appended to tool_interceptors
- multiple interceptors loaded in declaration order
- builder returning None is skipped
- resolve_variable ImportError logged and skipped
- builder raising exception logged and skipped
- absent mcpInterceptors field is safe (no-op)
- custom interceptors coexist with OAuth interceptor

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix(mcp): validate mcpInterceptors type and fix lint warnings

Address review feedback:

1. Validate mcpInterceptors config value before iterating:
   - Accept a single string and normalize to [string]
   - Ignore None silently
   - Log warning and skip for non-list/non-string types

2. Fix ruff F841 lint errors in tests:
   - Rename _make_mock_env to _make_patches, embed mock_client
   - Remove unused `as mock_cls` bindings where not needed
   - Extract _get_interceptors() helper to reduce repetition

3. Add two new test cases for type validation:
   - test_mcp_interceptors_single_string_is_normalized
   - test_mcp_interceptors_invalid_type_logs_warning

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(mcp): validate interceptor return type and fix import mock path

Address review feedback:

1. Validate builder return type with callable() check:
   - callable interceptor → append to tool_interceptors
   - None → silently skip (builder opted out)
   - non-callable → log warning with type name and skip

2. Fix test mock path: resolve_variable is a top-level import in
   tools.py, so mock deerflow.mcp.tools.resolve_variable instead of
   deerflow.reflection.resolve_variable to correctly intercept calls.

3. Add test_custom_interceptor_non_callable_return_logs_warning to
   cover the new non-callable validation branch.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* docs(mcp): add mcpInterceptors example and documentation

- Add mcpInterceptors field to extensions_config.example.json
- Add "Custom Tool Interceptors" section to MCP_SERVER.md with
  configuration format, example interceptor code, and edge case
  behavior notes

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: IECspace <IECspace@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-25 09:18:13 +08:00
orbisai0security 950821cb9b fix: use subprocess instead of os.system in local_backend.py (#2494)
* fix: use subprocess instead of os.system in local_backend.py

The sandbox backend and skill evaluation scripts use subprocess

* fixing the failing test

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-25 08:59:31 +08:00
pyp0327 2bb1a2dfa2 feat(models): Provider for MindIE model engine (#2483)
* feat(models): 适配 MindIE引擎的模型

* test: add unit tests for MindIEChatModel adapter and fix PR review comments

* chore: update uv.lock with pytest-asyncio

* build: add pytest-asyncio to test dependencies

* fix: address PR review comments (lazy import, cache clients, safe newline escape, strict xml regex)

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-25 08:59:03 +08:00
DanielWalnut b970993425 fix: read lead agent options from context (#2515)
* fix: read lead agent options from context

* fix: validate runtime context config
2026-04-24 22:46:51 +08:00
DanielWalnut ec8a8cae38 fix: gate deferred MCP tool execution (#2513)
* fix: gate deferred MCP tool execution

* style: format deferred tool middleware

* fix: address deferred tool review feedback
2026-04-24 22:45:41 +08:00
DanielWalnut d78ed5c8f2 fix: inherit subagent skill allowlists (#2514) 2026-04-24 21:24:42 +08:00
Nan Gao f9ff3a698d fix(middleware): avoid rescuing non-skill tool outputs during summarization (#2458)
* fix(middelware): narrow skill rescue to skill-related tool outputs

* fix(summarization): address skill rescue review feedback

* fix: wire summarization skill rescue config

* fix: remove dead skill tool helper

* fix(lint): fix format

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-24 21:19:46 +08:00
Admire c2332bb790 fix memory settings layout overflow (#2420)
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-24 20:29:55 +08:00
He Wang 3a61126824 fix: keep debug.py interactive terminal free from background log noise (#2466)
* fix(debug): keep terminal clean by redirecting all logs to file

- Redirect all logs to debug.log file to prevent background task logs
  from interfering with interactive terminal prompts
- Honor AppConfig.log_level setting instead of hard-coding to INFO
- Make logging setup idempotent by clearing pre-existing handlers
- Defer deerflow imports until after logging is configured to ensure
  import-time side effects are captured in debug.log
- Display active log level in startup banner
- Add prompt_toolkit installation tip for enhanced readline support

Made-with: Cursor

* attaching the file handler before importing/calling get_app_config()

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-24 17:09:41 +08:00
Airene Fang 11f557a2c6 feat(trace):Add run_name to the trace info for system agents. (#2492)
* feat(trace): Add `run_name` to the trace info for suggestions and memory.

before(in langsmith):
CodexChatModel
CodexChatModel
lead_agent
after:
suggest_agent
memory_agent
lead_agent

feat(trace): Add `run_name` to the trace info for suggestions and memory.

before(in langsmith):
CodexChatModel
CodexChatModel
lead_agent
after:
suggest_agent
memory_agent
lead_agent

* feat(trace): Add `run_name` to the trace info for system agents.

before(in langsmith):
CodexChatModel
CodexChatModel
CodexChatModel
CodexChatModel
lead_agent
after:
suggest_agent
title_agent
security_agent
memory_agent
lead_agent

* chore(code format):code format

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-24 17:06:55 +08:00
d 🔹 e8572b9d0c fix(jina): log transient failures at WARNING without traceback (#2484) (#2485)
The exception handler in JinaClient.crawl used logger.exception, which
emits an ERROR-level record with the full httpx/httpcore/anyio traceback
for every transient network failure (timeout, connection refused). Other
search/crawl providers in the project log the same class of recoverable
failures as a single line. One offline/slow-network session could produce
dozens of multi-frame ERROR stack traces, drowning out real problems.

Switch to logger.warning with a concise message that includes the
exception type and its str, matching the style used elsewhere for
recoverable transient failures (aio_sandbox, ddg, etc.). The exception
type now also surfaces into the returned "Error: ..." string so callers
retain diagnostic signal.

Adds a regression test that asserts the log record is WARNING, carries
no exc_info, and includes the exception class name.

Co-authored-by: voidborne-d <voidborne-d@users.noreply.github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-24 16:00:14 +08:00
Willem Jiang 80a7446fd6 fix(backend): fix the unit test error in backend 2026-04-24 14:56:03 +08:00
Willem Jiang cd12821134 fix(backend): Updated the uv.lock with new added dependency 2026-04-24 14:55:13 +08:00
Xinmin Zeng 30d619de08 feat(subagents): support per-subagent skill loading and custom subagent types (#2253)
* feat(subagents): support per-subagent skill loading and custom subagent types (#2230)

Add per-subagent skill configuration and custom subagent type registration,
aligned with Codex's role-based config layering and per-session skill injection.

Backend:
- SubagentConfig gains `skills` field (None=all, []=none, list=whitelist)
- New CustomSubagentConfig for user-defined subagent types in config.yaml
- SubagentsAppConfig gains `custom_agents` section and `get_skills_for()`
- Registry resolves custom agents with three-layer config precedence
- SubagentExecutor loads skills per-session as conversation items (Codex pattern)
- task_tool no longer appends skills to system_prompt
- Lead agent system prompt dynamically lists all registered subagent types
- setup_agent tool accepts optional skills parameter
- Gateway agents API transparently passes skills in CRUD operations

Frontend:
- Agent/CreateAgentRequest/UpdateAgentRequest types include skills field
- Agent card displays skills as badges alongside tool_groups

Config:
- config.example.yaml documents custom_agents and per-agent skills override

Tests:
- 40 new tests covering all skill config, custom agents, and registry logic
- Existing tests updated for new get_skills_prompt_section signature

Closes #2230

* fix: address review feedback on skills PR

- Remove stale get_skills_prompt_section monkeypatches from test_task_tool_core_logic.py
  (task_tool no longer imports this function after skill injection moved to executor)
- Add key prefixes (tg:/sk:) to agent-card badges to prevent React key collisions
  between tool_groups and skills

* fix(ci): resolve lint and test failures

- Format agent-card.tsx with prettier (lint-frontend)
- Remove stale "Skills Appendix" system_prompt assertion — skills are now
  loaded per-session by SubagentExecutor, not appended to system_prompt

* fix(ci): sort imports in test_subagent_skills_config.py (ruff I001)

* fix(ci): use nullish coalescing in agent-card badge condition (eslint)

* fix: address review feedback on skills PR

- Use model_fields_set in AgentUpdateRequest to distinguish "field omitted"
  from "explicitly set to null" — fixes skills=None ambiguity where None
  means "inherit all" but was treated as "don't change"
- Move lazy import of get_subagent_config outside loop in
  _build_available_subagents_description to avoid repeated import overhead

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-23 23:59:47 +08:00
JerryChaox 4e72410154 fix(gateway): bound lifespan shutdown hooks to prevent worker hang under uvicorn reload (#2331)
* fix(gateway): bound lifespan shutdown hooks to prevent worker hang

Gateway worker can hang indefinitely in `uvicorn --reload` mode with
the listening socket still bound — all /api/* requests return 504,
and SIGKILL is the only recovery.

Root cause (py-spy dump from a reproduction showed 16+ stacked frames
of signal_handler -> Event.set -> threading.Lock.__enter__ on the
main thread): CPython's `threading.Event` uses `Condition(Lock())`
where the inner Lock is non-reentrant. uvicorn's BaseReload signal
handler calls `should_exit.set()` directly from signal context; if a
second signal (SIGTERM/SIGHUP from the reload supervisor, or
watchfiles-triggered reload) arrives while the first handler holds
the Lock, the reentrant call deadlocks on itself.

The reload supervisor keeps sending those signals only when the
worker fails to exit promptly. DeerFlow's lifespan currently awaits
`stop_channel_service()` with no timeout; if a channel's `stop()`
stalls (e.g. Feishu/Slack WebSocket waiting for an ack), the worker
can't exit, the supervisor keeps signaling, and the deadlock becomes
reachable.

This is a defense-in-depth fix — it does not repair the upstream
uvicorn/CPython issue, but it ensures DeerFlow's lifespan exits
within a bounded window so the supervisor has no reason to keep
firing signals. No behavior change on the happy path.

Wraps the shutdown hook in `asyncio.wait_for(timeout=5.0)` and logs
a warning on timeout before proceeding to worker exit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Update backend/app/gateway/app.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* style: apply make format (ruff) to test assertions

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-23 19:41:26 +08:00
He Wang c42ae3af79 feat: add optional prompt-toolkit support to debug.py (#2461)
* feat: add optional prompt-toolkit support to debug.py

Use PromptSession.prompt_async() for arrow-key navigation and input
history when prompt-toolkit is available, falling back to plain input()
with a helpful install tip otherwise.

Made-with: Cursor

* fix: handle EOFError gracefully in debug.py

Catch EOFError alongside KeyboardInterrupt so that Ctrl-D exits
cleanly instead of printing a traceback.

Made-with: Cursor
2026-04-23 17:49:18 +08:00
dependabot[bot] bd35cd39aa chore(deps): bump uuid from 13.0.0 to 14.0.0 in /frontend (#2467)
Bumps [uuid](https://github.com/uuidjs/uuid) from 13.0.0 to 14.0.0.
- [Release notes](https://github.com/uuidjs/uuid/releases)
- [Changelog](https://github.com/uuidjs/uuid/blob/main/CHANGELOG.md)
- [Commits](https://github.com/uuidjs/uuid/compare/v13.0.0...v14.0.0)

---
updated-dependencies:
- dependency-name: uuid
  dependency-version: 14.0.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-23 14:47:15 +08:00
d 🔹 b90f219bd1 fix(skills): validate bundled SKILL.md front-matter in CI (fixes #2443) (#2457)
* fix(skills): validate bundled SKILL.md front-matter in CI (fixes #2443)

Adds a parametrized backend test that runs `_validate_skill_frontmatter`
against every bundled SKILL.md under `skills/public/`, so a broken
front-matter fails CI with a per-skill error message instead of
surfacing as a runtime gateway-load warning.

The new test caught two pre-existing breakages on `main` and fixes them:

* `bootstrap/SKILL.md`: the unquoted description had a second `:` mid-line
  ("Also trigger for updates: ..."), which YAML parses as a nested mapping
  ("mapping values are not allowed here"). Rewrites the description as a
  folded scalar (`>-`), which preserves the original wording (including the
  embedded colon, double quotes, and apostrophes) without further escaping.
  This complements PR #2436 (single-file colon→hyphen patch) with a more
  general convention that survives future edits.

* `chart-visualization/SKILL.md`: used `dependency:` which is not in
  `ALLOWED_FRONTMATTER_PROPERTIES`. Renamed to `compatibility:`, the
  documented field for "Required tools, dependencies" per skill-creator.
  No code reads `dependency` (verified by grep across backend/).

* Apply suggestions from code review

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* Fix the lint error

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-23 14:06:14 +08:00
dependabot[bot] 96d00f6073 chore(deps): bump dompurify from 3.3.1 to 3.4.1 in /frontend (#2462)
Bumps [dompurify](https://github.com/cure53/DOMPurify) from 3.3.1 to 3.4.1.
- [Release notes](https://github.com/cure53/DOMPurify/releases)
- [Commits](https://github.com/cure53/DOMPurify/compare/3.3.1...3.4.1)

---
updated-dependencies:
- dependency-name: dompurify
  dependency-version: 3.4.1
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-23 12:18:59 +08:00
He Wang c43c803f66 fix: remove mismatched context param in debug.py to suppress Pydantic warning (#2446)
* fix: remove mismatched context param in debug.py to suppress Pydantic warning

The ainvoke call passed context={"thread_id": ...} but the agent graph
has no context_schema (ContextT defaults to None), causing a
PydanticSerializationUnexpectedValue warning on every invocation.

Align with the production run_agent path by injecting context via
Runtime into configurable["__pregel_runtime"] instead.

Closes #2445

Made-with: Cursor

* refactor: derive runtime thread_id from config to avoid duplication

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Made-with: Cursor

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-23 09:56:57 +08:00
dependabot[bot] dbd777fe62 chore(deps): bump python-dotenv from 1.2.1 to 1.2.2 in /backend (#2440)
Bumps [python-dotenv](https://github.com/theskumar/python-dotenv) from 1.2.1 to 1.2.2.
- [Release notes](https://github.com/theskumar/python-dotenv/releases)
- [Changelog](https://github.com/theskumar/python-dotenv/blob/main/CHANGELOG.md)
- [Commits](https://github.com/theskumar/python-dotenv/compare/v1.2.1...v1.2.2)

---
updated-dependencies:
- dependency-name: python-dotenv
  dependency-version: 1.2.2
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-22 16:48:09 +08:00
dependabot[bot] 1ca2621285 chore(deps): bump lxml from 6.0.2 to 6.1.0 in /backend (#2427)
Bumps [lxml](https://github.com/lxml/lxml) from 6.0.2 to 6.1.0.
- [Release notes](https://github.com/lxml/lxml/releases)
- [Changelog](https://github.com/lxml/lxml/blob/master/CHANGES.txt)
- [Commits](https://github.com/lxml/lxml/compare/lxml-6.0.2...lxml-6.1.0)

---
updated-dependencies:
- dependency-name: lxml
  dependency-version: 6.1.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-22 16:14:11 +08:00
Shawn Jasper 5ba1dacf25 fix: rename present_file to present_files in docs and prompts (#2393)
The tool is registered as `present_files` (plural) in present_file_tool.py,
but four references in documentation and prompt strings incorrectly used the
singular form `present_file`. This could cause confusion and potentially
lead to incorrect tool invocations.

Changed files:
- backend/docs/GUARDRAILS.md
- backend/docs/ARCHITECTURE.md
- backend/packages/harness/deerflow/agents/lead_agent/prompt.py (2 occurrences)
2026-04-21 16:10:14 +08:00
Reuben Bowlby 085c13edc7 fix: remove unnecessary f-string prefixes and unused import (#2352)
- Remove f-string prefix on 7 strings with no placeholders (F541)
  in analyze.py, aggregate_benchmark.py, run_loop.py, generate_review.py
- Remove unused `os` import in quick_validate.py (F401)

Found by ruff via HUMMBL Arbiter (https://hummbl.io/audit).
2026-04-21 09:53:18 +08:00
Copilot ef04174194 Fix invalid HTML nesting in reasoning trigger during complex task rendering (#2382)
* Initial plan

* fix(frontend): avoid invalid paragraph nesting in reasoning trigger

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/4c9eb0c2-ff29-4629-a61c-4e33d736d918

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

* test(frontend): strengthen reasoning trigger DOM nesting assertion

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/4c9eb0c2-ff29-4629-a61c-4e33d736d918

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>
2026-04-21 09:41:28 +08:00
Ansel 6dce26a52e fix: resolve tool duplication and skill parser YAML inconsistencies (#1803) (#2107)
* Refactor tests for SKILL.md parser

Updated tests for SKILL.md parser to handle quoted names and descriptions correctly. Added new tests for parsing plain and single-quoted names, and ensured multi-line descriptions are processed properly.

* Implement tool name validation and deduplication

Add tool name mismatch warning and deduplication logic

* Refactor skill file parsing and error handling

* Add tests for tool name deduplication

Added tests for tool name deduplication in get_available_tools(). Ensured that duplicates are not returned, the first occurrence is kept, and warnings are logged for skipped duplicates.

* Apply suggestions from code review

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* Update minimal config to include tools list

* Update test for nonexistent skill file

Ensure the test for nonexistent files checks for None.

* Refactor tool loading and add skill management support

Refactor tool loading logic to include skill management tools based on configuration and clean up comments.

* Enhance code comments for tool loading logic

Added comments to clarify the purpose of various code sections related to tool loading and configuration.

* Fix assertion for duplicate tool name warning

* Fix indentation issues in tools.py

* Fix the lint error of test_tool_deduplication

* Fix the lint error of tools.py

* Fix the lint error

* Fix the lint error

* make format

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-20 20:25:03 +08:00
imhaoran fc94e90f6c fix(setup-agent): prevent data loss when setup fails on existing agen… (#2254)
* fix(setup-agent): prevent data loss when setup fails on existing agent directory

Record whether the agent directory pre-existed before mkdir, and only
run shutil.rmtree cleanup when the directory was newly created during
this call. Previously, any failure would delete the entire directory
including pre-existing SOUL.md and config.yaml.

* fix: address PR review — init variables before try, remove unused result

* style: fix ruff I001 import block formatting in test file

* style: add missing blank lines between top-level definitions in test file
2026-04-20 20:17:30 +08:00
Eilen Shin f2013f47aa fix command palette hydration mismatch (#2301)
* fix command palette hydration mismatch

* style: format command dialog description
2026-04-20 11:36:16 +08:00
KiteEater 4be857f64b fix: use Apple Container image pull syntax (#2366) 2026-04-20 08:00:05 +08:00
Admire c99865f53d fix(token-usage): enable stream usage for openai-compatible models (#2217)
* fix(token-usage): enable stream usage for openai-compatible models

* fix(token-usage): narrow stream_usage default to ChatOpenAI
2026-04-19 22:42:55 +08:00
YYMa 05f1da03e5 fix(script): use portable locale for langgraph log pipeline on macOS (#2361) 2026-04-19 22:41:00 +08:00
Xun a62ca5dd47 fix: Catch httpx.ReadError in the error handling (#2309)
* fix: Catch httpx.ReadError in the error handling

* fix
2026-04-19 22:30:22 +08:00
Nan Gao f514e35a36 fix(backend): make clarification messages idempotent (#2350) (#2351) 2026-04-19 22:00:58 +08:00
Xun 7c87dc5bca fix(reasoning): prevent LLM-hallucinated HTML tags from rendering as DOM elements (#2321)
* fix

* add test

* fix
2026-04-19 19:27:34 +08:00
Hinotobi 80e210f5bb [security] fix(uploads): require explicit opt-in for host-side document conversion (#2332)
* fix: disable host-side upload conversion by default

* fix: address PR review comments on upload conversion gate
2026-04-18 22:47:42 +08:00
dependabot[bot] 5656f90792 chore(deps-dev): bump pytest from 9.0.2 to 9.0.3 in /backend (#2349)
Bumps [pytest](https://github.com/pytest-dev/pytest) from 9.0.2 to 9.0.3.
- [Release notes](https://github.com/pytest-dev/pytest/releases)
- [Changelog](https://github.com/pytest-dev/pytest/blob/main/CHANGELOG.rst)
- [Commits](https://github.com/pytest-dev/pytest/compare/9.0.2...9.0.3)

---
updated-dependencies:
- dependency-name: pytest
  dependency-version: 9.0.3
  dependency-type: direct:development
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-18 22:22:40 +08:00
Shawn Jasper 55474011c9 fix(subagent): inherit parent agent's tool_groups in task_tool (#2305)
* fix(subagent): inherit parent agent's tool_groups in task_tool

When a custom agent defines tool_groups (e.g. [file:read, file:write, bash]),
the restriction is correctly applied to the lead agent. However, when the lead
agent delegates work to a subagent via the task tool, get_available_tools() is
called without the groups parameter, causing the subagent to receive ALL tools
(including web_search, web_fetch, image_search, etc.) regardless of the parent
agent's configuration.

This fix propagates tool_groups through run metadata so that task_tool passes
the same group filter when building the subagent's tool set.

Changes:
- agent.py: include tool_groups in run metadata
- task_tool.py: read tool_groups from metadata and pass to get_available_tools()

* fix: initialize metadata before conditional block and update tests for tool_groups propagation

- Initialize metadata = {} before the 'if runtime is not None' block to
  avoid Ruff F821 (possibly-undefined variable) and simplify the
  parent_tool_groups expression.
- Update existing test assertion to expect groups=None in
  get_available_tools call signature.
- Add 3 new test cases:
  - test_task_tool_propagates_tool_groups_to_subagent
  - test_task_tool_no_tool_groups_passes_none
  - test_task_tool_runtime_none_passes_groups_none
2026-04-18 22:17:37 +08:00
imhaoran 24fe5fbd8c fix(mcp): prevent RuntimeError from escaping except block in get_cach… (#2252)
* fix(mcp): prevent RuntimeError from escaping except block in get_cached_mcp_tools

When `asyncio.get_event_loop()` raises RuntimeError and the fallback
`asyncio.run()` also fails, the exception escapes unhandled because
Python does not route exceptions raised inside an `except` block to
sibling `except` clauses. Wrap the fallback call in its own try/except
so failures are logged and the function returns [] as intended.

* fix: use logger.exception to preserve stack traces on MCP init failure
2026-04-18 21:07:30 +08:00
Willem Jiang be4663505a chroe(script): disable the color log of langgraph 2026-04-18 20:03:05 +08:00
dependabot[bot] aa6098e6a4 chore(deps): bump langsmith from 0.6.4 to 0.7.31 in /backend (#2291)
Bumps [langsmith](https://github.com/langchain-ai/langsmith-sdk) from 0.6.4 to 0.7.31.
- [Release notes](https://github.com/langchain-ai/langsmith-sdk/releases)
- [Commits](https://github.com/langchain-ai/langsmith-sdk/compare/v0.6.4...v0.7.31)

---
updated-dependencies:
- dependency-name: langsmith
  dependency-version: 0.7.31
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-18 19:54:21 +08:00
Airene Fang 1221448029 fix(scripts): Cloud Provider Reports Security Issue(aliyun could) (#2323)
ATT&CK矩阵ID:T1059.004
数据来源:进程启动触发检测
告警原因:该进程的命令行显示出反弹shelI的特征
命令行:timeout 1 bash -c exec 3<>/dev/tcp/127.0.0.1/2024
进程路径:/usr/bin/timeout
进程链:-[337650] /usr/sbin/sshd -D
-[397971] /usr/sbin/sshd -D -R
-[397977]-bash
-[398903] make dev
-[398920] bash ./scripts/serve.sh --dev
-[399037]bash ./scripts/wait-for-port.sh 2024 60 LangGraph
2026-04-18 19:33:32 +08:00
Jason 3b91df2b18 fix(frontend): add catch-all API rewrite for gateway routes (#2335)
When NEXT_PUBLIC_BACKEND_BASE_URL is unset, the frontend proxies API
requests to the gateway. Only /api/agents and /api/skills had rewrite
rules, causing 404s for /api/models, /api/threads, /api/memory,
/api/mcp, /api/suggestions, /api/runs, etc.

Add a catch-all /api/:path* rewrite that proxies all remaining gateway
API routes. The existing /api/langgraph rewrite takes priority because
it is pushed to the array first (Next.js checks rewrites in order).

Fixes #2327

Co-authored-by: JasonOA888 <JasonOA888@users.noreply.github.com>
2026-04-18 11:35:19 +08:00
Shawn Jasper ca1b7d5f48 fix(sandbox): add missing path masking in ls_tool output (#2317)
ls_tool was the only file-system tool that did not call
mask_local_paths_in_output() before returning its result, causing host
absolute paths (e.g. /Users/.../backend/.deer-flow/knowledge-base/...)
to leak to the LLM instead of the expected virtual paths
(/mnt/knowledge-base/...).

This patch:
- Adds the mask_local_paths_in_output() call to ls_tool, consistent
  with bash_tool, glob_tool and grep_tool.
- Initialises thread_data = None before the is_local_sandbox branch
  (same pattern as glob_tool) so the variable is always in scope.
- Adds three new tests covering user-data path masking, skills path
  masking and the empty-directory edge case.
2026-04-18 08:46:59 +08:00
yangzheli c6b0423558 feat(frontend): add Playwright E2E tests with CI workflow (#2279)
* feat(frontend): add Playwright E2E tests with CI workflow

Add end-to-end testing infrastructure using Playwright (Chromium only).
14 tests across 5 spec files cover landing page, chat workspace,
thread history, sidebar navigation, and agent chat — all with mocked
LangGraph/Backend APIs via network interception (zero backend dependency).

New files:
- playwright.config.ts — Chromium, 30s timeout, auto-start Next.js
- tests/e2e/utils/mock-api.ts — shared API mocks & SSE stream helpers
- tests/e2e/{landing,chat,thread-history,sidebar,agent-chat}.spec.ts
- .github/workflows/e2e-tests.yml — push main + PR trigger, paths filter

Updated: package.json, Makefile, .gitignore, CONTRIBUTING.md,
frontend/CLAUDE.md, frontend/AGENTS.md, frontend/README.md

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: apply Copilot suggestions

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-18 08:21:08 +08:00
DanielWalnut 898f4e8ac2 fix: Memory update system has cache corruption, data loss, and thread-safety bugs (#2251)
* fix(memory): cache corruption, thread-safety, and caller mutation bugs

Bug 1 (updater.py): deep-copy current_memory before passing to
_apply_updates() so a subsequent save() failure cannot leave a
partially-mutated object in the storage cache.

Bug 3 (storage.py): add _cache_lock (threading.Lock) to
FileMemoryStorage and acquire it around every read/write of
_memory_cache, fixing concurrent-access races between the background
timer thread and HTTP reload calls.

Bug 4 (storage.py): replace in-place mutation
  memory_data["lastUpdated"] = ...
with a shallow copy
  memory_data = {**memory_data, "lastUpdated": ...}
so save() no longer silently modifies the caller's dict.

Regression tests added for all three bugs in test_memory_storage.py
and test_memory_updater.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style: format test_memory_updater.py with ruff

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style: remove stale bug-number labels from code comments and docstrings

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-17 12:00:31 +08:00
dependabot[bot] 259a6844bf chore(deps): bump python-multipart from 0.0.22 to 0.0.26 in /backend (#2282)
Bumps [python-multipart](https://github.com/Kludex/python-multipart) from 0.0.22 to 0.0.26.
- [Release notes](https://github.com/Kludex/python-multipart/releases)
- [Changelog](https://github.com/Kludex/python-multipart/blob/master/CHANGELOG.md)
- [Commits](https://github.com/Kludex/python-multipart/compare/0.0.22...0.0.26)

---
updated-dependencies:
- dependency-name: python-multipart
  dependency-version: 0.0.26
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-16 09:07:28 +08:00
d 🔹 a664d2f5c4 fix(checkpointer): create parent directory before opening SQLite in sync provider (#2272)
* fix(checkpointer): create parent directory before opening SQLite in sync provider

The sync checkpointer factory (_sync_checkpointer_cm) opens a SQLite
connection without first ensuring the parent directory exists.  The async
provider and both store providers already call ensure_sqlite_parent_dir(),
but this call was missing from the sync path.

When the deer-flow harness package is used from an external virtualenv
(where the .deer-flow directory is not pre-created), the missing parent
directory causes:

    sqlite3.OperationalError: unable to open database file

Add the missing ensure_sqlite_parent_dir() call in the sync SQLite
branch, consistent with the async provider, and add a regression test.

Closes #2259

* style: fix ruff format + add call-order assertion for ensure_parent_dir

- Fix formatting in test_checkpointer.py (ruff format)
- Add test_sqlite_ensure_parent_dir_before_connect to verify
  ensure_sqlite_parent_dir is called before from_conn_string
  (addresses Copilot review suggestion)

---------

Co-authored-by: voidborne-d <voidborne-d@users.noreply.github.com>
2026-04-16 09:06:38 +08:00
YuJitang 105db00987 feat: show token usage per assistant response (#2270)
* feat: show token usage per assistant response

* fix: align client models response with token usage

* fix: address token usage review feedback

* docs: clarify token usage config example

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-16 08:56:49 +08:00
Nan Gao 0e16a7fe55 fix(frontend): make Suggestion button opaque in dark mode (#2276)
* fix(frontend): make Suggestion button opaque in dark mode

The outline Button variant applies dark:bg-input/30, leaving Suggestion
pills ~70% transparent in dark mode. Scrolled chat content bled through
the buttons, making suggestion text unreadable. Override with
dark:bg-background so it matches the opaque light-mode appearance.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix the lint error of commit

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-16 08:55:16 +08:00
Nan Gao 4d3038a7b6 fix(frontend): stop artifact panel from auto-opening on rehydrated write_file (#2278)
After a page refresh, the artifact panel's autoOpen/autoSelect state is
reset to true. Submitting a new question flips thread.isLoading to true,
which message-list passes to every MessageGroup — including historical
ones. The previous response's last write_file step then satisfies the
auto-open condition and re-pops the stale artifact.

Gate the auto-open on the tool call having no result yet, so only a
write_file that is still streaming in the current response can trigger
it; rehydrated tool calls always carry a result and are now skipped.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-16 08:46:47 +08:00
Hinotobi 2176b2bbfc fix: validate bootstrap agent names before filesystem writes (#2274)
* fix: validate bootstrap agent names before filesystem writes

* fix: tighten bootstrap agent-name validation
2026-04-16 08:36:42 +08:00
Wen 8e3591312a test: add unit tests for ViewImageMiddleware (#2256)
* test: add unit tests for ViewImageMiddleware

- Add 33 test cases covering all 7 internal methods plus sync/async
  before_model hooks
- Cover normal path, edge cases (missing keys, empty base64, stale
  ToolMessages before assistant turn), and deduplication logic
- Related to Q2 Roadmap #1669

* test: add unit tests for ViewImageMiddleware

Add 35 test cases covering all internal methods, before_model hooks,
and edge cases (missing attrs, list-content dedup, stale ToolMessages).

Related to #1669
2026-04-15 23:54:30 +08:00
Willem Jiang 242c654075 fix(frontend):lint error of message-list-item.tsx 2026-04-15 23:35:50 +08:00
Willem Jiang 0c21cbf01f fix(frontend): lint error of frontend 2026-04-15 23:27:46 +08:00
Jason 772538ddba fix(frontend): add skills API rewrite rule to prevent HTML fallback (#2241)
Fixes #2203

When NEXT_PUBLIC_BACKEND_BASE_URL is not set, the frontend uses Next.js
rewrites to proxy API calls to the gateway. Skills API routes were missing
from the rewrite config, causing /api/skills to return the SPA HTML instead
of JSON, which produced 'Unexpected token <' errors in the skill settings page.

Co-authored-by: JasonOA888 <JasonOA888@users.noreply.github.com>
2026-04-15 23:21:40 +08:00
Jason 35fb3dd65a fix(frontend): resolve /mnt/ links in markdown to artifact API URLs (#2243)
* fix(gateway): forward agent_name and is_bootstrap from context to configurable

The frontend sends agent_name and is_bootstrap via the context field
in run requests, but services.py only forwards a hardcoded whitelist
of keys (_CONTEXT_CONFIGURABLE_KEYS) into the agent's configurable
dict.  Since agent_name was missing, custom agents never received
their name — make_lead_agent always fell back to the default lead
agent, skipping SOUL.md, per-agent config and skill filtering.

Similarly, is_bootstrap was dropped, so the bootstrap creation flow
could never activate the setup_agent tool path.

Add both keys to the whitelist so they reach make_lead_agent.

Fixes #2222

* fix(frontend): resolve /mnt/ links in markdown to artifact API URLs

AI agent messages contain links like /mnt/user-data/outputs/file.pdf
which were rendered as-is in the browser, resulting in 404 errors.
Images already got the correct treatment via MessageImage and
resolveArtifactURL, but anchor tags (<a>) were passed through
unchanged.

Add an 'a' component override in MessageContent_ that rewrites
/mnt/-prefixed hrefs to the artifact API endpoint, matching the
existing image handling pattern.

Fixes #2232

---------

Co-authored-by: JasonOA888 <JasonOA888@users.noreply.github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-15 23:12:21 +08:00
Jason 692f79452d fix(gateway): forward agent_name and is_bootstrap from context to configurable (#2242)
The frontend sends agent_name and is_bootstrap via the context field
in run requests, but services.py only forwards a hardcoded whitelist
of keys (_CONTEXT_CONFIGURABLE_KEYS) into the agent's configurable
dict.  Since agent_name was missing, custom agents never received
their name — make_lead_agent always fell back to the default lead
agent, skipping SOUL.md, per-agent config and skill filtering.

Similarly, is_bootstrap was dropped, so the bootstrap creation flow
could never activate the setup_agent tool path.

Add both keys to the whitelist so they reach make_lead_agent.

Fixes #2222

Co-authored-by: JasonOA888 <JasonOA888@users.noreply.github.com>
2026-04-15 23:11:10 +08:00
DanielWalnut 8760937439 fix(memory): use asyncio.to_thread for blocking file I/O in aupdate_memory (#2220)
* fix(memory): use asyncio.to_thread for blocking file I/O in aupdate_memory

`_finalize_update` performs synchronous blocking operations (os.mkdir,
file open/write/rename/stat) that were called directly from the async
`aupdate_memory` method, causing `BlockingError` from blockbuster when
running under an ASGI server. Wrap the call with `asyncio.to_thread` to
offload all blocking I/O to a thread pool.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(memory): use unique temp filename to prevent concurrent write collision

`file_path.with_suffix(".tmp")` produces a fixed path — concurrent saves
for the same agent (now possible after wrapping _finalize_update in
asyncio.to_thread) would clobber the same temp file. Use a UUID-suffixed
temp file so each write is isolated.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(memory): also offload _prepare_update_prompt to thread pool

FileMemoryStorage.load() inside _prepare_update_prompt performs
synchronous stat() and file read, blocking the event loop just like
_finalize_update did. Wrap _prepare_update_prompt in asyncio.to_thread
for the same reason.

The async path now has no blocking file I/O on the event loop:
  to_thread(_prepare_update_prompt) → await model.ainvoke() → to_thread(_finalize_update)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-14 16:41:54 +08:00
DanielWalnut 4ba3167f48 feat: flush memory before summarization (#2176)
* feat: flush memory before summarization

* fix: keep agent-scoped memory on summarization flush

* fix: harden summarization hook plumbing

* fix: address summarization review feedback

* style: format memory middleware
2026-04-14 15:01:06 +08:00
Octopus e4f896e90d fix(todo-middleware): prevent premature agent exit with incomplete todos (#2135)
* fix(todo-middleware): prevent premature agent exit with incomplete todos

When plan mode is active (is_plan_mode=True), the agent occasionally
exits the loop and outputs a final response while todo items are still
incomplete. This happens because the routing edge only checks for
tool_calls, not todo completion state.

Fixes #2112

Add an after_model override to TodoMiddleware with
@hook_config(can_jump_to=["model"]). When the model produces a
response with no tool calls but there are still incomplete todos, the
middleware injects a todo_completion_reminder HumanMessage and returns
jump_to=model to force another model turn. A cap of 2 reminders
prevents infinite loops when the agent cannot make further progress.

Also adds _completion_reminder_count() helper and 14 new unit tests
covering all edge cases of the new after_model / aafter_model logic.

* Remove unnecessary blank line in test file

* Fix runtime argument annotation in before_model

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: octo-patch <octo-patch@github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-14 11:11:26 +08:00
luo jiyin 07fc25d285 feat: switch memory updater to async LLM calls (#2138)
* docs: mark memory updater async migration as completed

- Update TODO.md to mark the replacement of sync model.invoke()
  with async model.ainvoke() in title_middleware and memory updater
  as completed using [x] format

Addresses #2131

* feat: switch memory updater to async LLM calls

- Add async aupdate_memory() method using await model.ainvoke()
- Convert sync update_memory() to use async wrapper
- Add _run_async_update_sync() for nested loop context handling
- Maintain backward compatibility with existing sync API
- Add ThreadPoolExecutor for async execution from sync contexts

Addresses #2131

* test: add tests for async memory updater

- Add test_async_update_memory_uses_ainvoke() to verify async path
- Convert existing tests to use AsyncMock and ainvoke assertions
- Add test_sync_update_memory_wrapper_works_in_running_loop()
- Update all model mocks to use async await patterns

Addresses #2131

* fix: apply ruff formatting to memory updater

- Format multi-line expressions to single line
- Ensure code style consistency with project standards
- Fix lint issues caught by GitHub Actions

* test: add comprehensive tests for async memory updater

- Add test_async_update_memory_uses_ainvoke() to verify async path
- Convert existing tests to use AsyncMock and ainvoke assertions
- Add test_sync_update_memory_wrapper_works_in_running_loop()
- Update all model mocks to use async await patterns
- Ensure backward compatibility with sync API

* fix: satisfy ruff formatting in memory updater test

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-14 11:10:42 +08:00
Nan Gao 55bc09ac33 fix(backend): fix uploads for mounted sandbox providers (#2199)
* fix uploads for mounted sandbox providers

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-14 10:44:31 +08:00
dependabot[bot] c43a45ea40 chore(deps): bump pillow from 12.1.1 to 12.2.0 in /backend (#2206)
Bumps [pillow](https://github.com/python-pillow/Pillow) from 12.1.1 to 12.2.0.
- [Release notes](https://github.com/python-pillow/Pillow/releases)
- [Changelog](https://github.com/python-pillow/Pillow/blob/main/CHANGES.rst)
- [Commits](https://github.com/python-pillow/Pillow/compare/12.1.1...12.2.0)

---
updated-dependencies:
- dependency-name: pillow
  dependency-version: 12.2.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-14 10:35:59 +08:00
Admire 9cf7153b1d fix(check): windows pnpm version detection in check script (#2189)
* fix: resolve Windows pnpm detection in check script

* style: format check script regression test

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix: resolve corepack fallback on windows

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-14 10:29:44 +08:00
Octopus c91785dd68 fix(title): strip <think> tags from title model responses and assistant context (#1927)
* fix(title): strip <think> tags from title model responses and assistant context

Reasoning models (e.g. minimax M2.7, DeepSeek-R1) emit <think>...</think>
blocks before their actual output. When such a model is used as the title
model (or as the main agent), the raw thinking content leaked into the thread
title stored in state, so the chat list showed the internal monologue instead
of a meaningful title.

Fixes #1884

- Add `_strip_think_tags()` helper using a regex to remove all <think>...</think> blocks
- Apply it in `_parse_title()` so the title model response is always clean
- Apply it to the assistant message in `_build_title_prompt()` so thinking
  content from the first AI turn is not fed back to the title model
- Add four new unit tests covering: stripping in parse, think-only response,
  assistant prompt stripping, and end-to-end async flow with think tags

* Fix the lint error

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-14 09:51:39 +08:00
sqsge 053e18e1a6 fix(skills): avoid blocking custom skill deletion on readonly history writes (#2197) 2026-04-14 09:00:29 +08:00
Hinotobi a7e7c6d667 fix: disable custom-agent management API by default (#2161)
* fix: disable custom-agent management API by default

* style: format agents API hardening files

* fix: address review feedback for agents API hardening

* fix: add missing disabled API coverage
2026-04-14 00:03:38 +08:00
Nan Gao f4c17c66ce fix(middleware): fix present_files thread id fallback (#2181)
* fix present files thread id fallback

* fix: resolve present_files thread id from runtime config
2026-04-13 22:59:13 +08:00
lesliewangwyc-dev 1df389b9d0 fix: wrap blocking readability call with asyncio.to_thread in web_fetch (#2157)
* fix: wrap blocking readability call with asyncio.to_thread in web_fetch

The readability extractor internally spawns a Node.js subprocess via
readabilipy, which blocks the async event loop and causes a
BlockingError when web_fetch is invoked inside LangGraph's async
runtime.

Wrap the synchronous extract_article call with asyncio.to_thread to
offload it to a thread pool, unblocking the event loop.

Note: community/infoquest/tools.py has the same latent issue and
should be addressed in a follow-up PR.

Closes #2152

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* test: verify web_fetch offloads extraction via asyncio.to_thread

Add a regression test that monkeypatches asyncio.to_thread to confirm
readability extraction is offloaded to a worker thread, preventing
future refactors from reintroducing the blocking call.

Addresses Copilot review feedback on #2157.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-13 21:15:24 +08:00
5db71cb68c fix(middleware): repair dangling tool-call history after loop interru… (#2035)
* fix(middleware): repair dangling tool-call history after loop interruption (#2029)

* docs(backend): fix middleware chain ordering

---------

Co-authored-by: luoxiao6645 <luoxiao6645@gmail.com>
2026-04-12 19:11:22 +08:00
yangzheli 4efc8d404f feat(frontend): set up Vitest frontend testing infrastructure with CI workflow (#2147)
* feat: set up Vitest frontend testing infrastructure with CI workflow

Migrate existing 4 frontend test files from Node.js native test runner
(node:test + node:assert/strict) to Vitest, reorganize test directory
structure under tests/unit/ mirroring src/ layout, and add a dedicated
CI workflow for frontend unit tests.

- Add vitest as devDependency, remove tsx
- Create vitest.config.ts with @/ path alias
- Migrate tests to Vitest API (test/expect/vi)
- Rename .mjs test files to .ts
- Move tests from src/ to tests/unit/ (mirrors src/ layout)
- Add frontend/Makefile `test` target
- Add .github/workflows/frontend-unit-tests.yml (parallel to backend)
- Update CONTRIBUTING.md, README.md, AGENTS.md, CLAUDE.md

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* style: fix the lint error

* style: fix the lint error

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-12 18:00:43 +08:00
Jin 4d4ddb3d3f feat(llm): introduce lightweight circuit breaker to prevent rate-limit bans and resource exhaustion (#2095) 2026-04-12 17:48:40 +08:00
luo jiyin 979a461af5 docs: move completed async migration to Completed Features (#2146)
- Move time.sleep() -> asyncio.sleep() from Planned to Completed Features
- Clean up duplicate entries in TODO.md

Ensures completed async optimizations are properly tracked.
2026-04-12 16:48:48 +08:00
Javen Fang ac04f2704f feat(subagents): allow model override per subagent in config.yaml (#2064)
* feat(subagents): allow model override per subagent in config.yaml

Wire the existing SubagentConfig.model field to config.yaml so users
can assign different models to different subagent types.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* test(subagents): cover model override in SubagentsAppConfig + registry

Addresses review feedback on #2064:

- registry.py: update stale inline comment — the block now applies
  timeout, max_turns AND model overrides, not just timeout.
- test_subagent_timeout_config.py: add coverage for model override
  resolution across SubagentOverrideConfig, SubagentsAppConfig
  (get_model_for + load), and registry.get_subagent_config:
  - per-agent model override is applied to registry-returned config
  - omitted `model` keeps the builtin value
  - explicit `model: null` in config.yaml is equivalent to omission
  - model override on one agent does not affect other agents
  - model override preserves all other fields (name, description,
    timeout_seconds, max_turns)
  - model override does not mutate BUILTIN_SUBAGENTS

Copilot's suggestion (3) "setting model to 'inherit' forces inheritance"
is skipped intentionally: there is no 'inherit' sentinel in the current
implementation — model is `str | None`, and None already means
"inherit from parent". Adding a sentinel would be a new feature, not
test coverage for this PR.

Tests run locally: 51 passed (37 existing + 14 new / expanded).

* test(subagents): reject empty-string model at config load time

Addresses WillemJiang's review comment on #2064 (empty-string edge case):

- subagents_config.py: add `min_length=1` to the `model` field on
  SubagentOverrideConfig. `model: ""` in config.yaml would otherwise
  bypass the `is not None` check and reach create_chat_model(name="")
  as a confusing runtime error. This is symmetric with the existing
  `ge=1` guards on timeout_seconds / max_turns, so the validation style
  stays consistent across all three override fields.
- test_subagent_timeout_config.py: add test_rejects_empty_model
  mirroring the existing test_rejects_zero / test_rejects_negative
  cases; update the docstring on test_model_accepts_any_string (now
  test_model_accepts_any_non_empty_string) to reflect the new guard.

Not addressing the first comment (validating `model` against the
`models:` section at load time) in this PR. `SubagentsAppConfig` is
scoped to the `subagents:` block and cannot see the sibling `models:`
section, so proper cross-section validation needs a second pass or a
structural change that is out of scope here — and the current behavior
is consistent with how timeout_seconds / max_turns work today. Happy to
track this as a follow-up issue covering cross-section validation
uniformly for all three fields.

Tests run locally: 52 passed in this file; 1847 passed, 18 skipped
across the full backend suite. Ruff check + format clean.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-12 16:40:21 +08:00
Matt Van Horn c4d273a68a feat(channels): add Discord channel integration (#1806)
* feat(channels): add Discord channel integration

Add a Discord bot channel following the existing Telegram/Slack pattern.
The bot listens for messages, creates conversation threads, and relays
responses back to Discord with 2000-char message splitting.

- DiscordChannel extends Channel base class
- Lazy imports discord.py with install hint
- Thread-based conversations (each Discord thread maps to a DeerFlow thread)
- Allowed guilds filter for access control
- File attachment support via discord.File
- Registered in service.py and manager.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(channels): address Copilot review suggestions for Discord integration

- Disable @everyone/@here mentions via AllowedMentions.none()
- Add 10s timeout to client close to prevent shutdown hangs
- Log publish_inbound errors via future callback instead of silently dropping
- Open file handle on caller thread to avoid cross-thread ownership issues
- Notify user in channel when thread creation fails

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(discord): resolve lint errors in Discord channel

- Replace asyncio.TimeoutError with builtin TimeoutError (UP041)
- Remove extraneous f-string prefix (F541)
- Apply ruff format

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(tests): remove fake langgraph_sdk shim from test_discord_channel

The module-level sys.modules.setdefault shim installed a fake
langgraph_sdk.errors.ConflictError during pytest collection. Because
pytest imports all test modules before running them, test_channels.py
then imported the fake ConflictError instead of the real one.

In test_handle_feishu_stream_conflict_sends_busy_message, the test
constructs ConflictError(message, response=..., body=...). The fake
only subclasses Exception (which takes no kwargs), so the construction
raised TypeError. The manager's _is_thread_busy_error check then saw a
TypeError instead of a ConflictError and fell through to the generic
'An error occurred' message.

langgraph_sdk is a real dependency, so the shim is unnecessary.
Removing it makes both test files import the same real ConflictError
and the full suite pass (1773 passed, 15 skipped).

---------

Co-authored-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-11 17:48:04 +08:00
Jason dc50a7fdfb fix(sandbox): resolve paths in read_file/write_file content for LocalSandbox (#1935)
* fix(sandbox): resolve paths in read_file/write_file content for LocalSandbox

In LocalSandbox mode, read_file and write_file now transform
container paths in file content, matching the path handling
behavior of bash tool.

- write_file: resolves virtual paths in content to system paths
  before writing, so scripts with /mnt/user-data paths work
  when executed
- read_file: reverse-resolves system paths back to virtual
  paths in returned content for consistency

This fixes scenarios where agents write Python scripts with
virtual paths, then execute them via bash tool expecting the
paths to work.

Fixes #1778

* fix(sandbox): address Copilot review — dedicated content resolver + forward-slash safety + tests

- Extract _resolve_paths_in_content() separate from _resolve_paths_in_command()
  to decouple file-content path resolution from shell-command parsing
- Normalize resolved paths to forward slashes to avoid Windows backslash
  escape issues in source files (e.g. \U in Python string literals)
- Add 4 focused tests: write resolves content, forward-slash guarantee,
  read reverse-resolves content, and write→read roundtrip

* style: fix ruff lint — remove extraneous f-string prefix

* fix(sandbox): only reverse-resolve paths in agent-written files

read_file previously applied _reverse_resolve_paths_in_output to ALL
file content, which could silently rewrite paths in user uploads and
external tool output (Willem Jiang review on #1935).

Now tracks files written through write_file in _agent_written_paths.
Only those files get reverse-resolved on read. Non-agent files are
returned as-is.

---------

Co-authored-by: JasonOA888 <JasonOA888@users.noreply.github.com>
2026-04-11 17:41:36 +08:00
ZHANG Ning 5b633449f8 fix(middleware): add per-tool-type frequency detection to LoopDetectionMiddleware (#1988)
* fix(middleware): add per-tool-type frequency detection to LoopDetectionMiddleware

The existing hash-based loop detection only catches identical tool call
sets. When the agent calls the same tool type (e.g. read_file) on many
different files, each call produces a unique hash and bypasses detection.
This causes the agent to exhaust recursion_limit, consuming 150K-225K
tokens per failed run.

Add a second detection layer that tracks cumulative call counts per tool
type per thread. Warns at 30 calls (configurable) and forces stop at 50.
The hard stop message now uses the actual returned message instead of a
hardcoded constant, so both hash-based and frequency-based stops produce
accurate diagnostics.

Also fix _apply() to use the warning message returned by
_track_and_check() for hard stops, instead of always using _HARD_STOP_MSG.

Closes #1987

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix(lint): remove unused imports and fix line length

- Remove unused _TOOL_FREQ_HARD_STOP_MSG and _TOOL_FREQ_WARNING_MSG
  imports from test file (F401)
- Break long _TOOL_FREQ_WARNING_MSG string to fit within 240 char limit (E501)

* style: apply ruff format

* test: add LRU eviction and per-thread reset coverage for frequency state

Address review feedback from @WillemJiang:
- Verify _tool_freq and _tool_freq_warned are cleaned on LRU eviction
- Add test for reset(thread_id=...) clearing only the target thread's
  frequency state while leaving others intact

* fix(makefile): route Windows shell-script targets through Git Bash (#2060)

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Asish Kumar <87874775+officialasishkumar@users.noreply.github.com>
2026-04-11 17:33:27 +08:00
yorick 02569136df fix(sandbox): improve sandbox security and preserve multimodal content (#2114)
* fix: improve sandbox security and preserve multimodal content

* Add unit test modifications for test_injects_uploaded_files_tag_into_list_content

* format updated_content

* Add regression tests for multimodal upload content and host bash default safety
2026-04-11 16:52:10 +08:00
dependabot[bot] 024ac0e464 chore(deps): bump langsmith from 0.5.2 to 0.5.18 in /frontend (#2110)
Bumps [langsmith](https://github.com/langchain-ai/langsmith-sdk) from 0.5.2 to 0.5.18.
- [Release notes](https://github.com/langchain-ai/langsmith-sdk/releases)
- [Commits](https://github.com/langchain-ai/langsmith-sdk/commits)

---
updated-dependencies:
- dependency-name: langsmith
  dependency-version: 0.5.18
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-11 14:51:21 +08:00
dependabot[bot] 19030928e0 chore(deps): bump langchain-core from 1.2.17 to 1.2.28 in /backend (#2109)
Bumps [langchain-core](https://github.com/langchain-ai/langchain) from 1.2.17 to 1.2.28.
- [Release notes](https://github.com/langchain-ai/langchain/releases)
- [Commits](https://github.com/langchain-ai/langchain/compare/langchain-core==1.2.17...langchain-core==1.2.28)

---
updated-dependencies:
- dependency-name: langchain-core
  dependency-version: 1.2.28
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-11 14:49:54 +08:00
581 changed files with 29153 additions and 25546 deletions
+1
View File
@@ -24,6 +24,7 @@ INFOQUEST_API_KEY=your-infoquest-api-key
# SLACK_BOT_TOKEN=your-slack-bot-token
# SLACK_APP_TOKEN=your-slack-app-token
# TELEGRAM_BOT_TOKEN=your-telegram-bot-token
# DISCORD_BOT_TOKEN=your-discord-bot-token
# Enable LangSmith to monitor and debug your LLM calls, agent runs, and tool executions.
# LANGSMITH_TRACING=true
+63
View File
@@ -0,0 +1,63 @@
name: E2E Tests
on:
push:
branches: [ 'main' ]
paths:
- 'frontend/**'
- '.github/workflows/e2e-tests.yml'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- 'frontend/**'
- '.github/workflows/e2e-tests.yml'
concurrency:
group: e2e-tests-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
e2e-tests:
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.draft == false }}
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
- name: Enable Corepack
run: corepack enable
- name: Use pinned pnpm version
run: corepack prepare pnpm@10.26.2 --activate
- name: Install frontend dependencies
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: Install Playwright Chromium
working-directory: frontend
run: npx playwright install chromium --with-deps
- name: Run E2E tests
working-directory: frontend
run: pnpm exec playwright test
env:
SKIP_ENV_VALIDATION: '1'
- name: Upload Playwright report
uses: actions/upload-artifact@v4
if: ${{ !cancelled() }}
with:
name: playwright-report
path: frontend/playwright-report/
retention-days: 7
+43
View File
@@ -0,0 +1,43 @@
name: Frontend Unit Tests
on:
push:
branches: [ 'main' ]
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
concurrency:
group: frontend-unit-tests-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
frontend-unit-tests:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
- name: Enable Corepack
run: corepack enable
- name: Use pinned pnpm version
run: corepack prepare pnpm@10.26.2 --activate
- name: Install frontend dependencies
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: Run unit tests of frontend
working-directory: frontend
run: make test
+3
View File
@@ -40,6 +40,7 @@ coverage/
skills/custom/*
logs/
log/
debug.log
# Local git hooks (keep only on this machine, do not push)
.githooks/
@@ -55,5 +56,7 @@ web/
backend/Dockerfile.langgraph
config.yaml.bak
.playwright-mcp
/frontend/test-results/
/frontend/playwright-report/
.gstack/
.worktrees
+33
View File
@@ -0,0 +1,33 @@
repos:
# Backend: ruff lint + format via uv (uses the same ruff version as backend deps)
- repo: local
hooks:
- id: ruff
name: ruff lint
entry: bash -c 'cd backend && uv run ruff check --fix "${@/#backend\//}"' --
language: system
types_or: [python]
files: ^backend/
- id: ruff-format
name: ruff format
entry: bash -c 'cd backend && uv run ruff format "${@/#backend\//}"' --
language: system
types_or: [python]
files: ^backend/
# Frontend: eslint + prettier (must run from frontend/ for node_modules resolution)
- repo: local
hooks:
- id: frontend-eslint
name: eslint (frontend)
entry: bash -c 'cd frontend && npx eslint --fix "${@/#frontend\//}"' --
language: system
types_or: [javascript, tsx, ts]
files: ^frontend/
- id: frontend-prettier
name: prettier (frontend)
entry: bash -c 'cd frontend && npx prettier --write "${@/#frontend\//}"' --
language: system
files: ^frontend/
types_or: [javascript, tsx, ts, json, css]
+12 -7
View File
@@ -166,7 +166,7 @@ Required tools:
1. **Configure the application** (same as Docker setup above)
2. **Install dependencies**:
2. **Install dependencies** (this also sets up pre-commit hooks):
```bash
make install
```
@@ -298,19 +298,24 @@ Nginx (port 2026) ← Unified entry point
```bash
# Backend tests
cd backend
uv run pytest
make test
# Frontend checks
# Frontend unit tests
cd frontend
pnpm check
make test
# Frontend E2E tests (requires Chromium; builds and auto-starts the Next.js production server)
cd frontend
make test-e2e
```
### PR Regression Checks
Every pull request runs the backend regression workflow at [.github/workflows/backend-unit-tests.yml](.github/workflows/backend-unit-tests.yml), including:
Every pull request triggers the following CI workflows:
- `tests/test_provisioner_kubeconfig.py`
- `tests/test_docker_sandbox_mode_detection.py`
- **Backend unit tests** — [.github/workflows/backend-unit-tests.yml](.github/workflows/backend-unit-tests.yml)
- **Frontend unit tests** — [.github/workflows/frontend-unit-tests.yml](.github/workflows/frontend-unit-tests.yml)
- **Frontend E2E tests** — [.github/workflows/e2e-tests.yml](.github/workflows/e2e-tests.yml) (triggered only when `frontend/` files change)
## Code Style
+4 -2
View File
@@ -23,7 +23,7 @@ help:
@echo " make config - Generate local config files (aborts if config already exists)"
@echo " make config-upgrade - Merge new fields from config.example.yaml into config.yaml"
@echo " make check - Check if all required tools are installed"
@echo " make install - Install all dependencies (frontend + backend)"
@echo " make install - Install all dependencies (frontend + backend + pre-commit hooks)"
@echo " make setup-sandbox - Pre-pull sandbox container image (recommended)"
@echo " make dev - Start all services in development mode (with hot-reloading)"
@echo " make dev-pro - Start in dev + Gateway mode (experimental, no LangGraph server)"
@@ -73,6 +73,8 @@ install:
@cd backend && uv sync
@echo "Installing frontend dependencies..."
@cd frontend && pnpm install
@echo "Installing pre-commit hooks..."
@$(BACKEND_UV_RUN) --with pre-commit pre-commit install
@echo "✓ All dependencies installed"
@echo ""
@echo "=========================================="
@@ -99,7 +101,7 @@ setup-sandbox:
echo ""; \
if command -v container >/dev/null 2>&1 && [ "$$(uname)" = "Darwin" ]; then \
echo "Detected Apple Container on macOS, pulling image..."; \
container pull "$$IMAGE" || echo "⚠ Apple Container pull failed, will try Docker"; \
container image pull "$$IMAGE" || echo "⚠ Apple Container pull failed, will try Docker"; \
fi; \
if command -v docker >/dev/null 2>&1; then \
echo "Pulling image using Docker..."; \
+3 -1
View File
@@ -264,7 +264,7 @@ On Windows, run the local development flow from Git Bash. Native `cmd.exe` and P
2. **Install dependencies**:
```bash
make install # Install backend + frontend dependencies
make install # Install backend + frontend dependencies + pre-commit hooks
```
3. **(Optional) Pre-pull sandbox image**:
@@ -658,6 +658,8 @@ This is the difference between a chatbot with tool access and an agent with an a
**Summarization**: Within a session, DeerFlow manages context aggressively — summarizing completed sub-tasks, offloading intermediate results to the filesystem, compressing what's no longer immediately relevant. This lets it stay sharp across long, multi-step tasks without blowing the context window.
**Strict Tool-Call Recovery**: When a provider or middleware interrupts a tool-call loop, DeerFlow now strips provider-level raw tool-call metadata on forced-stop assistant messages and injects placeholder tool results for dangling calls before the next model invocation. This keeps OpenAI-compatible reasoning models that strictly validate `tool_call_id` sequences from failing with malformed history errors.
### Long-Term Memory
Most agents forget everything the moment a conversation ends. DeerFlow remembers.
+27 -12
View File
@@ -130,7 +130,7 @@ from app.gateway.app import app
from app.channels.service import start_channel_service
# App → Harness (allowed)
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
# Harness → App (FORBIDDEN — enforced by test_harness_boundary.py)
# from app.gateway.routers.uploads import ... # ← will fail CI
@@ -156,20 +156,26 @@ from deerflow.config import get_app_config
### Middleware Chain
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`_build_middlewares`):
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
5. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
6. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
7. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
8. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
9. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
10. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
11. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if subagent_enabled)
12. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption), including raw provider tool-call payloads preserved only in `additional_kwargs["tool_calls"]`
5. **LLMErrorHandlingMiddleware** - Normalizes provider/model invocation failures into recoverable assistant-facing errors before later middleware/tool stages run
6. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
7. **SandboxAuditMiddleware** - Audits sandboxed shell/file operations for security logging before tool execution continues
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
15. **DeferredToolFilterMiddleware** - Hides deferred tool schemas from the bound model until tool search is enabled (optional)
16. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
17. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
18. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
### Configuration System
@@ -179,7 +185,16 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
**Config Versioning**: `config.example.yaml` has a `config_version` field. On startup, `AppConfig.from_file()` compares user version vs example version and emits a warning if outdated. Missing `config_version` = version 0. Run `make config-upgrade` to auto-merge missing fields. When changing the config schema, bump `config_version` in `config.example.yaml`.
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart.
**Config Lifecycle**: All config models are `frozen=True` (immutable after construction). `AppConfig.from_file()` is a pure function — no side effects, no process-global state. The resolved `AppConfig` is passed as an explicit parameter down every consumer lane:
- **Gateway**: `app.state.config` populated in lifespan; routers receive it via `Depends(get_config)` from `app/gateway/deps.py`.
- **Client**: `DeerFlowClient._app_config` captured in the constructor; every method reads `self._app_config`.
- **Agent run**: wrapped in `DeerFlowContext(app_config=…)` and injected via LangGraph `Runtime[DeerFlowContext].context`. Middleware and tools read `runtime.context.app_config` directly or via `resolve_context(runtime)`.
- **LangGraph Server bootstrap**: `make_lead_agent` (registered in `langgraph.json`) calls `AppConfig.from_file()` itself — the only place in production that loads from disk at agent-build time.
To update config at runtime (Gateway API mutations for MCP/Skills), write the new file and call `AppConfig.from_file()` to build a fresh snapshot, then swap `app.state.config`. No mtime detection, no auto-reload, no ambient ContextVar lookup (`AppConfig.current()` has been removed).
**DeerFlowContext**: Per-invocation typed context for the agent execution path, injected via LangGraph `Runtime[DeerFlowContext]`. Holds `app_config: AppConfig`, `thread_id: str`, `agent_name: str | None`. Gateway runtime and `DeerFlowClient` construct full `DeerFlowContext` at invoke time; the LangGraph Server boundary builds one inside `make_lead_agent`. Middleware and tools access context through `resolve_context(runtime)` which returns the typed `DeerFlowContext` — legacy dict/None shapes are rejected. Mutable runtime state (`sandbox_id`) flows through `ThreadState.sandbox`, not context.
Configuration priority:
1. Explicit `config_path` argument
+1 -1
View File
@@ -8,7 +8,7 @@ gateway:
PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
test:
PYTHONPATH=. uv run pytest tests/unittest -v
PYTHONPATH=. uv run pytest tests/ -v
lint:
uvx ruff check .
+273
View File
@@ -0,0 +1,273 @@
"""Discord channel integration using discord.py."""
from __future__ import annotations
import asyncio
import logging
import threading
from typing import Any
from app.channels.base import Channel
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
_DISCORD_MAX_MESSAGE_LEN = 2000
class DiscordChannel(Channel):
"""Discord bot channel.
Configuration keys (in ``config.yaml`` under ``channels.discord``):
- ``bot_token``: Discord Bot token.
- ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all.
"""
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
super().__init__(name="discord", bus=bus, config=config)
self._bot_token = str(config.get("bot_token", "")).strip()
self._allowed_guilds: set[int] = set()
for guild_id in config.get("allowed_guilds", []):
try:
self._allowed_guilds.add(int(guild_id))
except (TypeError, ValueError):
continue
self._client = None
self._thread: threading.Thread | None = None
self._discord_loop: asyncio.AbstractEventLoop | None = None
self._main_loop: asyncio.AbstractEventLoop | None = None
self._discord_module = None
async def start(self) -> None:
if self._running:
return
try:
import discord
except ImportError:
logger.error("discord.py is not installed. Install it with: uv add discord.py")
return
if not self._bot_token:
logger.error("Discord channel requires bot_token")
return
intents = discord.Intents.default()
intents.messages = True
intents.guilds = True
intents.message_content = True
client = discord.Client(
intents=intents,
allowed_mentions=discord.AllowedMentions.none(),
)
self._client = client
self._discord_module = discord
self._main_loop = asyncio.get_event_loop()
@client.event
async def on_message(message) -> None:
await self._on_message(message)
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
self._thread = threading.Thread(target=self._run_client, daemon=True)
self._thread.start()
logger.info("Discord channel started")
async def stop(self) -> None:
self._running = False
self.bus.unsubscribe_outbound(self._on_outbound)
if self._client and self._discord_loop and self._discord_loop.is_running():
close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop)
try:
await asyncio.wait_for(asyncio.wrap_future(close_future), timeout=10)
except TimeoutError:
logger.warning("[Discord] client close timed out after 10s")
except Exception:
logger.exception("[Discord] error while closing client")
if self._thread:
self._thread.join(timeout=10)
self._thread = None
self._client = None
self._discord_loop = None
self._discord_module = None
logger.info("Discord channel stopped")
async def send(self, msg: OutboundMessage) -> None:
target = await self._resolve_target(msg)
if target is None:
logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
return
text = msg.text or ""
for chunk in self._split_text(text):
send_future = asyncio.run_coroutine_threadsafe(target.send(chunk), self._discord_loop)
await asyncio.wrap_future(send_future)
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
target = await self._resolve_target(msg)
if target is None:
logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
return False
if self._discord_module is None:
return False
try:
fp = open(str(attachment.actual_path), "rb") # noqa: SIM115
file = self._discord_module.File(fp, filename=attachment.filename)
send_future = asyncio.run_coroutine_threadsafe(target.send(file=file), self._discord_loop)
await asyncio.wrap_future(send_future)
logger.info("[Discord] file uploaded: %s", attachment.filename)
return True
except Exception:
logger.exception("[Discord] failed to upload file: %s", attachment.filename)
return False
async def _on_message(self, message) -> None:
if not self._running or not self._client:
return
if message.author.bot:
return
if self._client.user and message.author.id == self._client.user.id:
return
guild = message.guild
if self._allowed_guilds:
if guild is None or guild.id not in self._allowed_guilds:
return
text = (message.content or "").strip()
if not text:
return
if self._discord_module is None:
return
if isinstance(message.channel, self._discord_module.Thread):
chat_id = str(message.channel.parent_id or message.channel.id)
thread_id = str(message.channel.id)
else:
thread = await self._create_thread(message)
if thread is None:
return
chat_id = str(message.channel.id)
thread_id = str(thread.id)
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=chat_id,
user_id=str(message.author.id),
text=text,
msg_type=msg_type,
thread_ts=thread_id,
metadata={
"guild_id": str(guild.id) if guild else None,
"channel_id": str(message.channel.id),
"message_id": str(message.id),
},
)
inbound.topic_id = thread_id
if self._main_loop and self._main_loop.is_running():
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
def _run_client(self) -> None:
self._discord_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._discord_loop)
try:
self._discord_loop.run_until_complete(self._client.start(self._bot_token))
except Exception:
if self._running:
logger.exception("Discord client error")
finally:
try:
if self._client and not self._client.is_closed():
self._discord_loop.run_until_complete(self._client.close())
except Exception:
logger.exception("Error during Discord shutdown")
async def _create_thread(self, message):
try:
thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100]
return await message.create_thread(name=thread_name)
except Exception:
logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id)
try:
await message.channel.send("Could not create a thread for your message. Please check that threads are enabled in this channel.")
except Exception:
pass
return None
async def _resolve_target(self, msg: OutboundMessage):
if not self._client or not self._discord_loop:
return None
target_ids: list[str] = []
if msg.thread_ts:
target_ids.append(msg.thread_ts)
if msg.chat_id and msg.chat_id not in target_ids:
target_ids.append(msg.chat_id)
for raw_id in target_ids:
target = await self._get_channel_or_thread(raw_id)
if target is not None:
return target
return None
async def _get_channel_or_thread(self, raw_id: str):
if not self._client or not self._discord_loop:
return None
try:
target_id = int(raw_id)
except (TypeError, ValueError):
return None
get_future = asyncio.run_coroutine_threadsafe(self._fetch_channel(target_id), self._discord_loop)
try:
return await asyncio.wrap_future(get_future)
except Exception:
logger.exception("[Discord] failed to resolve target id=%s", raw_id)
return None
async def _fetch_channel(self, target_id: int):
if not self._client:
return None
channel = self._client.get_channel(target_id)
if channel is not None:
return channel
try:
return await self._client.fetch_channel(target_id)
except Exception:
return None
@staticmethod
def _split_text(text: str) -> list[str]:
if not text:
return [""]
chunks: list[str] = []
remaining = text
while len(remaining) > _DISCORD_MAX_MESSAGE_LEN:
split_at = remaining.rfind("\n", 0, _DISCORD_MAX_MESSAGE_LEN)
if split_at <= 0:
split_at = _DISCORD_MAX_MESSAGE_LEN
chunks.append(remaining[:split_at])
remaining = remaining[split_at:].lstrip("\n")
if remaining:
chunks.append(remaining)
return chunks
+43 -63
View File
@@ -9,12 +9,11 @@ import re
import threading
from typing import Any, Literal
from app.plugins.auth.security.actor_context import bind_user_actor_context
from app.channels.base import Channel
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.runtime.actor_context import get_effective_user_id
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
logger = logging.getLogger(__name__)
@@ -299,35 +298,15 @@ class FeishuChannel(Channel):
text = msg.text
for file in files:
if file.get("image_key"):
virtual_path = await self._receive_single_file(
msg.thread_ts,
file["image_key"],
"image",
thread_id,
user_id=msg.user_id,
)
virtual_path = await self._receive_single_file(msg.thread_ts, file["image_key"], "image", thread_id)
text = text.replace("[image]", virtual_path, 1)
elif file.get("file_key"):
virtual_path = await self._receive_single_file(
msg.thread_ts,
file["file_key"],
"file",
thread_id,
user_id=msg.user_id,
)
virtual_path = await self._receive_single_file(msg.thread_ts, file["file_key"], "file", thread_id)
text = text.replace("[file]", virtual_path, 1)
msg.text = text
return msg
async def _receive_single_file(
self,
message_id: str,
file_key: str,
type: Literal["image", "file"],
thread_id: str,
*,
user_id: str | None = None,
) -> str:
async def _receive_single_file(self, message_id: str, file_key: str, type: Literal["image", "file"], thread_id: str) -> str:
request = self._GetMessageResourceRequest.builder().message_id(message_id).file_key(file_key).type(type).build()
def inner():
@@ -366,51 +345,52 @@ class FeishuChannel(Channel):
return f"Failed to obtain the [{type}]"
paths = get_paths()
with bind_user_actor_context(user_id):
effective_user_id = get_effective_user_id()
paths.ensure_thread_dirs(thread_id, user_id=effective_user_id)
uploads_dir = paths.sandbox_uploads_dir(thread_id, user_id=effective_user_id).resolve()
user_id = get_effective_user_id()
paths.ensure_thread_dirs(thread_id, user_id=user_id)
uploads_dir = paths.sandbox_uploads_dir(thread_id, user_id=user_id).resolve()
ext = "png" if type == "image" else "bin"
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
ext = "png" if type == "image" else "bin"
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
# Sanitize filename: preserve extension, replace path chars in name part
if "." in raw_filename:
name_part, ext = raw_filename.rsplit(".", 1)
name_part = re.sub(r"[./\\]", "_", name_part)
filename = f"{name_part}.{ext}"
else:
filename = re.sub(r"[./\\]", "_", raw_filename)
resolved_target = uploads_dir / filename
# Sanitize filename: preserve extension, replace path chars in name part
if "." in raw_filename:
name_part, ext = raw_filename.rsplit(".", 1)
name_part = re.sub(r"[./\\]", "_", name_part)
filename = f"{name_part}.{ext}"
else:
filename = re.sub(r"[./\\]", "_", raw_filename)
resolved_target = uploads_dir / filename
def down_load():
# use thread_lock to avoid filename conflicts when writing
with self._thread_lock:
resolved_target.write_bytes(content)
def down_load():
# use thread_lock to avoid filename conflicts when writing
with self._thread_lock:
resolved_target.write_bytes(content)
try:
await asyncio.to_thread(down_load)
except Exception:
logger.exception("[Feishu] failed to persist downloaded resource: %s, type=%s", resolved_target, type)
return f"Failed to obtain the [{type}]"
try:
await asyncio.to_thread(down_load)
except Exception:
logger.exception("[Feishu] failed to persist downloaded resource: %s, type=%s", resolved_target, type)
return f"Failed to obtain the [{type}]"
virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"
virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"
try:
sandbox_provider = get_sandbox_provider()
sandbox_id = sandbox_provider.acquire(thread_id)
if sandbox_id != "local":
sandbox = sandbox_provider.get(sandbox_id)
if sandbox is None:
logger.warning("[Feishu] sandbox not found for thread_id=%s", thread_id)
return f"Failed to obtain the [{type}]"
sandbox.update_file(virtual_path, content)
except Exception:
logger.exception("[Feishu] failed to sync resource into non-local sandbox: %s", virtual_path)
return f"Failed to obtain the [{type}]"
try:
from deerflow.config.app_config import AppConfig
logger.info("[Feishu] downloaded resource mapped: file_key=%s -> %s", file_key, virtual_path)
return virtual_path
sandbox_provider = get_sandbox_provider(AppConfig.from_file())
sandbox_id = sandbox_provider.acquire(thread_id)
if sandbox_id != "local":
sandbox = sandbox_provider.get(sandbox_id)
if sandbox is None:
logger.warning("[Feishu] sandbox not found for thread_id=%s", thread_id)
return f"Failed to obtain the [{type}]"
sandbox.update_file(virtual_path, content)
except Exception:
logger.exception("[Feishu] failed to sync resource into non-local sandbox: %s", virtual_path)
return f"Failed to obtain the [{type}]"
logger.info("[Feishu] downloaded resource mapped: file_key=%s -> %s", file_key, virtual_path)
return virtual_path
# -- message formatting ------------------------------------------------
+38 -52
View File
@@ -14,11 +14,10 @@ from typing import Any
import httpx
from langgraph_sdk.errors import ConflictError
from app.plugins.auth.security.actor_context import bind_user_actor_context
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.store import ChannelStore
from deerflow.runtime.actor_context import get_effective_user_id
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
@@ -37,6 +36,7 @@ STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
CHANNEL_CAPABILITIES = {
"discord": {"supports_streaming": False},
"feishu": {"supports_streaming": True},
"slack": {"supports_streaming": False},
"telegram": {"supports_streaming": False},
@@ -329,7 +329,7 @@ def _format_artifact_text(artifacts: list[str]) -> str:
_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/"
def _resolve_attachments(thread_id: str, artifacts: list[str], *, user_id: str | None = None) -> list[ResolvedAttachment]:
def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedAttachment]:
"""Resolve virtual artifact paths to host filesystem paths with metadata.
Only paths under ``/mnt/user-data/outputs/`` are accepted; any other
@@ -343,40 +343,39 @@ def _resolve_attachments(thread_id: str, artifacts: list[str], *, user_id: str |
attachments: list[ResolvedAttachment] = []
paths = get_paths()
with bind_user_actor_context(user_id):
effective_user_id = get_effective_user_id()
outputs_dir = paths.sandbox_outputs_dir(thread_id, user_id=effective_user_id).resolve()
for virtual_path in artifacts:
# Security: only allow files from the agent outputs directory
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
continue
user_id = get_effective_user_id()
outputs_dir = paths.sandbox_outputs_dir(thread_id, user_id=user_id).resolve()
for virtual_path in artifacts:
# Security: only allow files from the agent outputs directory
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
continue
try:
actual = paths.resolve_virtual_path(thread_id, virtual_path, user_id=user_id)
# Verify the resolved path is actually under the outputs directory
# (guards against path-traversal even after prefix check)
try:
actual = paths.resolve_virtual_path(thread_id, virtual_path, user_id=effective_user_id)
# Verify the resolved path is actually under the outputs directory
# (guards against path-traversal even after prefix check)
try:
actual.resolve().relative_to(outputs_dir)
except ValueError:
logger.warning("[Manager] artifact path escapes outputs dir: %s -> %s", virtual_path, actual)
continue
if not actual.is_file():
logger.warning("[Manager] artifact not found on disk: %s -> %s", virtual_path, actual)
continue
mime, _ = mimetypes.guess_type(str(actual))
mime = mime or "application/octet-stream"
attachments.append(
ResolvedAttachment(
virtual_path=virtual_path,
actual_path=actual,
filename=actual.name,
mime_type=mime,
size=actual.stat().st_size,
is_image=mime.startswith("image/"),
)
actual.resolve().relative_to(outputs_dir)
except ValueError:
logger.warning("[Manager] artifact path escapes outputs dir: %s -> %s", virtual_path, actual)
continue
if not actual.is_file():
logger.warning("[Manager] artifact not found on disk: %s -> %s", virtual_path, actual)
continue
mime, _ = mimetypes.guess_type(str(actual))
mime = mime or "application/octet-stream"
attachments.append(
ResolvedAttachment(
virtual_path=virtual_path,
actual_path=actual,
filename=actual.name,
mime_type=mime,
size=actual.stat().st_size,
is_image=mime.startswith("image/"),
)
except (ValueError, OSError) as exc:
logger.warning("[Manager] failed to resolve artifact %s: %s", virtual_path, exc)
)
except (ValueError, OSError) as exc:
logger.warning("[Manager] failed to resolve artifact %s: %s", virtual_path, exc)
return attachments
@@ -384,15 +383,13 @@ def _prepare_artifact_delivery(
thread_id: str,
response_text: str,
artifacts: list[str],
*,
user_id: str | None = None,
) -> tuple[str, list[ResolvedAttachment]]:
"""Resolve attachments and append filename fallbacks to the text response."""
attachments: list[ResolvedAttachment] = []
if not artifacts:
return response_text, attachments
attachments = _resolve_attachments(thread_id, artifacts, user_id=user_id)
attachments = _resolve_attachments(thread_id, artifacts)
resolved_virtuals = {attachment.virtual_path for attachment in attachments}
unresolved = [path for path in artifacts if path not in resolved_virtuals]
@@ -415,8 +412,7 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
with bind_user_actor_context(msg.user_id):
uploads_dir = ensure_uploads_dir(thread_id)
uploads_dir = ensure_uploads_dir(thread_id)
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
created: list[dict[str, Any]] = []
@@ -750,12 +746,7 @@ class ChannelManager:
len(artifacts),
)
response_text, attachments = _prepare_artifact_delivery(
thread_id,
response_text,
artifacts,
user_id=msg.user_id,
)
response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts)
if not response_text:
if attachments:
@@ -846,12 +837,7 @@ class ChannelManager:
result = last_values if last_values is not None else {"messages": [{"type": "ai", "content": latest_text}]}
response_text = _extract_response_text(result)
artifacts = _extract_artifacts(result)
response_text, attachments = _prepare_artifact_delivery(
thread_id,
response_text,
artifacts,
user_id=msg.user_id,
)
response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts)
if not response_text:
if attachments:
+30 -10
View File
@@ -4,17 +4,21 @@ from __future__ import annotations
import logging
import os
from typing import Any
from typing import TYPE_CHECKING, Any
from app.channels.base import Channel
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
from app.channels.message_bus import MessageBus
from app.channels.store import ChannelStore
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
# Channel name → import path for lazy loading
_CHANNEL_REGISTRY: dict[str, str] = {
"discord": "app.channels.discord:DiscordChannel",
"feishu": "app.channels.feishu:FeishuChannel",
"slack": "app.channels.slack:SlackChannel",
"telegram": "app.channels.telegram:TelegramChannel",
@@ -22,6 +26,16 @@ _CHANNEL_REGISTRY: dict[str, str] = {
"wecom": "app.channels.wecom:WeComChannel",
}
# Keys that indicate a user has configured credentials for a channel.
_CHANNEL_CREDENTIAL_KEYS: dict[str, list[str]] = {
"discord": ["bot_token"],
"feishu": ["app_id", "app_secret"],
"slack": ["bot_token", "app_token"],
"telegram": ["bot_token"],
"wecom": ["bot_id", "bot_secret"],
"wechat": ["bot_token"],
}
_CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
@@ -64,14 +78,11 @@ class ChannelService:
self._running = False
@classmethod
def from_app_config(cls) -> ChannelService:
"""Create a ChannelService from the application config."""
from deerflow.config.app_config import get_app_config
config = get_app_config()
def from_app_config(cls, app_config: AppConfig) -> ChannelService:
"""Create a ChannelService from an explicit application config."""
channels_config = {}
# extra fields are allowed by AppConfig (extra="allow")
extra = config.model_extra or {}
extra = app_config.model_extra or {}
if "channels" in extra:
channels_config = extra["channels"]
return cls(channels_config=channels_config)
@@ -87,7 +98,16 @@ class ChannelService:
if not isinstance(channel_config, dict):
continue
if not channel_config.get("enabled", False):
logger.info("Channel %s is disabled, skipping", name)
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
if has_creds:
logger.warning(
"Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.",
name,
name,
)
else:
logger.info("Channel %s is disabled, skipping", name)
continue
await self._start_channel(name, channel_config)
@@ -181,12 +201,12 @@ def get_channel_service() -> ChannelService | None:
return _channel_service
async def start_channel_service() -> ChannelService:
async def start_channel_service(app_config: AppConfig) -> ChannelService:
"""Create and start the global ChannelService from app config."""
global _channel_service
if _channel_service is not None:
return _channel_service
_channel_service = ChannelService.from_app_config()
_channel_service = ChannelService.from_app_config(app_config)
await _channel_service.start()
return _channel_service
+20 -2
View File
@@ -16,13 +16,31 @@ logger = logging.getLogger(__name__)
_slack_md_converter = SlackMarkdownConverter()
def _normalize_allowed_users(allowed_users: Any) -> set[str]:
if allowed_users is None:
return set()
if isinstance(allowed_users, str):
values = [allowed_users]
elif isinstance(allowed_users, list | tuple | set):
values = allowed_users
else:
logger.warning(
"Slack allowed_users should be a list of Slack user IDs or a single Slack user ID string; treating %s as one string value",
type(allowed_users).__name__,
)
values = [allowed_users]
return {str(user_id) for user_id in values if str(user_id)}
class SlackChannel(Channel):
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
Configuration keys (in ``config.yaml`` under ``channels.slack``):
- ``bot_token``: Slack Bot User OAuth Token (xoxb-...).
- ``app_token``: Slack App-Level Token (xapp-...) for Socket Mode.
- ``allowed_users``: (optional) List of allowed Slack user IDs. Empty = allow all.
- ``allowed_users``: (optional) List of allowed Slack user IDs, or a
single Slack user ID string as shorthand. Empty = allow all. Other
scalar values are treated as a single string with a warning.
"""
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
@@ -30,7 +48,7 @@ class SlackChannel(Channel):
self._socket_client = None
self._web_client = None
self._loop: asyncio.AbstractEventLoop | None = None
self._allowed_users: set[str] = {str(user_id) for user_id in config.get("allowed_users", [])}
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
async def start(self) -> None:
if self._running:
+3 -22
View File
@@ -1,23 +1,4 @@
from __future__ import annotations
from .app import app, create_app
from .config import GatewayConfig, get_gateway_config
__all__ = ["GatewayConfig", "app", "get_gateway_config", "register_app"]
def __getattr__(name: str):
if name == "app":
from .app import app
return app
if name == "GatewayConfig":
from .config import GatewayConfig
return GatewayConfig
if name == "get_gateway_config":
from .config import get_gateway_config
return get_gateway_config
if name == "register_app":
from .registrar import register_app
return register_app
raise AttributeError(name)
__all__ = ["app", "create_app", "GatewayConfig", "get_gateway_config"]
+375 -4
View File
@@ -1,8 +1,379 @@
from app.gateway.registrar import register_app
import asyncio
import logging
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.gateway.auth_middleware import AuthMiddleware
from app.gateway.config import get_gateway_config
from app.gateway.csrf_middleware import CSRFMiddleware
from app.gateway.deps import langgraph_runtime
from app.gateway.routers import (
agents,
artifacts,
assistants_compat,
auth,
channels,
feedback,
mcp,
memory,
models,
runs,
skills,
suggestions,
thread_runs,
threads,
uploads,
)
from deerflow.config.app_config import AppConfig
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
# Upper bound (seconds) each lifespan shutdown hook is allowed to run.
# Bounds worker exit time so uvicorn's reload supervisor does not keep
# firing signals into a worker that is stuck waiting for shutdown cleanup.
_SHUTDOWN_HOOK_TIMEOUT_SECONDS = 5.0
def create_app():
return register_app()
async def _ensure_admin_user(app: FastAPI) -> None:
"""Startup hook: handle first boot and migrate orphan threads otherwise.
After admin creation, migrate orphan threads from the LangGraph
store (metadata.user_id unset) to the admin account. This is the
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
authentication have existing LangGraph thread data that needs an
owner assigned.
First boot (no admin exists):
- Does NOT create any user accounts automatically.
- The operator must visit ``/setup`` to create the first admin.
Subsequent boots (admin already exists):
- Runs the one-time "no-auth → with-auth" orphan thread migration for
existing LangGraph thread metadata that has no owner_id.
No SQL persistence migration is needed: the four user_id columns
(threads_meta, runs, run_events, feedback) only come into existence
alongside the auth module via create_all, so freshly created tables
never contain NULL-owner rows.
"""
from sqlalchemy import select
from app.gateway.deps import get_local_provider
from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.user.model import UserRow
provider = get_local_provider()
admin_count = await provider.count_admin_users()
if admin_count == 0:
logger.info("=" * 60)
logger.info(" First boot detected — no admin account exists.")
logger.info(" Visit /setup to complete admin account creation.")
logger.info("=" * 60)
return
# Admin already exists — run orphan thread migration for any
# LangGraph thread metadata that pre-dates the auth module.
sf = get_session_factory()
if sf is None:
return
async with sf() as session:
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
return # Should not happen (admin_count > 0 above), but be safe.
admin_id = str(row.id)
# LangGraph store orphan migration — non-fatal.
# This covers the "no-auth → with-auth" upgrade path for users
# whose existing LangGraph thread metadata has no user_id set.
store = getattr(app.state, "store", None)
if store is not None:
try:
migrated = await _migrate_orphaned_threads(store, admin_id)
if migrated:
logger.info("Migrated %d orphan LangGraph thread(s) to admin", migrated)
except Exception:
logger.exception("LangGraph thread migration failed (non-fatal)")
app = register_app()
async def _iter_store_items(store, namespace, *, page_size: int = 500):
"""Paginated async iterator over a LangGraph store namespace.
Replaces the old hardcoded ``limit=1000`` call with a cursor-style
loop so that environments with more than one page of orphans do
not silently lose data. Terminates when a page is empty OR when a
short page arrives (indicating the last page).
"""
offset = 0
while True:
batch = await store.asearch(namespace, limit=page_size, offset=offset)
if not batch:
return
for item in batch:
yield item
if len(batch) < page_size:
return
offset += page_size
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
"""Migrate LangGraph store threads with no user_id to the given admin.
Uses cursor pagination so all orphans are migrated regardless of
count. Returns the number of rows migrated.
"""
migrated = 0
async for item in _iter_store_items(store, ("threads",)):
metadata = item.value.get("metadata", {})
if not metadata.get("user_id"):
metadata["user_id"] = admin_user_id
item.value["metadata"] = metadata
await store.aput(("threads",), item.key, item.value)
migrated += 1
return migrated
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan handler."""
try:
# ``app.state.config`` is the sole source of truth for
# ``Depends(get_config)``. Consumers that want AppConfig must receive
# it as an explicit parameter; there is no ambient singleton.
app.state.config = AppConfig.from_file()
logger.info("Configuration loaded successfully")
except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}"
logger.exception(error_msg)
raise RuntimeError(error_msg) from e
config = get_gateway_config()
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
async with langgraph_runtime(app):
logger.info("LangGraph runtime initialised")
# Ensure admin user exists (auto-create on first boot)
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
await _ensure_admin_user(app)
# Start IM channel service if any channels are configured
try:
from app.channels.service import start_channel_service
channel_service = await start_channel_service(app.state.config)
logger.info("Channel service started: %s", channel_service.get_status())
except Exception:
logger.exception("No IM channels configured or channel service failed to start")
yield
# Stop channel service on shutdown (bounded to prevent worker hang)
try:
from app.channels.service import stop_channel_service
await asyncio.wait_for(
stop_channel_service(),
timeout=_SHUTDOWN_HOOK_TIMEOUT_SECONDS,
)
except TimeoutError:
logger.warning(
"Channel service shutdown exceeded %.1fs; proceeding with worker exit.",
_SHUTDOWN_HOOK_TIMEOUT_SECONDS,
)
except Exception:
logger.exception("Failed to stop channel service")
logger.info("Shutting down API Gateway")
def create_app() -> FastAPI:
"""Create and configure the FastAPI application.
Returns:
Configured FastAPI application instance.
"""
app = FastAPI(
title="DeerFlow API Gateway",
description="""
## DeerFlow API Gateway
API Gateway for DeerFlow - A LangGraph-based AI agent backend with sandbox execution capabilities.
### Features
- **Models Management**: Query and retrieve available AI models
- **MCP Configuration**: Manage Model Context Protocol (MCP) server configurations
- **Memory Management**: Access and manage global memory data for personalized conversations
- **Skills Management**: Query and manage skills and their enabled status
- **Artifacts**: Access thread artifacts and generated files
- **Health Monitoring**: System health check endpoints
### Architecture
LangGraph requests are handled by nginx reverse proxy.
This gateway provides custom endpoints for models, MCP configuration, skills, and artifacts.
""",
version="0.1.0",
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
openapi_tags=[
{
"name": "models",
"description": "Operations for querying available AI models and their configurations",
},
{
"name": "mcp",
"description": "Manage Model Context Protocol (MCP) server configurations",
},
{
"name": "memory",
"description": "Access and manage global memory data for personalized conversations",
},
{
"name": "skills",
"description": "Manage skills and their configurations",
},
{
"name": "artifacts",
"description": "Access and download thread artifacts and generated files",
},
{
"name": "uploads",
"description": "Upload and manage user files for threads",
},
{
"name": "threads",
"description": "Manage DeerFlow thread-local filesystem data",
},
{
"name": "agents",
"description": "Create and manage custom agents with per-agent config and prompts",
},
{
"name": "suggestions",
"description": "Generate follow-up question suggestions for conversations",
},
{
"name": "channels",
"description": "Manage IM channel integrations (Feishu, Slack, Telegram)",
},
{
"name": "assistants-compat",
"description": "LangGraph Platform-compatible assistants API (stub)",
},
{
"name": "runs",
"description": "LangGraph Platform-compatible runs lifecycle (create, stream, cancel)",
},
{
"name": "health",
"description": "Health check and system status endpoints",
},
],
)
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net)
app.add_middleware(AuthMiddleware)
# CSRF: Double Submit Cookie pattern for state-changing requests
app.add_middleware(CSRFMiddleware)
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware.
# In production, nginx handles CORS and no middleware is needed.
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
if cors_origins_env:
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
# Validate: wildcard origin with credentials is a security misconfiguration
for origin in cors_origins:
if origin == "*":
logger.error("GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. This is a security misconfiguration — browsers will reject the response. Use explicit scheme://host:port origins instead.")
cors_origins = [o for o in cors_origins if o != "*"]
break
if cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
# Models API is mounted at /api/models
app.include_router(models.router)
# MCP API is mounted at /api/mcp
app.include_router(mcp.router)
# Memory API is mounted at /api/memory
app.include_router(memory.router)
# Skills API is mounted at /api/skills
app.include_router(skills.router)
# Artifacts API is mounted at /api/threads/{thread_id}/artifacts
app.include_router(artifacts.router)
# Uploads API is mounted at /api/threads/{thread_id}/uploads
app.include_router(uploads.router)
# Thread cleanup API is mounted at /api/threads/{thread_id}
app.include_router(threads.router)
# Agents API is mounted at /api/agents
app.include_router(agents.router)
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
app.include_router(suggestions.router)
# Channels API is mounted at /api/channels
app.include_router(channels.router)
# Assistants compatibility API (LangGraph Platform stub)
app.include_router(assistants_compat.router)
# Auth API is mounted at /api/v1/auth
app.include_router(auth.router)
# Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback
app.include_router(feedback.router)
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
app.include_router(thread_runs.router)
# Stateless Runs API (stream/wait without a pre-existing thread)
app.include_router(runs.router)
@app.get("/health", tags=["health"])
async def health_check() -> dict:
"""Health check endpoint.
Returns:
Service health status information.
"""
return {"status": "healthy", "service": "deer-flow-gateway"}
return app
# Create app instance for uvicorn
app = create_app()
+42
View File
@@ -0,0 +1,42 @@
"""Authentication module for DeerFlow.
This module provides:
- JWT-based authentication
- Provider Factory pattern for extensible auth methods
- UserRepository interface for storage backends (SQLite)
"""
from app.gateway.auth.config import AuthConfig, get_auth_config, set_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
from app.gateway.auth.jwt import TokenPayload, create_access_token, decode_token
from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.models import User, UserResponse
from app.gateway.auth.password import hash_password, verify_password
from app.gateway.auth.providers import AuthProvider
from app.gateway.auth.repositories.base import UserRepository
__all__ = [
# Config
"AuthConfig",
"get_auth_config",
"set_auth_config",
# Errors
"AuthErrorCode",
"AuthErrorResponse",
"TokenError",
# JWT
"TokenPayload",
"create_access_token",
"decode_token",
# Password
"hash_password",
"verify_password",
# Models
"User",
"UserResponse",
# Providers
"AuthProvider",
"LocalAuthProvider",
# Repository
"UserRepository",
]
+57
View File
@@ -0,0 +1,57 @@
"""Authentication configuration for DeerFlow."""
import logging
import os
import secrets
from dotenv import load_dotenv
from pydantic import BaseModel, Field
load_dotenv()
logger = logging.getLogger(__name__)
class AuthConfig(BaseModel):
"""JWT and auth-related configuration. Parsed once at startup.
Note: the ``users`` table now lives in the shared persistence
database managed by ``deerflow.persistence.engine``. The old
``users_db_path`` config key has been removed — user storage is
configured through ``config.database`` like every other table.
"""
jwt_secret: str = Field(
...,
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
)
token_expiry_days: int = Field(default=7, ge=1, le=30)
oauth_github_client_id: str | None = Field(default=None)
oauth_github_client_secret: str | None = Field(default=None)
_auth_config: AuthConfig | None = None
def get_auth_config() -> AuthConfig:
"""Get the global AuthConfig instance. Parses from env on first call."""
global _auth_config
if _auth_config is None:
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
if not jwt_secret:
jwt_secret = secrets.token_urlsafe(32)
os.environ["AUTH_JWT_SECRET"] = jwt_secret
logger.warning(
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
"Sessions will be invalidated on restart. "
"For production, add AUTH_JWT_SECRET to your .env file: "
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
)
_auth_config = AuthConfig(jwt_secret=jwt_secret)
return _auth_config
def set_auth_config(config: AuthConfig) -> None:
"""Set the global AuthConfig instance (for testing)."""
global _auth_config
_auth_config = config
@@ -0,0 +1,48 @@
"""Write initial admin credentials to a restricted file instead of logs.
Logging secrets to stdout/stderr is a well-known CodeQL finding
(py/clear-text-logging-sensitive-data) — in production those logs
get collected into ELK/Splunk/etc and become a secret sprawl
source. This helper writes the credential to a 0600 file that only
the process user can read, and returns the path so the caller can
log **the path** (not the password) for the operator to pick up.
"""
from __future__ import annotations
import os
from pathlib import Path
from deerflow.config.paths import get_paths
_CREDENTIAL_FILENAME = "admin_initial_credentials.txt"
def write_initial_credentials(email: str, password: str, *, label: str = "initial") -> Path:
"""Write the admin email + password to ``{base_dir}/admin_initial_credentials.txt``.
The file is created **atomically** with mode 0600 via ``os.open``
so the password is never world-readable, even for the single syscall
window between ``write_text`` and ``chmod``.
``label`` distinguishes "initial" (fresh creation) from "reset"
(password reset) in the file header so an operator picking up the
file after a restart can tell which event produced it.
Returns the absolute :class:`Path` to the file.
"""
target = get_paths().base_dir / _CREDENTIAL_FILENAME
target.parent.mkdir(parents=True, exist_ok=True)
content = (
f"# DeerFlow admin {label} credentials\n# This file is generated on first boot or password reset.\n# Change the password after login via Settings -> Account,\n# then delete this file.\n#\nemail: {email}\npassword: {password}\n"
)
# Atomic 0600 create-or-truncate. O_TRUNC (not O_EXCL) so the
# reset-password path can rewrite an existing file without a
# separate unlink-then-create dance.
fd = os.open(target, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w", encoding="utf-8") as fh:
fh.write(content)
return target.resolve()
@@ -1,4 +1,9 @@
"""Typed error definitions for auth plugin."""
"""Typed error definitions for auth module.
AuthErrorCode: exhaustive enum of all auth failure conditions.
TokenError: exhaustive enum of JWT decode failures.
AuthErrorResponse: structured error payload for HTTP responses.
"""
from enum import StrEnum
@@ -6,6 +11,8 @@ from pydantic import BaseModel
class AuthErrorCode(StrEnum):
"""Exhaustive list of auth error conditions."""
INVALID_CREDENTIALS = "invalid_credentials"
TOKEN_EXPIRED = "token_expired"
TOKEN_INVALID = "token_invalid"
@@ -17,17 +24,22 @@ class AuthErrorCode(StrEnum):
class TokenError(StrEnum):
"""Exhaustive list of JWT decode failure reasons."""
EXPIRED = "expired"
INVALID_SIGNATURE = "invalid_signature"
MALFORMED = "malformed"
class AuthErrorResponse(BaseModel):
"""Structured error response — replaces bare `detail` strings."""
code: AuthErrorCode
message: str
def token_error_to_code(err: TokenError) -> AuthErrorCode:
"""Map TokenError to AuthErrorCode — single source of truth."""
if err == TokenError.EXPIRED:
return AuthErrorCode.TOKEN_EXPIRED
return AuthErrorCode.TOKEN_INVALID
@@ -5,26 +5,44 @@ from datetime import UTC, datetime, timedelta
import jwt
from pydantic import BaseModel
from app.plugins.auth.domain.errors import TokenError
from app.plugins.auth.runtime.config_state import get_auth_config
from app.gateway.auth.config import get_auth_config
from app.gateway.auth.errors import TokenError
class TokenPayload(BaseModel):
sub: str
"""JWT token payload."""
sub: str # user_id
exp: datetime
iat: datetime | None = None
ver: int = 0
ver: int = 0 # token_version — must match User.token_version
def create_access_token(user_id: str, expires_delta: timedelta | None = None, token_version: int = 0) -> str:
"""Create a JWT access token.
Args:
user_id: The user's UUID as string
expires_delta: Optional custom expiry, defaults to 7 days
token_version: User's current token_version for invalidation
Returns:
Encoded JWT string
"""
config = get_auth_config()
expiry = expires_delta or timedelta(days=config.token_expiry_days)
now = datetime.now(UTC)
payload = {"sub": user_id, "exp": now + expiry, "iat": now, "ver": token_version}
return jwt.encode(payload, config.jwt_secret, algorithm="HS256")
def decode_token(token: str) -> TokenPayload | TokenError:
"""Decode and validate a JWT token.
Returns:
TokenPayload if valid, or a specific TokenError variant.
"""
config = get_auth_config()
try:
payload = jwt.decode(token, config.jwt_secret, algorithms=["HS256"])
@@ -0,0 +1,91 @@
"""Local email/password authentication provider."""
from app.gateway.auth.models import User
from app.gateway.auth.password import hash_password_async, verify_password_async
from app.gateway.auth.providers import AuthProvider
from app.gateway.auth.repositories.base import UserRepository
class LocalAuthProvider(AuthProvider):
"""Email/password authentication provider using local database."""
def __init__(self, repository: UserRepository):
"""Initialize with a UserRepository.
Args:
repository: UserRepository implementation (SQLite)
"""
self._repo = repository
async def authenticate(self, credentials: dict) -> User | None:
"""Authenticate with email and password.
Args:
credentials: dict with 'email' and 'password' keys
Returns:
User if authentication succeeds, None otherwise
"""
email = credentials.get("email")
password = credentials.get("password")
if not email or not password:
return None
user = await self._repo.get_user_by_email(email)
if user is None:
return None
if user.password_hash is None:
# OAuth user without local password
return None
if not await verify_password_async(password, user.password_hash):
return None
return user
async def get_user(self, user_id: str) -> User | None:
"""Get user by ID."""
return await self._repo.get_user_by_id(user_id)
async def create_user(self, email: str, password: str | None = None, system_role: str = "user", needs_setup: bool = False) -> User:
"""Create a new local user.
Args:
email: User email address
password: Plain text password (will be hashed)
system_role: Role to assign ("admin" or "user")
needs_setup: If True, user must complete setup on first login
Returns:
Created User instance
"""
password_hash = await hash_password_async(password) if password else None
user = User(
email=email,
password_hash=password_hash,
system_role=system_role,
needs_setup=needs_setup,
)
return await self._repo.create_user(user)
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
"""Get user by OAuth provider and ID."""
return await self._repo.get_user_by_oauth(provider, oauth_id)
async def count_users(self) -> int:
"""Return total number of registered users."""
return await self._repo.count_users()
async def count_admin_users(self) -> int:
"""Return number of admin users."""
return await self._repo.count_admin_users()
async def update_user(self, user: User) -> User:
"""Update an existing user."""
return await self._repo.update_user(user)
async def get_user_by_email(self, email: str) -> User | None:
"""Get user by email."""
return await self._repo.get_user_by_email(email)
@@ -1,4 +1,4 @@
"""User Pydantic models for the auth plugin."""
"""User Pydantic models for authentication."""
from datetime import UTC, datetime
from typing import Literal
@@ -8,10 +8,13 @@ from pydantic import BaseModel, ConfigDict, EmailStr, Field
def _utc_now() -> datetime:
"""Return current UTC time (timezone-aware)."""
return datetime.now(UTC)
class User(BaseModel):
"""Internal user representation."""
model_config = ConfigDict(from_attributes=True)
id: UUID = Field(default_factory=uuid4, description="Primary key")
@@ -19,13 +22,19 @@ class User(BaseModel):
password_hash: str | None = Field(None, description="bcrypt hash, nullable for OAuth users")
system_role: Literal["admin", "user"] = Field(default="user")
created_at: datetime = Field(default_factory=_utc_now)
# OAuth linkage (optional)
oauth_provider: str | None = Field(None, description="e.g. 'github', 'google'")
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
# Auth lifecycle
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
class UserResponse(BaseModel):
"""Response model for user info endpoint."""
id: str
email: str
system_role: Literal["admin", "user"]
@@ -6,16 +6,28 @@ import bcrypt
def hash_password(password: str) -> str:
"""Hash a password using bcrypt."""
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
async def hash_password_async(password: str) -> str:
"""Hash a password using bcrypt (non-blocking).
Wraps the blocking bcrypt operation in a thread pool to avoid
blocking the event loop during password hashing.
"""
return await asyncio.to_thread(hash_password, password)
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash (non-blocking).
Wraps the blocking bcrypt operation in a thread pool to avoid
blocking the event loop during password verification.
"""
return await asyncio.to_thread(verify_password, plain_password, hashed_password)
+24
View File
@@ -0,0 +1,24 @@
"""Auth provider abstraction."""
from abc import ABC, abstractmethod
class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def authenticate(self, credentials: dict) -> "User | None":
"""Authenticate user with given credentials.
Returns User if authentication succeeds, None otherwise.
"""
...
@abstractmethod
async def get_user(self, user_id: str) -> "User | None":
"""Retrieve user by ID."""
...
# Import User at runtime to avoid circular imports
from app.gateway.auth.models import User # noqa: E402
@@ -0,0 +1,102 @@
"""User repository interface for abstracting database operations."""
from abc import ABC, abstractmethod
from app.gateway.auth.models import User
class UserNotFoundError(LookupError):
"""Raised when a user repository operation targets a non-existent row.
Subclass of :class:`LookupError` so callers that already catch
``LookupError`` for "missing entity" can keep working unchanged,
while specific call sites can pin to this class to distinguish
"concurrent delete during update" from other lookups.
"""
class UserRepository(ABC):
"""Abstract interface for user data storage.
Implement this interface to support different storage backends
(SQLite)
"""
@abstractmethod
async def create_user(self, user: User) -> User:
"""Create a new user.
Args:
user: User object to create
Returns:
Created User with ID assigned
Raises:
ValueError: If email already exists
"""
...
@abstractmethod
async def get_user_by_id(self, user_id: str) -> User | None:
"""Get user by ID.
Args:
user_id: User UUID as string
Returns:
User if found, None otherwise
"""
...
@abstractmethod
async def get_user_by_email(self, email: str) -> User | None:
"""Get user by email.
Args:
email: User email address
Returns:
User if found, None otherwise
"""
...
@abstractmethod
async def update_user(self, user: User) -> User:
"""Update an existing user.
Args:
user: User object with updated fields
Returns:
Updated User
Raises:
UserNotFoundError: If no row exists for ``user.id``. This is
a hard failure (not a no-op) so callers cannot mistake a
concurrent-delete race for a successful update.
"""
...
@abstractmethod
async def count_users(self) -> int:
"""Return total number of registered users."""
...
@abstractmethod
async def count_admin_users(self) -> int:
"""Return number of users with system_role == 'admin'."""
...
@abstractmethod
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
"""Get user by OAuth provider and ID.
Args:
provider: OAuth provider name (e.g. 'github', 'google')
oauth_id: User ID from the OAuth provider
Returns:
User if found, None otherwise
"""
...
@@ -0,0 +1,127 @@
"""SQLAlchemy-backed UserRepository implementation.
Uses the shared async session factory from
``deerflow.persistence.engine`` — the ``users`` table lives in the
same database as ``threads_meta``, ``runs``, ``run_events``, and
``feedback``.
Constructor takes the session factory directly (same pattern as the
other four repositories in ``deerflow.persistence.*``). Callers
construct this after ``init_engine_from_config()`` has run.
"""
from __future__ import annotations
from datetime import UTC
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.gateway.auth.models import User
from app.gateway.auth.repositories.base import UserNotFoundError, UserRepository
from deerflow.persistence.user.model import UserRow
class SQLiteUserRepository(UserRepository):
"""Async user repository backed by the shared SQLAlchemy engine."""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
# ── Converters ────────────────────────────────────────────────────
@staticmethod
def _row_to_user(row: UserRow) -> User:
return User(
id=UUID(row.id),
email=row.email,
password_hash=row.password_hash,
system_role=row.system_role, # type: ignore[arg-type]
# SQLite loses tzinfo on read; reattach UTC so downstream
# code can compare timestamps reliably.
created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC),
oauth_provider=row.oauth_provider,
oauth_id=row.oauth_id,
needs_setup=row.needs_setup,
token_version=row.token_version,
)
@staticmethod
def _user_to_row(user: User) -> UserRow:
return UserRow(
id=str(user.id),
email=user.email,
password_hash=user.password_hash,
system_role=user.system_role,
created_at=user.created_at,
oauth_provider=user.oauth_provider,
oauth_id=user.oauth_id,
needs_setup=user.needs_setup,
token_version=user.token_version,
)
# ── CRUD ──────────────────────────────────────────────────────────
async def create_user(self, user: User) -> User:
"""Insert a new user. Raises ``ValueError`` on duplicate email."""
row = self._user_to_row(user)
async with self._sf() as session:
session.add(row)
try:
await session.commit()
except IntegrityError as exc:
await session.rollback()
raise ValueError(f"Email already registered: {user.email}") from exc
return user
async def get_user_by_id(self, user_id: str) -> User | None:
async with self._sf() as session:
row = await session.get(UserRow, user_id)
return self._row_to_user(row) if row is not None else None
async def get_user_by_email(self, email: str) -> User | None:
stmt = select(UserRow).where(UserRow.email == email)
async with self._sf() as session:
result = await session.execute(stmt)
row = result.scalar_one_or_none()
return self._row_to_user(row) if row is not None else None
async def update_user(self, user: User) -> User:
async with self._sf() as session:
row = await session.get(UserRow, str(user.id))
if row is None:
# Hard fail on concurrent delete: callers (reset_admin,
# password change handlers, _ensure_admin_user) all
# fetched the user just before this call, so a missing
# row here means the row vanished underneath us. Silent
# success would let the caller log "password reset" for
# a row that no longer exists.
raise UserNotFoundError(f"User {user.id} no longer exists")
row.email = user.email
row.password_hash = user.password_hash
row.system_role = user.system_role
row.oauth_provider = user.oauth_provider
row.oauth_id = user.oauth_id
row.needs_setup = user.needs_setup
row.token_version = user.token_version
await session.commit()
return user
async def count_users(self) -> int:
stmt = select(func.count()).select_from(UserRow)
async with self._sf() as session:
return await session.scalar(stmt) or 0
async def count_admin_users(self) -> int:
stmt = select(func.count()).select_from(UserRow).where(UserRow.system_role == "admin")
async with self._sf() as session:
return await session.scalar(stmt) or 0
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
async with self._sf() as session:
result = await session.execute(stmt)
row = result.scalar_one_or_none()
return self._row_to_user(row) if row is not None else None
+92
View File
@@ -0,0 +1,92 @@
"""CLI tool to reset an admin password.
Usage:
python -m app.gateway.auth.reset_admin
python -m app.gateway.auth.reset_admin --email admin@example.com
Writes the new password to ``.deer-flow/admin_initial_credentials.txt``
(mode 0600) instead of printing it, so CI / log aggregators never see
the cleartext secret.
"""
from __future__ import annotations
import argparse
import asyncio
import secrets
import sys
from sqlalchemy import select
from app.gateway.auth.credential_file import write_initial_credentials
from app.gateway.auth.password import hash_password
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.user.model import UserRow
async def _run(email: str | None) -> int:
from deerflow.config import AppConfig
from deerflow.persistence.engine import (
close_engine,
get_session_factory,
init_engine_from_config,
)
# CLI entry: load config explicitly at the top, pass down through the closure.
config = AppConfig.from_file()
await init_engine_from_config(config.database)
try:
sf = get_session_factory()
if sf is None:
print("Error: persistence engine not available (check config.database).", file=sys.stderr)
return 1
repo = SQLiteUserRepository(sf)
if email:
user = await repo.get_user_by_email(email)
else:
# Find first admin via direct SELECT — repository does not
# expose a "first admin" helper and we do not want to add
# one just for this CLI.
async with sf() as session:
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
user = None
else:
user = await repo.get_user_by_id(row.id)
if user is None:
if email:
print(f"Error: user '{email}' not found.", file=sys.stderr)
else:
print("Error: no admin user found.", file=sys.stderr)
return 1
new_password = secrets.token_urlsafe(16)
user.password_hash = hash_password(new_password)
user.token_version += 1
user.needs_setup = True
await repo.update_user(user)
cred_path = write_initial_credentials(user.email, new_password, label="reset")
print(f"Password reset for: {user.email}")
print(f"Credentials written to: {cred_path} (mode 0600)")
print("Next login will require setup (new email + password).")
return 0
finally:
await close_engine()
def main() -> None:
parser = argparse.ArgumentParser(description="Reset admin password")
parser.add_argument("--email", help="Admin email (default: first admin found)")
args = parser.parse_args()
exit_code = asyncio.run(_run(args.email))
sys.exit(exit_code)
if __name__ == "__main__":
main()
+118
View File
@@ -0,0 +1,118 @@
"""Global authentication middleware — fail-closed safety net.
Rejects unauthenticated requests to non-public paths with 401. When a
request passes the cookie check, resolves the JWT payload to a real
``User`` object and stamps it into both ``request.state.user`` and the
``deerflow.runtime.user_context`` contextvar so that repository-layer
owner filtering works automatically via the sentinel pattern.
Fine-grained permission checks remain in authz.py decorators.
"""
from collections.abc import Callable
from fastapi import HTTPException, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.types import ASGIApp
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
from deerflow.runtime.user_context import reset_current_user, set_current_user
# Paths that never require authentication.
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
"/health",
"/docs",
"/redoc",
"/openapi.json",
)
# Exact auth paths that are public (login/register/status check).
# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public.
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/login/local",
"/api/v1/auth/register",
"/api/v1/auth/logout",
"/api/v1/auth/setup-status",
"/api/v1/auth/initialize",
}
)
def _is_public(path: str) -> bool:
stripped = path.rstrip("/")
if stripped in _PUBLIC_EXACT_PATHS:
return True
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
class AuthMiddleware(BaseHTTPMiddleware):
"""Strict auth gate: reject requests without a valid session.
Two-stage check for non-public paths:
1. Cookie presence — return 401 NOT_AUTHENTICATED if missing
2. JWT validation via ``get_optional_user_from_request`` — return 401
TOKEN_INVALID if the token is absent, malformed, expired, or the
signed user does not exist / is stale
On success, stamps ``request.state.user`` and the
``deerflow.runtime.user_context`` contextvar so that repository-layer
owner filters work downstream without every route needing a
``@require_auth`` decorator. Routes that need per-resource
authorization (e.g. "user A cannot read user B's thread by guessing
the URL") should additionally use ``@require_permission(...,
owner_check=True)`` for explicit enforcement — but authentication
itself is fully handled here.
"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
if _is_public(request.url.path):
return await call_next(request)
# Non-public path: require session cookie
if not request.cookies.get("access_token"):
return JSONResponse(
status_code=401,
content={
"detail": AuthErrorResponse(
code=AuthErrorCode.NOT_AUTHENTICATED,
message="Authentication required",
).model_dump()
},
)
# Strict JWT validation: reject junk/expired tokens with 401
# right here instead of silently passing through. This closes
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
# without this, non-isolation routes like /api/models would
# accept any cookie-shaped string as authentication.
#
# We call the *strict* resolver so that fine-grained error
# codes (token_expired, token_invalid, user_not_found, …)
# propagate from AuthErrorCode, not get flattened into one
# generic code. BaseHTTPMiddleware doesn't let HTTPException
# bubble up, so we catch and render it as JSONResponse here.
from app.gateway.deps import get_current_user_from_request
try:
user = await get_current_user_from_request(request)
except HTTPException as exc:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
# Stamp both request.state.user (for the contextvar pattern)
# and request.state.auth (so @require_permission's "auth is
# None" branch short-circuits instead of running the entire
# JWT-decode + DB-lookup pipeline a second time per request).
request.state.user = user
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
token = set_current_user(user)
try:
return await call_next(request)
finally:
reset_current_user(token)
+262
View File
@@ -0,0 +1,262 @@
"""Authorization decorators and context for DeerFlow.
Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py
**Usage:**
1. Use ``@require_auth`` on routes that need authentication
2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks
3. The decorator chain processes from bottom to top
**Example:**
@router.get("/{thread_id}")
@require_auth
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request):
# User is authenticated and has threads:read permission
...
**Permission Model:**
- threads:read - View thread
- threads:write - Create/update thread
- threads:delete - Delete thread
- runs:create - Run agent
- runs:read - View run
- runs:cancel - Cancel run
"""
from __future__ import annotations
import functools
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
from fastapi import HTTPException, Request
if TYPE_CHECKING:
from app.gateway.auth.models import User
P = ParamSpec("P")
T = TypeVar("T")
# Permission constants
class Permissions:
"""Permission constants for resource:action format."""
# Threads
THREADS_READ = "threads:read"
THREADS_WRITE = "threads:write"
THREADS_DELETE = "threads:delete"
# Runs
RUNS_CREATE = "runs:create"
RUNS_READ = "runs:read"
RUNS_CANCEL = "runs:cancel"
class AuthContext:
"""Authentication context for the current request.
Stored in request.state.auth after require_auth decoration.
Attributes:
user: The authenticated user, or None if anonymous
permissions: List of permission strings (e.g., "threads:read")
"""
__slots__ = ("user", "permissions")
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
self.user = user
self.permissions = permissions or []
@property
def is_authenticated(self) -> bool:
"""Check if user is authenticated."""
return self.user is not None
def has_permission(self, resource: str, action: str) -> bool:
"""Check if context has permission for resource:action.
Args:
resource: Resource name (e.g., "threads")
action: Action name (e.g., "read")
Returns:
True if user has permission
"""
permission = f"{resource}:{action}"
return permission in self.permissions
def require_user(self) -> User:
"""Get user or raise 401.
Raises:
HTTPException 401 if not authenticated
"""
if not self.user:
raise HTTPException(status_code=401, detail="Authentication required")
return self.user
def get_auth_context(request: Request) -> AuthContext | None:
"""Get AuthContext from request state."""
return getattr(request.state, "auth", None)
_ALL_PERMISSIONS: list[str] = [
Permissions.THREADS_READ,
Permissions.THREADS_WRITE,
Permissions.THREADS_DELETE,
Permissions.RUNS_CREATE,
Permissions.RUNS_READ,
Permissions.RUNS_CANCEL,
]
async def _authenticate(request: Request) -> AuthContext:
"""Authenticate request and return AuthContext.
Delegates to deps.get_optional_user_from_request() for the JWT→User pipeline.
Returns AuthContext with user=None for anonymous requests.
"""
from app.gateway.deps import get_optional_user_from_request
user = await get_optional_user_from_request(request)
if user is None:
return AuthContext(user=None, permissions=[])
# In future, permissions could be stored in user record
return AuthContext(user=user, permissions=_ALL_PERMISSIONS)
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
"""Decorator that authenticates the request and sets AuthContext.
Must be placed ABOVE other decorators (executes after them).
Usage:
@router.get("/{thread_id}")
@require_auth # Bottom decorator (executes first after permission check)
@require_permission("threads", "read")
async def get_thread(thread_id: str, request: Request):
auth: AuthContext = request.state.auth
...
Raises:
ValueError: If 'request' parameter is missing
"""
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request")
if request is None:
raise ValueError("require_auth decorator requires 'request' parameter")
# Authenticate and set context
auth_context = await _authenticate(request)
request.state.auth = auth_context
return await func(*args, **kwargs)
return wrapper
def require_permission(
resource: str,
action: str,
owner_check: bool = False,
require_existing: bool = False,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Decorator that checks permission for resource:action.
Must be used AFTER @require_auth.
Args:
resource: Resource name (e.g., "threads", "runs")
action: Action name (e.g., "read", "write", "delete")
owner_check: If True, validates that the current user owns the resource.
Requires 'thread_id' path parameter and performs ownership check.
require_existing: Only meaningful with ``owner_check=True``. If True, a
missing ``threads_meta`` row counts as a denial (404)
instead of "untracked legacy thread, allow". Use on
**destructive / mutating** routes (DELETE, PATCH,
state-update) so a deleted thread can't be re-targeted
by another user via the missing-row code path.
Usage:
# Read-style: legacy untracked threads are allowed
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request):
...
# Destructive: thread row MUST exist and be owned by caller
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_thread(thread_id: str, request: Request):
...
Raises:
HTTPException 401: If authentication required but user is anonymous
HTTPException 403: If user lacks permission
HTTPException 404: If owner_check=True but user doesn't own the thread
ValueError: If owner_check=True but 'thread_id' parameter is missing
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request")
if request is None:
raise ValueError("require_permission decorator requires 'request' parameter")
auth: AuthContext = getattr(request.state, "auth", None)
if auth is None:
auth = await _authenticate(request)
request.state.auth = auth
if not auth.is_authenticated:
raise HTTPException(status_code=401, detail="Authentication required")
# Check permission
if not auth.has_permission(resource, action):
raise HTTPException(
status_code=403,
detail=f"Permission denied: {resource}:{action}",
)
# Owner check for thread-specific resources.
#
# 2.0-rc moved thread metadata into the SQL persistence layer
# (``threads_meta`` table). We verify ownership via
# ``ThreadMetaStore.check_access``: it returns True for
# missing rows (untracked legacy thread) and for rows whose
# ``user_id`` is NULL (shared / pre-auth data), so this is
# strict-deny rather than strict-allow — only an *existing*
# row with a *different* user_id triggers 404.
if owner_check:
thread_id = kwargs.get("thread_id")
if thread_id is None:
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
from app.gateway.deps import get_thread_store
thread_store = get_thread_store(request)
allowed = await thread_store.check_access(
thread_id,
str(auth.user.id),
require_existing=require_existing,
)
if not allowed:
raise HTTPException(
status_code=404,
detail=f"Thread {thread_id} not found",
)
return await func(*args, **kwargs)
return wrapper
return decorator
-3
View File
@@ -1,3 +0,0 @@
from .lifespan import lifespan_manager
__all__ = ["lifespan_manager"]
-52
View File
@@ -1,52 +0,0 @@
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any
from fastapi import FastAPI
LifespanFunc = Callable[[FastAPI], AbstractAsyncContextManager[dict[str, Any] | None]]
class LifespanManager:
"""FastAPI lifespan manager"""
def __init__(self) -> None:
self._lifespans: list[LifespanFunc] = []
def register(self, func: LifespanFunc) -> LifespanFunc:
"""
Register a lifespan hook.
:param func: lifespan hook
:return:
"""
if func not in self._lifespans:
self._lifespans.append(func)
return func
def build(self) -> LifespanFunc:
"""
Build the combined lifespan hook.
:return:
"""
@asynccontextmanager
async def combined_lifespan(app: FastAPI): # noqa: ANN202
state: dict[str, Any] = {}
async with AsyncExitStack() as exit_stack:
for lifespan_fn in self._lifespans:
result = await exit_stack.enter_async_context(lifespan_fn(app))
if isinstance(result, dict):
state.update(result)
for key, value in state.items():
setattr(app.state, key, value)
yield state or None
return combined_lifespan
# Singleton lifespan_manager instance
lifespan_manager = LifespanManager()
@@ -1,4 +1,8 @@
"""CSRF protection middleware and helpers for cookie-based auth flows."""
"""CSRF protection middleware for FastAPI.
Per RFC-001:
State-changing operations require CSRF protection.
"""
import secrets
from collections.abc import Callable
@@ -24,11 +28,16 @@ def generate_csrf_token() -> str:
def should_check_csrf(request: Request) -> bool:
"""Determine if a request needs CSRF validation."""
"""Determine if a request needs CSRF validation.
CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH).
GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231.
"""
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
return False
path = request.url.path.rstrip("/")
# Exempt /api/v1/auth/me endpoint
if path == "/api/v1/auth/me":
return False
return True
@@ -45,12 +54,15 @@ _AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
def is_auth_endpoint(request: Request) -> bool:
"""Check if the request is to an auth endpoint."""
"""Check if the request is to an auth endpoint.
Auth endpoints don't need CSRF validation on first call (no token).
"""
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
class CSRFMiddleware(BaseHTTPMiddleware):
"""Implement CSRF protection using the double-submit cookie pattern."""
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
@@ -76,13 +88,16 @@ class CSRFMiddleware(BaseHTTPMiddleware):
response = await call_next(request)
# For auth endpoints that set up session, also set CSRF cookie
if _is_auth and request.method == "POST":
# Generate a new CSRF token for the session
csrf_token = generate_csrf_token()
is_https = is_secure_request(request)
response.set_cookie(
key=CSRF_COOKIE_NAME,
value=csrf_token,
httponly=False,
secure=is_secure_request(request),
httponly=False, # Must be JS-readable for Double Submit Cookie pattern
secure=is_https,
samesite="strict",
)
@@ -90,17 +105,9 @@ class CSRFMiddleware(BaseHTTPMiddleware):
def get_csrf_token(request: Request) -> str | None:
"""Get the CSRF token from the current request's cookies."""
"""Get the CSRF token from the current request's cookies.
This is useful for server-side rendering where you need to embed
token in forms or headers.
"""
return request.cookies.get(CSRF_COOKIE_NAME)
__all__ = [
"CSRF_COOKIE_NAME",
"CSRF_HEADER_NAME",
"CSRFMiddleware",
"generate_csrf_token",
"get_csrf_token",
"is_auth_endpoint",
"is_secure_request",
"should_check_csrf",
]
@@ -1,59 +0,0 @@
from app.gateway.dependencies.checkpointer import (
CurrentCheckpointer,
get_checkpointer,
)
from app.plugins.auth.security.dependencies import (
CurrentAuthService,
CurrentUserRepository,
get_auth_service,
get_current_user_from_request,
get_current_user_id,
get_optional_user_from_request,
get_user_repository,
)
from app.gateway.dependencies.db import (
CurrentSession,
CurrentSessionTransaction,
get_db_session,
get_db_session_transaction,
)
from app.gateway.dependencies.repositories import (
CurrentFeedbackRepository,
CurrentRunRepository,
CurrentThreadMetaRepository,
CurrentThreadMetaStorage,
get_feedback_repository,
get_run_repository,
get_thread_meta_repository,
get_thread_meta_storage,
)
from app.gateway.dependencies.stream_bridge import (
CurrentStreamBridge,
get_stream_bridge,
)
__all__ = [
"CurrentCheckpointer",
"CurrentAuthService",
"CurrentFeedbackRepository",
"CurrentRunRepository",
"CurrentSession",
"CurrentSessionTransaction",
"CurrentStreamBridge",
"CurrentThreadMetaRepository",
"CurrentThreadMetaStorage",
"CurrentUserRepository",
"get_auth_service",
"get_checkpointer",
"get_current_user_from_request",
"get_current_user_id",
"get_db_session",
"get_db_session_transaction",
"get_feedback_repository",
"get_optional_user_from_request",
"get_run_repository",
"get_stream_bridge",
"get_thread_meta_repository",
"get_thread_meta_storage",
"get_user_repository",
]
@@ -1,20 +0,0 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends, HTTPException, Request
from langgraph.types import Checkpointer
def get_checkpointer(request: Request) -> Checkpointer:
"""Get checkpointer from app.state.persistence."""
persistence = getattr(request.app.state, "persistence", None)
if persistence is None:
raise HTTPException(status_code=503, detail="Persistence not available")
checkpointer = getattr(persistence, "checkpointer", None)
if checkpointer is None:
raise HTTPException(status_code=503, detail="Checkpointer not available")
return checkpointer
CurrentCheckpointer = Annotated[Checkpointer, Depends(get_checkpointer)]
-37
View File
@@ -1,37 +0,0 @@
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Annotated
from fastapi import Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
def _get_session_factory(request: Request) -> async_sessionmaker[AsyncSession]:
factory = getattr(request.app.state.persistence, "session_factory", None)
if factory is None:
raise HTTPException(status_code=503, detail="Database session factory not available")
return factory
async def get_db_session(request: Request) -> AsyncIterator[AsyncSession]:
"""Open a session without auto-commit. Use for read-only endpoints."""
session_factory = _get_session_factory(request)
async with session_factory() as session:
yield session
async def get_db_session_transaction(request: Request) -> AsyncIterator[AsyncSession]:
"""Open a session and commit on success, rollback on error."""
session_factory = _get_session_factory(request)
async with session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
CurrentSession = Annotated[AsyncSession, Depends(get_db_session)]
CurrentSessionTransaction = Annotated[AsyncSession, Depends(get_db_session_transaction)]
@@ -1,41 +0,0 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends, HTTPException, Request
from app.infra.storage import ThreadMetaStorage
from store.repositories.contracts import (
FeedbackRepositoryProtocol,
RunRepositoryProtocol,
ThreadMetaRepositoryProtocol,
)
def _require_state(request: Request, attr: str, label: str):
value = getattr(request.app.state, attr, None)
if value is None:
raise HTTPException(status_code=503, detail=f"{label} not available")
return value
def get_run_repository(request: Request) -> RunRepositoryProtocol:
return _require_state(request, "run_store", "Run store")
def get_thread_meta_repository(request: Request) -> ThreadMetaRepositoryProtocol:
return _require_state(request, "thread_meta_repo", "Thread metadata store")
def get_thread_meta_storage(request: Request) -> ThreadMetaStorage:
return _require_state(request, "thread_meta_storage", "Thread metadata storage")
def get_feedback_repository(request: Request) -> FeedbackRepositoryProtocol:
return _require_state(request, "feedback_repo", "Feedback")
CurrentRunRepository = Annotated[RunRepositoryProtocol, Depends(get_run_repository)]
CurrentThreadMetaRepository = Annotated[ThreadMetaRepositoryProtocol, Depends(get_thread_meta_repository)]
CurrentThreadMetaStorage = Annotated[ThreadMetaStorage, Depends(get_thread_meta_storage)]
CurrentFeedbackRepository = Annotated[FeedbackRepositoryProtocol, Depends(get_feedback_repository)]
@@ -1,18 +0,0 @@
from __future__ import annotations
from typing import Annotated
from fastapi import Depends, HTTPException, Request
from deerflow.runtime import StreamBridge
def get_stream_bridge(request: Request) -> StreamBridge:
"""Get stream bridge from app.state."""
bridge = getattr(request.app.state, "stream_bridge", None)
if bridge is None:
raise HTTPException(status_code=503, detail="Stream bridge not available")
return bridge
CurrentStreamBridge = Annotated[StreamBridge, Depends(get_stream_bridge)]
+247
View File
@@ -0,0 +1,247 @@
"""Centralized accessors for singleton objects stored on ``app.state``.
**Getters** (used by routers): raise 503 when a required dependency is
missing, except ``get_store`` which returns ``None``.
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING
from fastapi import FastAPI, HTTPException, Request
from deerflow.config.app_config import AppConfig
from deerflow.runtime import RunContext, RunManager
if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.thread_meta.base import ThreadMetaStore
def get_config(request: Request) -> AppConfig:
"""FastAPI dependency returning the app-scoped ``AppConfig``.
Reads from ``request.app.state.config`` which is set at startup
(``app.py`` lifespan) and swapped on config reload (``routers/mcp.py``,
``routers/skills.py``).
"""
cfg = getattr(request.app.state, "config", None)
if cfg is None:
raise HTTPException(status_code=503, detail="Configuration not available")
return cfg
@asynccontextmanager
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
"""Bootstrap and tear down all LangGraph runtime singletons.
Usage in ``app.py``::
async with langgraph_runtime(app):
yield
"""
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
from deerflow.runtime import make_store, make_stream_bridge
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
from deerflow.runtime.events.store import make_run_event_store
async with AsyncExitStack() as stack:
# app.state.config is populated earlier in lifespan(); thread it
# explicitly into every provider below.
config = app.state.config
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge(config))
# Initialize persistence engine BEFORE checkpointer so that
# auto-create-database logic runs first (postgres backend).
await init_engine_from_config(config.database)
app.state.checkpointer = await stack.enter_async_context(make_checkpointer(config))
app.state.store = await stack.enter_async_context(make_store(config))
# Initialize repositories — one get_session_factory() call for all.
sf = get_session_factory()
if sf is not None:
from deerflow.persistence.feedback import FeedbackRepository
from deerflow.persistence.run import RunRepository
app.state.run_store = RunRepository(sf)
app.state.feedback_repo = FeedbackRepository(sf)
else:
from deerflow.runtime.runs.store.memory import MemoryRunStore
app.state.run_store = MemoryRunStore()
app.state.feedback_repo = None
from deerflow.persistence.thread_meta import make_thread_store
app.state.thread_store = make_thread_store(sf, app.state.store)
# Run event store (has its own factory with config-driven backend selection)
run_events_config = getattr(config, "run_events", None)
app.state.run_event_store = make_run_event_store(run_events_config)
# RunManager with store backing for persistence
app.state.run_manager = RunManager(store=app.state.run_store)
try:
yield
finally:
await close_engine()
# ---------------------------------------------------------------------------
# Getters called by routers per-request
# ---------------------------------------------------------------------------
def _require(attr: str, label: str):
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
def dep(request: Request):
val = getattr(request.app.state, attr, None)
if val is None:
raise HTTPException(status_code=503, detail=f"{label} not available")
return val
dep.__name__ = dep.__qualname__ = f"get_{attr}"
return dep
get_stream_bridge = _require("stream_bridge", "Stream bridge")
get_run_manager = _require("run_manager", "Run manager")
get_checkpointer = _require("checkpointer", "Checkpointer")
get_run_event_store = _require("run_event_store", "Run event store")
get_feedback_repo = _require("feedback_repo", "Feedback")
get_run_store = _require("run_store", "Run store")
def get_store(request: Request):
"""Return the global store (may be ``None`` if not configured)."""
return getattr(request.app.state, "store", None)
def get_thread_store(request: Request) -> ThreadMetaStore:
"""Return the thread metadata store (SQL or memory-backed)."""
val = getattr(request.app.state, "thread_store", None)
if val is None:
raise HTTPException(status_code=503, detail="Thread metadata store not available")
return val
def get_run_context(request: Request) -> RunContext:
"""Build a :class:`RunContext` from ``app.state`` singletons.
Returns a *base* context with infrastructure dependencies. Callers that
need per-run fields (e.g. ``follow_up_to_run_id``) should use
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
to :func:`run_agent`.
"""
config = get_config(request)
return RunContext(
checkpointer=get_checkpointer(request),
store=get_store(request),
event_store=get_run_event_store(request),
run_events_config=getattr(config, "run_events", None),
thread_store=get_thread_store(request),
app_config=config,
)
# ---------------------------------------------------------------------------
# Auth helpers (used by authz.py and auth middleware)
# ---------------------------------------------------------------------------
# Cached singletons to avoid repeated instantiation per request
_cached_local_provider: LocalAuthProvider | None = None
_cached_repo: SQLiteUserRepository | None = None
def get_local_provider() -> LocalAuthProvider:
"""Get or create the cached LocalAuthProvider singleton.
Must be called after ``init_engine_from_config()`` — the shared
session factory is required to construct the user repository.
"""
global _cached_local_provider, _cached_repo
if _cached_repo is None:
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.engine import get_session_factory
sf = get_session_factory()
if sf is None:
raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table")
_cached_repo = SQLiteUserRepository(sf)
if _cached_local_provider is None:
from app.gateway.auth.local_provider import LocalAuthProvider
_cached_local_provider = LocalAuthProvider(repository=_cached_repo)
return _cached_local_provider
async def get_current_user_from_request(request: Request):
"""Get the current authenticated user from the request cookie.
Raises HTTPException 401 if not authenticated.
"""
from app.gateway.auth import decode_token
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
access_token = request.cookies.get("access_token")
if not access_token:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
)
payload = decode_token(access_token)
if isinstance(payload, TokenError):
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(),
)
provider = get_local_provider()
user = await provider.get_user(payload.sub)
if user is None:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
)
# Token version mismatch → password was changed, token is stale
if user.token_version != payload.ver:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(),
)
return user
async def get_optional_user_from_request(request: Request):
"""Get optional authenticated user from request.
Returns None if not authenticated.
"""
try:
return await get_current_user_from_request(request)
except HTTPException:
return None
async def get_current_user(request: Request) -> str | None:
"""Extract user_id from request cookie, or None if not authenticated.
Thin adapter that returns the string id for callers that only need
identification (e.g., ``feedback.py``). Full-user callers should use
``get_current_user_from_request`` or ``get_optional_user_from_request``.
"""
user = await get_optional_user_from_request(request)
return str(user.id) if user else None
+106
View File
@@ -0,0 +1,106 @@
"""LangGraph Server auth handler — shares JWT logic with Gateway.
Loaded by LangGraph Server via langgraph.json ``auth.path``.
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
so both modes validate tokens with the same secret and rules.
Two layers:
1. @auth.authenticate — validates JWT cookie, extracts user_id,
and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH)
2. @auth.on — returns metadata filter so each user only sees own threads
"""
import secrets
from langgraph_sdk import Auth
from app.gateway.auth.errors import TokenError
from app.gateway.auth.jwt import decode_token
from app.gateway.deps import get_local_provider
auth = Auth()
# Methods that require CSRF validation (state-changing per RFC 7231).
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
def _check_csrf(request) -> None:
"""Enforce Double Submit Cookie CSRF check for state-changing requests.
Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes
proxied directly by nginx have the same CSRF protection.
"""
method = getattr(request, "method", "") or ""
if method.upper() not in _CSRF_METHODS:
return
cookie_token = request.cookies.get("csrf_token")
header_token = request.headers.get("x-csrf-token")
if not cookie_token or not header_token:
raise Auth.exceptions.HTTPException(
status_code=403,
detail="CSRF token missing. Include X-CSRF-Token header.",
)
if not secrets.compare_digest(cookie_token, header_token):
raise Auth.exceptions.HTTPException(
status_code=403,
detail="CSRF token mismatch.",
)
@auth.authenticate
async def authenticate(request):
"""Validate the session cookie, decode JWT, and check token_version.
Same validation chain as Gateway's get_current_user_from_request:
cookie → decode JWT → DB lookup → token_version match
Also enforces CSRF on state-changing methods.
"""
# CSRF check before authentication so forged cross-site requests
# are rejected early, even if the cookie carries a valid JWT.
_check_csrf(request)
token = request.cookies.get("access_token")
if not token:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="Not authenticated",
)
payload = decode_token(token)
if isinstance(payload, TokenError):
raise Auth.exceptions.HTTPException(
status_code=401,
detail=f"Token error: {payload.value}",
)
user = await get_local_provider().get_user(payload.sub)
if user is None:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="User not found",
)
if user.token_version != payload.ver:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="Token revoked (password changed)",
)
return payload.sub
@auth.on
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
"""Inject user_id metadata on writes; filter by user_id on reads.
Gateway stores thread ownership as ``metadata.user_id``.
This handler ensures LangGraph Server enforces the same isolation.
"""
# On create/update: stamp user_id into metadata
metadata = value.setdefault("metadata", {})
metadata["user_id"] = ctx.user.identity
# Return filter dict — LangGraph applies it to search/read/delete
return {"user_id": ctx.user.identity}
+3 -5
View File
@@ -5,17 +5,16 @@ from pathlib import Path
from fastapi import HTTPException
from deerflow.config.paths import get_paths
from deerflow.runtime.actor_context import get_effective_user_id
from deerflow.runtime.user_context import get_effective_user_id
def resolve_thread_virtual_path(thread_id: str, virtual_path: str, *, user_id: str | None = None) -> Path:
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
"""Resolve a virtual path to the actual filesystem path under thread user-data.
Args:
thread_id: The thread ID.
virtual_path: The virtual path as seen inside the sandbox
(e.g., /mnt/user-data/outputs/file.txt).
user_id: Explicit user id override. Falls back to the current actor context.
Returns:
The resolved filesystem path.
@@ -24,8 +23,7 @@ def resolve_thread_virtual_path(thread_id: str, virtual_path: str, *, user_id: s
HTTPException: If the path is invalid or outside allowed directories.
"""
try:
resolved_user_id = get_effective_user_id() if user_id is None else user_id
return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=resolved_user_id)
return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=get_effective_user_id())
except ValueError as e:
status = 403 if "traversal" in str(e) else 400
raise HTTPException(status_code=status, detail=str(e))
-132
View File
@@ -1,132 +0,0 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from scalar_fastapi import AgentScalarConfig, get_scalar_api_reference
from starlette.middleware.cors import CORSMiddleware
from store.persistence import create_persistence
from app.gateway.common import lifespan_manager
from app.gateway.router import router as gateway_router
from app.infra.run_events import build_run_event_store
from app.infra.storage import FeedbackStoreAdapter, RunStoreAdapter, ThreadMetaStorage, ThreadMetaStoreAdapter
from app.plugins.auth.authorization.hooks import build_authz_hooks
from app.plugins.auth.injection import install_route_guards, load_route_policy_registry, validate_route_policy_registry
from app.plugins.auth.security import AuthMiddleware, CSRFMiddleware
STATIC_DIR = Path(__file__).resolve().parents[1] / "static"
STATIC_MOUNT = "/api/static"
SCALAR_JS_URL = f"{STATIC_MOUNT}/scalar.js"
@lifespan_manager.register
@asynccontextmanager
async def init_persistence(app: FastAPI) -> AsyncGenerator[dict[str, Any], None]:
"""Initialize persistence layer (DB, checkpointer, store)."""
app_persistence = await create_persistence()
await app_persistence.setup()
run_store = RunStoreAdapter(app_persistence.session_factory)
thread_meta_store = ThreadMetaStoreAdapter(app_persistence.session_factory)
feedback_store = FeedbackStoreAdapter(app_persistence.session_factory)
try:
yield {
"persistence": app_persistence,
"checkpointer": app_persistence.checkpointer,
"store": None,
"session_factory": app_persistence.session_factory,
"run_store": run_store,
"run_read_repo": run_store,
"run_write_repo": run_store,
"run_delete_repo": run_store,
"feedback_repo": feedback_store,
"thread_meta_repo": thread_meta_store,
"thread_meta_storage": ThreadMetaStorage(thread_meta_store),
"run_event_store": build_run_event_store(app_persistence.session_factory),
}
finally:
await app_persistence.aclose()
@lifespan_manager.register
@asynccontextmanager
async def init_runtime(app: FastAPI) -> AsyncGenerator[dict[str, Any], None]:
"""Initialize StreamBridge for LangGraph-compatible runtime endpoints."""
from app.infra.stream_bridge import build_stream_bridge
async with build_stream_bridge() as stream_bridge:
yield {
"stream_bridge": stream_bridge,
}
def register_app() -> FastAPI:
app = FastAPI(
title="DeerFlow API Gateway",
version="0.1.0",
docs_url=None,
redoc_url=None,
lifespan=lifespan_manager.build(),
openapi_tags=[
{
"name": "threads",
"description": "Endpoints for managing threads, which are conversations between a human and an assistant. A thread can have multiple runs as the conversation progresses."
}
]
)
app.state.authz_hooks = build_authz_hooks()
_register_static(app)
_register_routes(app)
_register_scalar(app)
_register_auth_route_policies(app)
_register_middlewares(app)
return app
def _register_static(app: FastAPI) -> None:
app.mount(STATIC_MOUNT, StaticFiles(directory=STATIC_DIR), name="static")
def _register_routes(app: FastAPI) -> None:
app.include_router(gateway_router)
def _register_auth_route_policies(app: FastAPI) -> None:
registry = load_route_policy_registry()
validate_route_policy_registry(app, registry)
app.state.auth_route_policy_registry = registry
install_route_guards(app)
def _register_middlewares(app: FastAPI) -> None:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)
app.add_middleware(CSRFMiddleware)
app.add_middleware(AuthMiddleware)
def _register_scalar(app: FastAPI) -> None:
@app.get("/docs", include_in_schema=False)
def scalar_docs() -> HTMLResponse:
return get_scalar_api_reference(
openapi_url=app.openapi_url,
title=app.title,
scalar_js_url=SCALAR_JS_URL,
agent=AgentScalarConfig(disabled=True),
hide_client_button=True,
overrides={"mcp": {"disabled": True}},
)
-22
View File
@@ -1,22 +0,0 @@
from fastapi import APIRouter
from app.plugins.auth.api.router import router as auth_router
from .routers import artifacts, channels, mcp, models, skills, uploads
from .routers.agents import router as agents_router
from .routers.langgraph import feedback_router, runs_router, suggestion_router, threads_router
router = APIRouter()
router.include_router(auth_router)
router.include_router(threads_router, prefix="/api/threads")
router.include_router(runs_router, prefix="/api/threads")
router.include_router(feedback_router, prefix="/api/threads")
router.include_router(suggestion_router)
router.include_router(agents_router)
router.include_router(channels.router)
router.include_router(artifacts.router)
router.include_router(mcp.router)
router.include_router(models.router)
router.include_router(skills.router)
router.include_router(uploads.router)
+2 -2
View File
@@ -1,3 +1,3 @@
from . import artifacts, mcp, models, skills, suggestions, uploads
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads
__all__ = ["artifacts", "mcp", "models", "skills", "suggestions", "uploads"]
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
+51 -12
View File
@@ -5,10 +5,12 @@ import re
import shutil
import yaml
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from app.gateway.deps import get_config
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__)
@@ -24,6 +26,7 @@ class AgentResponse(BaseModel):
description: str = Field(default="", description="Agent description")
model: str | None = Field(default=None, description="Optional model override")
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
skills: list[str] | None = Field(default=None, description="Optional skill whitelist (None=all, []=none)")
soul: str | None = Field(default=None, description="SOUL.md content")
@@ -40,6 +43,7 @@ class AgentCreateRequest(BaseModel):
description: str = Field(default="", description="Agent description")
model: str | None = Field(default=None, description="Optional model override")
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
skills: list[str] | None = Field(default=None, description="Optional skill whitelist (None=all enabled, []=none)")
soul: str = Field(default="", description="SOUL.md content — agent personality and behavioral guardrails")
@@ -49,6 +53,7 @@ class AgentUpdateRequest(BaseModel):
description: str | None = Field(default=None, description="Updated description")
model: str | None = Field(default=None, description="Updated model override")
tool_groups: list[str] | None = Field(default=None, description="Updated tool group whitelist")
skills: list[str] | None = Field(default=None, description="Updated skill whitelist (None=all, []=none)")
soul: str | None = Field(default=None, description="Updated SOUL.md content")
@@ -73,6 +78,15 @@ def _normalize_agent_name(name: str) -> str:
return name.lower()
def _require_agents_api_enabled(app_config: AppConfig) -> None:
"""Reject access unless the custom-agent management API is explicitly enabled."""
if not app_config.agents_api.enabled:
raise HTTPException(
status_code=403,
detail=("Custom-agent management API is disabled. Set agents_api.enabled=true to expose agent and user-profile routes over HTTP."),
)
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
"""Convert AgentConfig to AgentResponse."""
soul: str | None = None
@@ -84,6 +98,7 @@ def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False
description=agent_cfg.description,
model=agent_cfg.model,
tool_groups=agent_cfg.tool_groups,
skills=agent_cfg.skills,
soul=soul,
)
@@ -94,12 +109,14 @@ def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False
summary="List Custom Agents",
description="List all custom agents available in the agents directory, including their soul content.",
)
async def list_agents() -> AgentsListResponse:
async def list_agents(app_config: AppConfig = Depends(get_config)) -> AgentsListResponse:
"""List all custom agents.
Returns:
List of all custom agents with their metadata and soul content.
"""
_require_agents_api_enabled(app_config)
try:
agents = list_custom_agents()
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
@@ -125,6 +142,7 @@ async def check_agent_name(name: str) -> dict:
Raises:
HTTPException: 422 if the name is invalid.
"""
_require_agents_api_enabled(app_config)
_validate_agent_name(name)
normalized = _normalize_agent_name(name)
available = not get_paths().agent_dir(normalized).exists()
@@ -137,7 +155,7 @@ async def check_agent_name(name: str) -> dict:
summary="Get Custom Agent",
description="Retrieve details and SOUL.md content for a specific custom agent.",
)
async def get_agent(name: str) -> AgentResponse:
async def get_agent(name: str, app_config: AppConfig = Depends(get_config)) -> AgentResponse:
"""Get a specific custom agent by name.
Args:
@@ -149,6 +167,7 @@ async def get_agent(name: str) -> AgentResponse:
Raises:
HTTPException: 404 if agent not found.
"""
_require_agents_api_enabled(app_config)
_validate_agent_name(name)
name = _normalize_agent_name(name)
@@ -169,7 +188,7 @@ async def get_agent(name: str) -> AgentResponse:
summary="Create Custom Agent",
description="Create a new custom agent with its config and SOUL.md.",
)
async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
async def create_agent_endpoint(request: AgentCreateRequest, app_config: AppConfig = Depends(get_config)) -> AgentResponse:
"""Create a new custom agent.
Args:
@@ -181,6 +200,7 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
Raises:
HTTPException: 409 if agent already exists, 422 if name is invalid.
"""
_require_agents_api_enabled(app_config)
_validate_agent_name(request.name)
normalized_name = _normalize_agent_name(request.name)
@@ -200,6 +220,8 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
config_data["model"] = request.model
if request.tool_groups is not None:
config_data["tool_groups"] = request.tool_groups
if request.skills is not None:
config_data["skills"] = request.skills
config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f:
@@ -230,7 +252,7 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
summary="Update Custom Agent",
description="Update an existing custom agent's config and/or SOUL.md.",
)
async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
async def update_agent(name: str, request: AgentUpdateRequest, app_config: AppConfig = Depends(get_config)) -> AgentResponse:
"""Update an existing custom agent.
Args:
@@ -243,6 +265,7 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
Raises:
HTTPException: 404 if agent not found.
"""
_require_agents_api_enabled(app_config)
_validate_agent_name(name)
name = _normalize_agent_name(name)
@@ -255,21 +278,32 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
try:
# Update config if any config fields changed
config_changed = any(v is not None for v in [request.description, request.model, request.tool_groups])
# Use model_fields_set to distinguish "field omitted" from "explicitly set to null".
# This is critical for skills where None means "inherit all" (not "don't change").
fields_set = request.model_fields_set
config_changed = bool(fields_set & {"description", "model", "tool_groups", "skills"})
if config_changed:
updated: dict = {
"name": agent_cfg.name,
"description": request.description if request.description is not None else agent_cfg.description,
"description": request.description if "description" in fields_set else agent_cfg.description,
}
new_model = request.model if request.model is not None else agent_cfg.model
new_model = request.model if "model" in fields_set else agent_cfg.model
if new_model is not None:
updated["model"] = new_model
new_tool_groups = request.tool_groups if request.tool_groups is not None else agent_cfg.tool_groups
new_tool_groups = request.tool_groups if "tool_groups" in fields_set else agent_cfg.tool_groups
if new_tool_groups is not None:
updated["tool_groups"] = new_tool_groups
# skills: None = inherit all, [] = no skills, ["a","b"] = whitelist
if "skills" in fields_set:
new_skills = request.skills
else:
new_skills = agent_cfg.skills
if new_skills is not None:
updated["skills"] = new_skills
config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f:
yaml.dump(updated, f, default_flow_style=False, allow_unicode=True)
@@ -309,12 +343,14 @@ class UserProfileUpdateRequest(BaseModel):
summary="Get User Profile",
description="Read the global USER.md file that is injected into all custom agents.",
)
async def get_user_profile() -> UserProfileResponse:
async def get_user_profile(app_config: AppConfig = Depends(get_config)) -> UserProfileResponse:
"""Return the current USER.md content.
Returns:
UserProfileResponse with content=None if USER.md does not exist yet.
"""
_require_agents_api_enabled(app_config)
try:
user_md_path = get_paths().user_md_file
if not user_md_path.exists():
@@ -332,7 +368,7 @@ async def get_user_profile() -> UserProfileResponse:
summary="Update User Profile",
description="Write the global USER.md file that is injected into all custom agents.",
)
async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileResponse:
async def update_user_profile(request: UserProfileUpdateRequest, app_config: AppConfig = Depends(get_config)) -> UserProfileResponse:
"""Create or overwrite the global USER.md.
Args:
@@ -341,6 +377,8 @@ async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileR
Returns:
UserProfileResponse with the saved content.
"""
_require_agents_api_enabled(app_config)
try:
paths = get_paths()
paths.base_dir.mkdir(parents=True, exist_ok=True)
@@ -358,7 +396,7 @@ async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileR
summary="Delete Custom Agent",
description="Delete a custom agent and all its files (config, SOUL.md, memory).",
)
async def delete_agent(name: str) -> None:
async def delete_agent(name: str, app_config: AppConfig = Depends(get_config)) -> None:
"""Delete a custom agent.
Args:
@@ -367,6 +405,7 @@ async def delete_agent(name: str) -> None:
Raises:
HTTPException: 404 if agent not found.
"""
_require_agents_api_enabled(app_config)
_validate_agent_name(name)
name = _normalize_agent_name(name)
+2
View File
@@ -7,6 +7,7 @@ from urllib.parse import quote
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import FileResponse, PlainTextResponse, Response
from app.gateway.authz import require_permission
from app.gateway.path_utils import resolve_thread_virtual_path
logger = logging.getLogger(__name__)
@@ -81,6 +82,7 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
summary="Get Artifact File",
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
)
@require_permission("threads", "read", owner_check=True)
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
"""Get an artifact file by its path.
@@ -0,0 +1,149 @@
"""Assistants compatibility endpoints.
Provides LangGraph Platform-compatible assistants API backed by the
``langgraph.json`` graph registry and ``config.yaml`` agent definitions.
This is a minimal stub that satisfies the ``useStream`` React hook's
initialization requirements (``assistants.search()`` and ``assistants.get()``).
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime
from typing import Any
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/assistants", tags=["assistants-compat"])
class AssistantResponse(BaseModel):
assistant_id: str
graph_id: str
name: str
config: dict[str, Any] = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
description: str | None = None
created_at: str = ""
updated_at: str = ""
version: int = 1
class AssistantSearchRequest(BaseModel):
graph_id: str | None = None
name: str | None = None
metadata: dict[str, Any] | None = None
limit: int = 10
offset: int = 0
def _get_default_assistant() -> AssistantResponse:
"""Return the default lead_agent assistant."""
now = datetime.now(UTC).isoformat()
return AssistantResponse(
assistant_id="lead_agent",
graph_id="lead_agent",
name="lead_agent",
config={},
metadata={"created_by": "system"},
description="DeerFlow lead agent",
created_at=now,
updated_at=now,
version=1,
)
def _list_assistants() -> list[AssistantResponse]:
"""List all available assistants from config."""
assistants = [_get_default_assistant()]
# Also include custom agents from config.yaml agents directory
try:
from deerflow.config.agents_config import list_custom_agents
for agent_cfg in list_custom_agents():
now = datetime.now(UTC).isoformat()
assistants.append(
AssistantResponse(
assistant_id=agent_cfg.name,
graph_id="lead_agent", # All agents use the same graph
name=agent_cfg.name,
config={},
metadata={"created_by": "user"},
description=agent_cfg.description or "",
created_at=now,
updated_at=now,
version=1,
)
)
except Exception:
logger.debug("Could not load custom agents for assistants list")
return assistants
@router.post("/search", response_model=list[AssistantResponse])
async def search_assistants(body: AssistantSearchRequest | None = None) -> list[AssistantResponse]:
"""Search assistants.
Returns all registered assistants (lead_agent + custom agents from config).
"""
assistants = _list_assistants()
if body and body.graph_id:
assistants = [a for a in assistants if a.graph_id == body.graph_id]
if body and body.name:
assistants = [a for a in assistants if body.name.lower() in a.name.lower()]
offset = body.offset if body else 0
limit = body.limit if body else 10
return assistants[offset : offset + limit]
@router.get("/{assistant_id}", response_model=AssistantResponse)
async def get_assistant_compat(assistant_id: str) -> AssistantResponse:
"""Get an assistant by ID."""
for a in _list_assistants():
if a.assistant_id == assistant_id:
return a
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
@router.get("/{assistant_id}/graph")
async def get_assistant_graph(assistant_id: str) -> dict:
"""Get the graph structure for an assistant.
Returns a minimal graph description. Full graph introspection is
not supported in the Gateway — this stub satisfies SDK validation.
"""
found = any(a.assistant_id == assistant_id for a in _list_assistants())
if not found:
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
return {
"graph_id": "lead_agent",
"nodes": [],
"edges": [],
}
@router.get("/{assistant_id}/schemas")
async def get_assistant_schemas(assistant_id: str) -> dict:
"""Get JSON schemas for an assistant's input/output/state.
Returns empty schemas — full introspection not supported in Gateway.
"""
found = any(a.assistant_id == assistant_id for a in _list_assistants())
if not found:
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
return {
"graph_id": "lead_agent",
"input_schema": {},
"output_schema": {},
"state_schema": {},
"config_schema": {},
}
+459
View File
@@ -0,0 +1,459 @@
"""Authentication endpoints."""
import logging
import os
import time
from ipaddress import ip_address, ip_network
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr, Field, field_validator
from app.gateway.auth import (
UserResponse,
create_access_token,
)
from app.gateway.auth.config import get_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.csrf_middleware import is_secure_request
from app.gateway.deps import get_current_user_from_request, get_local_provider
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
# ── Request/Response Models ──────────────────────────────────────────────
class LoginResponse(BaseModel):
"""Response model for login — token only lives in HttpOnly cookie."""
expires_in: int # seconds
needs_setup: bool = False
# Top common-password blocklist. Drawn from the public SecLists "10k worst
# passwords" set, lowercased + length>=8 only (shorter ones already fail
# the min_length check). Kept tight on purpose: this is the **lower bound**
# defense, not a full HIBP / passlib check, and runs in-process per request.
_COMMON_PASSWORDS: frozenset[str] = frozenset(
{
"password",
"password1",
"password12",
"password123",
"password1234",
"12345678",
"123456789",
"1234567890",
"qwerty12",
"qwertyui",
"qwerty123",
"abc12345",
"abcd1234",
"iloveyou",
"letmein1",
"welcome1",
"welcome123",
"admin123",
"administrator",
"passw0rd",
"p@ssw0rd",
"monkey12",
"trustno1",
"sunshine",
"princess",
"football",
"baseball",
"superman",
"batman123",
"starwars",
"dragon123",
"master123",
"shadow12",
"michael1",
"jennifer",
"computer",
}
)
def _password_is_common(password: str) -> bool:
"""Case-insensitive blocklist check.
Lowercases the input so trivial mutations like ``Password`` /
``PASSWORD`` are also rejected. Does not normalize digit substitutions
(``p@ssw0rd`` is included as a literal entry instead) — keeping the
rule cheap and predictable.
"""
return password.lower() in _COMMON_PASSWORDS
def _validate_strong_password(value: str) -> str:
"""Pydantic field-validator body shared by Register + ChangePassword.
Constraint = function, not type-level mixin. The two request models
have no "is-a" relationship; they only share the password-strength
rule. Lifting it into a free function lets each model bind it via
``@field_validator(field_name)`` without inheritance gymnastics.
"""
if _password_is_common(value):
raise ValueError("Password is too common; choose a stronger password.")
return value
class RegisterRequest(BaseModel):
"""Request model for user registration."""
email: EmailStr
password: str = Field(..., min_length=8)
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
class ChangePasswordRequest(BaseModel):
"""Request model for password change (also handles setup flow)."""
current_password: str
new_password: str = Field(..., min_length=8)
new_email: EmailStr | None = None
_strong_password = field_validator("new_password")(classmethod(lambda cls, v: _validate_strong_password(v)))
class MessageResponse(BaseModel):
"""Generic message response."""
message: str
# ── Helpers ───────────────────────────────────────────────────────────────
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
"""Set the access_token HttpOnly cookie on the response."""
config = get_auth_config()
is_https = is_secure_request(request)
response.set_cookie(
key="access_token",
value=token,
httponly=True,
secure=is_https,
samesite="lax",
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
)
# ── Rate Limiting ────────────────────────────────────────────────────────
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
_MAX_LOGIN_ATTEMPTS = 5
_LOCKOUT_SECONDS = 300 # 5 minutes
# ip → (fail_count, lock_until_timestamp)
_login_attempts: dict[str, tuple[int, float]] = {}
def _trusted_proxies() -> list:
"""Parse ``AUTH_TRUSTED_PROXIES`` env var into a list of ip_network objects.
Comma-separated CIDR or single-IP entries. Empty / unset = no proxy is
trusted (direct mode). Invalid entries are skipped with a logger warning.
Read live so env-var overrides take effect immediately and tests can
``monkeypatch.setenv`` without poking a module-level cache.
"""
raw = os.getenv("AUTH_TRUSTED_PROXIES", "").strip()
if not raw:
return []
nets = []
for entry in raw.split(","):
entry = entry.strip()
if not entry:
continue
try:
nets.append(ip_network(entry, strict=False))
except ValueError:
logger.warning("AUTH_TRUSTED_PROXIES: ignoring invalid entry %r", entry)
return nets
def _get_client_ip(request: Request) -> str:
"""Extract the real client IP for rate limiting.
Trust model:
- The TCP peer (``request.client.host``) is always the baseline. It is
whatever the kernel reports as the connecting socket — unforgeable
by the client itself.
- ``X-Real-IP`` is **only** honored if the TCP peer is in the
``AUTH_TRUSTED_PROXIES`` allowlist (set via env var, comma-separated
CIDR or single IPs). When set, the gateway is assumed to be behind a
reverse proxy (nginx, Cloudflare, ALB, …) that overwrites
``X-Real-IP`` with the original client address.
- With no ``AUTH_TRUSTED_PROXIES`` set, ``X-Real-IP`` is silently
ignored — closing the bypass where any client could rotate the
header to dodge per-IP rate limits in dev / direct-gateway mode.
``X-Forwarded-For`` is intentionally NOT used because it is naturally
client-controlled at the *first* hop and the trust chain is harder to
audit per-request.
"""
peer_host = request.client.host if request.client else None
trusted = _trusted_proxies()
if trusted and peer_host:
try:
peer_ip = ip_address(peer_host)
if any(peer_ip in net for net in trusted):
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
except ValueError:
# peer_host wasn't a parseable IP (e.g. "unknown") — fall through
pass
return peer_host or "unknown"
def _check_rate_limit(ip: str) -> None:
"""Raise 429 if the IP is currently locked out."""
record = _login_attempts.get(ip)
if record is None:
return
fail_count, lock_until = record
if fail_count >= _MAX_LOGIN_ATTEMPTS:
if time.time() < lock_until:
raise HTTPException(
status_code=429,
detail="Too many login attempts. Try again later.",
)
del _login_attempts[ip]
_MAX_TRACKED_IPS = 10000
def _record_login_failure(ip: str) -> None:
"""Record a failed login attempt for the given IP."""
# Evict expired lockouts when dict grows too large
if len(_login_attempts) >= _MAX_TRACKED_IPS:
now = time.time()
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
for k in expired:
del _login_attempts[k]
# If still too large, evict cheapest-to-lose half: below-threshold
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
if len(_login_attempts) >= _MAX_TRACKED_IPS:
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
for k, _ in by_time[: len(by_time) // 2]:
del _login_attempts[k]
record = _login_attempts.get(ip)
if record is None:
_login_attempts[ip] = (1, 0.0)
else:
new_count = record[0] + 1
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
_login_attempts[ip] = (new_count, lock_until)
def _record_login_success(ip: str) -> None:
"""Clear failure counter for the given IP on successful login."""
_login_attempts.pop(ip, None)
# ── Endpoints ─────────────────────────────────────────────────────────────
@router.post("/login/local", response_model=LoginResponse)
async def login_local(
request: Request,
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
):
"""Local email/password login."""
client_ip = _get_client_ip(request)
_check_rate_limit(client_ip)
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
if user is None:
_record_login_failure(client_ip)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
)
_record_login_success(client_ip)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return LoginResponse(
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
needs_setup=user.needs_setup,
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(request: Request, response: Response, body: RegisterRequest):
"""Register a new user account (always 'user' role).
Admin is auto-created on first boot. This endpoint creates regular users.
Auto-login by setting the session cookie.
"""
try:
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
@router.post("/logout", response_model=MessageResponse)
async def logout(request: Request, response: Response):
"""Logout current user by clearing the cookie."""
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
return MessageResponse(message="Successfully logged out")
@router.post("/change-password", response_model=MessageResponse)
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
"""Change password for the currently authenticated user.
Also handles the first-boot setup flow:
- If new_email is provided, updates email (checks uniqueness)
- If user.needs_setup is True and new_email is given, clears needs_setup
- Always increments token_version to invalidate old sessions
- Re-issues session cookie with new token_version
"""
from app.gateway.auth.password import hash_password_async, verify_password_async
user = await get_current_user_from_request(request)
if user.password_hash is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
if not await verify_password_async(body.current_password, user.password_hash):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
provider = get_local_provider()
# Update email if provided
if body.new_email is not None:
existing = await provider.get_user_by_email(body.new_email)
if existing and str(existing.id) != str(user.id):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
user.email = body.new_email
# Update password + bump version
user.password_hash = await hash_password_async(body.new_password)
user.token_version += 1
# Clear setup flag if this is the setup flow
if user.needs_setup and body.new_email is not None:
user.needs_setup = False
await provider.update_user(user)
# Re-issue cookie with new token_version
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return MessageResponse(message="Password changed successfully")
@router.get("/me", response_model=UserResponse)
async def get_me(request: Request):
"""Get current authenticated user info."""
user = await get_current_user_from_request(request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
@router.get("/setup-status")
async def setup_status():
"""Check if an admin account exists. Returns needs_setup=True when no admin exists."""
admin_count = await get_local_provider().count_admin_users()
return {"needs_setup": admin_count == 0}
class InitializeAdminRequest(BaseModel):
"""Request model for first-boot admin account creation."""
email: EmailStr
password: str = Field(..., min_length=8)
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
@router.post("/initialize", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def initialize_admin(request: Request, response: Response, body: InitializeAdminRequest):
"""Create the first admin account on initial system setup.
Only callable when no admin exists. Returns 409 Conflict if an admin
already exists.
On success, the admin account is created with ``needs_setup=False`` and
the session cookie is set.
"""
admin_count = await get_local_provider().count_admin_users()
if admin_count > 0:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
)
try:
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="admin", needs_setup=False)
except ValueError:
# DB unique-constraint race: another concurrent request beat us.
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
@router.get("/oauth/{provider}")
async def oauth_login(provider: str):
"""Initiate OAuth login flow.
Redirects to the OAuth provider's authorization URL.
Currently a placeholder - requires OAuth provider implementation.
"""
if provider not in ["github", "google"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported OAuth provider: {provider}",
)
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="OAuth login not yet implemented",
)
@router.get("/callback/{provider}")
async def oauth_callback(provider: str, code: str, state: str):
"""OAuth callback endpoint.
Handles the OAuth provider's callback after user authorization.
Currently a placeholder.
"""
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="OAuth callback not yet implemented",
)
@@ -1,4 +1,8 @@
"""LangGraph-compatible run feedback endpoints."""
"""Feedback endpoints — create, list, stats, delete.
Allows users to submit thumbs-up/down feedback on runs,
optionally scoped to a specific message.
"""
from __future__ import annotations
@@ -8,12 +12,16 @@ from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from app.gateway.dependencies import get_feedback_repository, get_run_repository
from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id
from app.plugins.auth.security.dependencies import get_current_user_id
from app.gateway.authz import require_permission
from app.gateway.deps import get_current_user, get_feedback_repo, get_run_store
logger = logging.getLogger(__name__)
router = APIRouter(tags=["feedback"])
router = APIRouter(prefix="/api/threads", tags=["feedback"])
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class FeedbackCreateRequest(BaseModel):
@@ -22,11 +30,16 @@ class FeedbackCreateRequest(BaseModel):
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
class FeedbackUpsertRequest(BaseModel):
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
comment: str | None = Field(default=None, description="Optional text feedback")
class FeedbackResponse(BaseModel):
feedback_id: str
run_id: str
thread_id: str
owner_id: str | None = None
user_id: str | None = None
message_id: str | None = None
rating: int
comment: str | None = None
@@ -40,36 +53,85 @@ class FeedbackStatsResponse(BaseModel):
negative: int = 0
async def _validate_run_scope(thread_id: str, run_id: str, request: Request) -> None:
run_store = get_run_repository(request)
if resolve_request_user_id(request) is None:
run = await run_store.get(run_id, user_id=None)
else:
with bind_request_actor_context(request):
run = await run_store.get(run_id)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def upsert_feedback(
thread_id: str,
run_id: str,
body: FeedbackUpsertRequest,
request: Request,
) -> dict[str, Any]:
"""Create or update feedback for a run (idempotent)."""
if body.rating not in (1, -1):
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
user_id = await get_current_user(request)
run_store = get_run_store(request)
run = await run_store.get(run_id)
if run is None:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
if run.get("thread_id") != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
async def _get_current_user(request: Request) -> str | None:
"""Extract current user id from auth dependencies when available."""
return await get_current_user_id(request)
feedback_repo = get_feedback_repo(request)
return await feedback_repo.upsert(
run_id=run_id,
thread_id=thread_id,
rating=body.rating,
user_id=user_id,
comment=body.comment,
)
async def _create_feedback(
@router.delete("/{thread_id}/runs/{run_id}/feedback")
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_run_feedback(
thread_id: str,
run_id: str,
request: Request,
) -> dict[str, bool]:
"""Delete the current user's feedback for a run."""
user_id = await get_current_user(request)
feedback_repo = get_feedback_repo(request)
deleted = await feedback_repo.delete_by_run(
thread_id=thread_id,
run_id=run_id,
user_id=user_id,
)
if not deleted:
raise HTTPException(status_code=404, detail="No feedback found for this run")
return {"success": True}
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def create_feedback(
thread_id: str,
run_id: str,
body: FeedbackCreateRequest,
request: Request,
) -> dict[str, Any]:
"""Submit feedback (thumbs-up/down) for a run."""
if body.rating not in (1, -1):
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
await _validate_run_scope(thread_id, run_id, request)
user_id = await _get_current_user(request)
feedback_repo = get_feedback_repository(request)
user_id = await get_current_user(request)
# Validate run exists and belongs to thread
run_store = get_run_store(request)
run = await run_store.get(run_id)
if run is None:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
if run.get("thread_id") != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
feedback_repo = get_feedback_repo(request)
return await feedback_repo.create(
run_id=run_id,
thread_id=thread_id,
@@ -80,94 +142,41 @@ async def _create_feedback(
)
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
async def upsert_feedback(
thread_id: str,
run_id: str,
body: FeedbackCreateRequest,
request: Request,
) -> dict[str, Any]:
"""Create or replace the run-level feedback record."""
feedback_repo = get_feedback_repository(request)
user_id = await _get_current_user(request)
if user_id is not None:
return await feedback_repo.upsert(
run_id=run_id,
thread_id=thread_id,
rating=body.rating,
user_id=user_id,
comment=body.comment,
)
existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None)
for item in existing:
feedback_id = item.get("feedback_id")
if isinstance(feedback_id, str):
await feedback_repo.delete(feedback_id)
return await _create_feedback(thread_id, run_id, body, request)
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
async def create_feedback(
thread_id: str,
run_id: str,
body: FeedbackCreateRequest,
request: Request,
) -> dict[str, Any]:
"""Submit feedback for a run."""
return await _create_feedback(thread_id, run_id, body, request)
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
@require_permission("threads", "read", owner_check=True)
async def list_feedback(
thread_id: str,
run_id: str,
request: Request,
) -> list[dict[str, Any]]:
"""List all feedback for a run."""
feedback_repo = get_feedback_repository(request)
user_id = await _get_current_user(request)
return await feedback_repo.list_by_run(thread_id, run_id, user_id=user_id)
feedback_repo = get_feedback_repo(request)
return await feedback_repo.list_by_run(thread_id, run_id)
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
@require_permission("threads", "read", owner_check=True)
async def feedback_stats(
thread_id: str,
run_id: str,
request: Request,
) -> dict[str, Any]:
"""Get aggregated feedback stats for a run."""
feedback_repo = get_feedback_repository(request)
"""Get aggregated feedback stats (positive/negative counts) for a run."""
feedback_repo = get_feedback_repo(request)
return await feedback_repo.aggregate_by_run(thread_id, run_id)
@router.delete("/{thread_id}/runs/{run_id}/feedback")
async def delete_run_feedback(
thread_id: str,
run_id: str,
request: Request,
) -> dict[str, bool]:
"""Delete all feedback records for a run."""
feedback_repo = get_feedback_repository(request)
user_id = await _get_current_user(request)
if user_id is not None:
return {"success": await feedback_repo.delete_by_run(thread_id=thread_id, run_id=run_id, user_id=user_id)}
existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None)
for item in existing:
feedback_id = item.get("feedback_id")
if isinstance(feedback_id, str):
await feedback_repo.delete(feedback_id)
return {"success": True}
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_feedback(
thread_id: str,
run_id: str,
feedback_id: str,
request: Request,
) -> dict[str, bool]:
"""Delete a single feedback record."""
feedback_repo = get_feedback_repository(request)
"""Delete a feedback record."""
feedback_repo = get_feedback_repo(request)
# Verify feedback belongs to the specified thread/run before deleting
existing = await feedback_repo.get(feedback_id)
if existing is None:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
@@ -1,6 +0,0 @@
from .feedback import router as feedback_router
from .runs import router as runs_router
from .suggestions import router as suggestion_router
from .threads import router as threads_router
__all__ = ["feedback_router", "runs_router", "threads_router", "suggestion_router"]
@@ -1,501 +0,0 @@
"""LangGraph-compatible runs endpoints backed by RunsFacade."""
from __future__ import annotations
import json
from collections.abc import AsyncIterator
from typing import Literal
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field
from app.plugins.auth.security.actor_context import bind_request_actor_context
from app.gateway.services.runs.facade_factory import build_runs_facade_from_request
from app.gateway.services.runs.input import (
AdaptedRunRequest,
RunSpecBuilder,
UnsupportedRunFeatureError,
adapt_create_run_request,
adapt_create_stream_request,
adapt_create_wait_request,
adapt_join_stream_request,
adapt_join_wait_request,
)
from deerflow.runtime.runs.types import RunRecord, RunSpec
from deerflow.runtime.stream_bridge import JSONValue, StreamEvent
router = APIRouter(tags=["runs"])
class RunCreateRequest(BaseModel):
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
follow_up_to_run_id: str | None = Field(default=None, description="Lineage link to the prior run")
input: dict[str, JSONValue] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
command: dict[str, JSONValue] | None = Field(default=None, description="LangGraph Command")
metadata: dict[str, JSONValue] | None = Field(default=None, description="Run metadata")
config: dict[str, JSONValue] | None = Field(default=None, description="RunnableConfig overrides")
context: dict[str, JSONValue] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
webhook: str | None = Field(default=None, description="Completion callback URL")
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
checkpoint: dict[str, JSONValue] | None = Field(default=None, description="Full checkpoint object")
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
after_seconds: float | None = Field(default=None, description="Delayed execution")
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
class RunResponse(BaseModel):
run_id: str
thread_id: str
assistant_id: str | None = None
status: str
metadata: dict[str, JSONValue] = Field(default_factory=dict)
multitask_strategy: str = "reject"
created_at: str = ""
updated_at: str = ""
class RunDeleteResponse(BaseModel):
deleted: bool
class RunMessageResponse(BaseModel):
run_id: str
content: JSONValue
metadata: dict[str, JSONValue] = Field(default_factory=dict)
created_at: str
seq: int
class RunMessagesResponse(BaseModel):
data: list[RunMessageResponse]
hasMore: bool = False
def format_sse(event: str, data: JSONValue, *, event_id: str | None = None) -> str:
"""Format a single SSE frame."""
payload = json.dumps(data, default=str, ensure_ascii=False)
parts = [f"event: {event}", f"data: {payload}"]
if event_id:
parts.append(f"id: {event_id}")
parts.append("")
parts.append("")
return "\n".join(parts)
def _record_to_response(record: RunRecord) -> RunResponse:
return RunResponse(
run_id=record.run_id,
thread_id=record.thread_id,
assistant_id=record.assistant_id,
status=record.status,
metadata=record.metadata,
multitask_strategy=record.multitask_strategy,
created_at=record.created_at,
updated_at=record.updated_at,
)
def _trim_paginated_rows(
rows: list[dict],
*,
limit: int,
after_seq: int | None,
) -> tuple[list[dict], bool]:
has_more = len(rows) > limit
if not has_more:
return rows, False
if after_seq is not None:
return rows[:limit], True
return rows[-limit:], True
def _event_to_run_message(event: dict) -> RunMessageResponse:
return RunMessageResponse(
run_id=str(event["run_id"]),
content=event.get("content"),
metadata=dict(event.get("metadata") or {}),
created_at=str(event.get("created_at") or ""),
seq=int(event["seq"]),
)
async def _sse_consumer(
stream: AsyncIterator[StreamEvent],
request: Request,
*,
cancel_on_disconnect: bool,
cancel_run,
run_id: str,
) -> AsyncIterator[str]:
try:
async for event in stream:
if await request.is_disconnected():
break
if event.event == "__heartbeat__":
yield ": heartbeat\n\n"
continue
if event.event == "__end__":
yield format_sse("end", None, event_id=event.id or None)
return
if event.event == "__cancelled__":
yield format_sse("cancel", None, event_id=event.id or None)
return
yield format_sse(event.event, event.data, event_id=event.id or None)
finally:
if cancel_on_disconnect:
await cancel_run(run_id)
def _get_run_event_store(request: Request):
event_store = getattr(request.app.state, "run_event_store", None)
if event_store is None:
raise HTTPException(status_code=503, detail="Run event store not available")
return event_store
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
async def list_runs(
thread_id: str,
request: Request,
limit: int = 100,
offset: int = 0,
status: str | None = None,
) -> list[RunResponse]:
# Accepted for API compatibility; field projection is not implemented yet.
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
records = await facade.list_runs(thread_id)
if status is not None:
records = [record for record in records if record.status == status]
records = records[offset : offset + limit]
return [_record_to_response(record) for record in records]
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.get_run(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
return _record_to_response(record)
@router.get("/{thread_id}/runs/{run_id}/messages", response_model=RunMessagesResponse)
async def run_messages(
thread_id: str,
run_id: str,
request: Request,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> RunMessagesResponse:
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.get_run(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
event_store = _get_run_event_store(request)
with bind_request_actor_context(request):
rows = await event_store.list_messages_by_run(
thread_id,
run_id,
limit=limit + 1,
before_seq=before_seq,
after_seq=after_seq,
)
page, has_more = _trim_paginated_rows(rows, limit=limit, after_seq=after_seq)
return RunMessagesResponse(data=[_event_to_run_message(row) for row in page], hasMore=has_more)
def _build_spec(
*,
adapted: AdaptedRunRequest,
) -> RunSpec:
try:
return RunSpecBuilder().build(adapted)
except UnsupportedRunFeatureError as exc:
raise HTTPException(status_code=501, detail=str(exc)) from exc
@router.post("/{thread_id}/runs", response_model=RunResponse)
async def create_run(
thread_id: str,
body: RunCreateRequest,
request: Request,
) -> Response:
adapted = adapt_create_run_request(
thread_id=thread_id,
body=body.model_dump(),
headers=dict(request.headers),
query=dict(request.query_params),
)
spec = _build_spec(adapted=adapted)
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.create_background(spec)
return Response(
content=_record_to_response(record).model_dump_json(),
media_type="application/json",
headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"},
)
@router.post("/{thread_id}/runs/stream")
async def stream_run(
thread_id: str,
body: RunCreateRequest,
request: Request,
) -> StreamingResponse:
adapted = adapt_create_stream_request(
thread_id=thread_id,
body=body.model_dump(),
headers=dict(request.headers),
query=dict(request.query_params),
)
spec = _build_spec(adapted=adapted)
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record, stream = await facade.create_and_stream(spec)
return StreamingResponse(
_sse_consumer(
stream,
request,
cancel_on_disconnect=spec.on_disconnect == "cancel",
cancel_run=facade.cancel,
run_id=record.run_id,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
},
)
@router.post("/{thread_id}/runs/wait")
async def wait_run(
thread_id: str,
body: RunCreateRequest,
request: Request,
) -> Response:
adapted = adapt_create_wait_request(
thread_id=thread_id,
body=body.model_dump(),
headers=dict(request.headers),
query=dict(request.query_params),
)
spec = _build_spec(adapted=adapted)
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record, result = await facade.create_and_wait(spec)
return Response(
content=json.dumps(result, default=str, ensure_ascii=False),
media_type="application/json",
headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"},
)
@router.post("/runs", response_model=RunResponse)
async def create_stateless_run(body: RunCreateRequest, request: Request) -> Response:
adapted = adapt_create_run_request(
thread_id=None,
body=body.model_dump(),
headers=dict(request.headers),
query=dict(request.query_params),
)
spec = _build_spec(adapted=adapted)
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.create_background(spec)
return Response(
content=_record_to_response(record).model_dump_json(),
media_type="application/json",
headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"},
)
@router.post("/runs/stream")
async def create_stateless_stream_run(body: RunCreateRequest, request: Request) -> StreamingResponse:
adapted = adapt_create_stream_request(
thread_id=None,
body=body.model_dump(),
headers=dict(request.headers),
query=dict(request.query_params),
)
spec = _build_spec(adapted=adapted)
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record, stream = await facade.create_and_stream(spec)
return StreamingResponse(
_sse_consumer(
stream,
request,
cancel_on_disconnect=spec.on_disconnect == "cancel",
cancel_run=facade.cancel,
run_id=record.run_id,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}",
},
)
@router.post("/runs/wait")
async def wait_stateless_run(body: RunCreateRequest, request: Request) -> Response:
adapted = adapt_create_wait_request(
thread_id=None,
body=body.model_dump(),
headers=dict(request.headers),
query=dict(request.query_params),
)
spec = _build_spec(adapted=adapted)
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record, result = await facade.create_and_wait(spec)
return Response(
content=json.dumps(result, default=str, ensure_ascii=False),
media_type="application/json",
headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"},
)
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
async def stream_existing_run(
thread_id: str,
run_id: str,
request: Request,
action: Literal["interrupt", "rollback"] | None = None,
wait: bool = False,
cancel_on_disconnect: bool = False,
stream_mode: str | None = None,
) -> StreamingResponse | Response:
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.get_run(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
if action is not None:
with bind_request_actor_context(request):
cancelled = await facade.cancel(run_id, action=action)
if not cancelled:
raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable")
if wait:
with bind_request_actor_context(request):
await facade.join_wait(run_id)
return Response(status_code=204)
adapted = adapt_join_stream_request(
thread_id=thread_id,
run_id=run_id,
headers=dict(request.headers),
query=dict(request.query_params),
)
with bind_request_actor_context(request):
stream = await facade.join_stream(run_id, last_event_id=adapted.last_event_id)
return StreamingResponse(
_sse_consumer(
stream,
request,
cancel_on_disconnect=cancel_on_disconnect,
cancel_run=facade.cancel,
run_id=run_id,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.get("/{thread_id}/runs/{run_id}/join")
async def join_existing_run(
thread_id: str,
run_id: str,
request: Request,
cancel_on_disconnect: bool = False,
) -> JSONValue:
# Accepted for API compatibility; current join_wait path does not change
# behavior based on client disconnect.
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.get_run(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
adapted = adapt_join_wait_request(
thread_id=thread_id,
run_id=run_id,
headers=dict(request.headers),
query=dict(request.query_params),
)
with bind_request_actor_context(request):
return await facade.join_wait(run_id, last_event_id=adapted.last_event_id)
@router.post("/{thread_id}/runs/{run_id}/cancel")
async def cancel_existing_run(
thread_id: str,
run_id: str,
request: Request,
wait: bool = False,
action: Literal["interrupt", "rollback"] = "interrupt",
) -> JSONValue:
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.get_run(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
with bind_request_actor_context(request):
cancelled = await facade.cancel(run_id, action=action)
if not cancelled:
raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable")
if wait:
with bind_request_actor_context(request):
return await facade.join_wait(run_id)
return {}
@router.delete("/{thread_id}/runs/{run_id}", response_model=RunDeleteResponse)
async def delete_run(
thread_id: str,
run_id: str,
request: Request,
) -> RunDeleteResponse:
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.get_run(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
with bind_request_actor_context(request):
deleted = await facade.delete_run(run_id)
return RunDeleteResponse(deleted=deleted)
@@ -1,132 +0,0 @@
import json
import logging
from fastapi import APIRouter
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["suggestions"])
class SuggestionMessage(BaseModel):
role: str = Field(..., description="Message role: user|assistant")
content: str = Field(..., description="Message content as plain text")
class SuggestionsRequest(BaseModel):
messages: list[SuggestionMessage] = Field(..., description="Recent conversation messages")
n: int = Field(default=3, ge=1, le=5, description="Number of suggestions to generate")
model_name: str | None = Field(default=None, description="Optional model override")
class SuggestionsResponse(BaseModel):
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
def _strip_markdown_code_fence(text: str) -> str:
stripped = text.strip()
if not stripped.startswith("```"):
return stripped
lines = stripped.splitlines()
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
return "\n".join(lines[1:-1]).strip()
return stripped
def _parse_json_string_list(text: str) -> list[str] | None:
candidate = _strip_markdown_code_fence(text)
start = candidate.find("[")
end = candidate.rfind("]")
if start == -1 or end == -1 or end <= start:
return None
candidate = candidate[start : end + 1]
try:
data = json.loads(candidate)
except Exception:
return None
if not isinstance(data, list):
return None
out: list[str] = []
for item in data:
if not isinstance(item, str):
continue
s = item.strip()
if not s:
continue
out.append(s)
return out
def _extract_response_text(content: object) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict) and block.get("type") in {"text", "output_text"}:
text = block.get("text")
if isinstance(text, str):
parts.append(text)
return "\n".join(parts) if parts else ""
if content is None:
return ""
return str(content)
def _format_conversation(messages: list[SuggestionMessage]) -> str:
parts: list[str] = []
for m in messages:
role = m.role.strip().lower()
if role in ("user", "human"):
parts.append(f"User: {m.content.strip()}")
elif role in ("assistant", "ai"):
parts.append(f"Assistant: {m.content.strip()}")
else:
parts.append(f"{m.role}: {m.content.strip()}")
return "\n".join(parts).strip()
@router.post(
"/threads/{thread_id}/suggestions",
response_model=SuggestionsResponse,
summary="Generate Follow-up Questions",
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
)
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
if not request.messages:
return SuggestionsResponse(suggestions=[])
n = request.n
conversation = _format_conversation(request.messages)
if not conversation:
return SuggestionsResponse(suggestions=[])
system_instruction = (
"You are generating follow-up questions to help the user continue the conversation.\n"
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
"Requirements:\n"
"- Questions must be relevant to the preceding conversation.\n"
"- Questions must be written in the same language as the user.\n"
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
"- Do NOT include numbering, markdown, or any extra text.\n"
"- Output MUST be a JSON array of strings only.\n"
)
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try:
model = create_chat_model(name=request.model_name, thinking_enabled=False)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or []
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
cleaned = cleaned[:n]
return SuggestionsResponse(suggestions=cleaned)
except Exception as exc:
logger.exception("Failed to generate suggestions: thread_id=%s err=%s", thread_id, exc)
return SuggestionsResponse(suggestions=[])
@@ -1,455 +0,0 @@
"""Thread management endpoints.
Provides CRUD operations for threads and checkpoint state management.
"""
from __future__ import annotations
import logging
import time
import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from app.gateway.dependencies import CurrentCheckpointer, CurrentRunRepository, CurrentThreadMetaStorage
from app.infra.storage import ThreadMetaStorage
from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
logger = logging.getLogger(__name__)
router = APIRouter(tags=["threads"])
# ---------------------------------------------------------------------------
# Request / Response Models
# ---------------------------------------------------------------------------
class ThreadCreateRequest(BaseModel):
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
class ThreadSearchRequest(BaseModel):
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
offset: int = Field(default=0, ge=0, description="Pagination offset")
status: str | None = Field(default=None, description="Filter by thread status")
user_id: str | None = Field(default=None, description="Filter by user ID")
assistant_id: str | None = Field(default=None, description="Filter by assistant ID")
class ThreadResponse(BaseModel):
thread_id: str = Field(description="Unique thread identifier")
status: str = Field(default="idle", description="Thread status")
created_at: str = Field(default="", description="ISO timestamp")
updated_at: str = Field(default="", description="ISO timestamp")
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
values: dict[str, Any] = Field(default_factory=dict, description="Current state values")
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
class ThreadDeleteResponse(BaseModel):
success: bool
message: str
class ThreadStateUpdateRequest(BaseModel):
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
as_node: str | None = Field(default=None, description="Node identity for the update")
class ThreadStateResponse(BaseModel):
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
next: list[str] = Field(default_factory=list, description="Next nodes to execute")
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
class ThreadHistoryRequest(BaseModel):
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
before: str | None = Field(default=None, description="Cursor for pagination (checkpoint_id)")
class HistoryEntry(BaseModel):
checkpoint_id: str
parent_checkpoint_id: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
values: dict[str, Any] = Field(default_factory=dict)
created_at: str | None = None
next: list[str] = Field(default_factory=list)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def sanitize_log_param(value: str) -> str:
"""Strip control characters to prevent log injection."""
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
"""Delete local filesystem data for a thread."""
path_manager = paths or get_paths()
try:
path_manager.delete_thread_dir(thread_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
except FileNotFoundError:
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
except Exception as exc:
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
async def _thread_or_run_exists(
*,
request: Request,
thread_id: str,
thread_meta_storage: ThreadMetaStorage,
run_repo,
) -> bool:
request_user_id = resolve_request_user_id(request)
if request_user_id is None:
thread = await thread_meta_storage.get_thread(thread_id, user_id=None)
if thread is not None:
return True
runs = await run_repo.list_by_thread(thread_id, limit=1, user_id=None)
return bool(runs)
with bind_request_actor_context(request):
thread = await thread_meta_storage.get_thread(thread_id)
if thread is not None:
return True
runs = await run_repo.list_by_thread(thread_id, limit=1)
return bool(runs)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("", response_model=ThreadResponse)
async def create_thread(
body: ThreadCreateRequest,
request: Request,
thread_meta_storage: CurrentThreadMetaStorage,
) -> ThreadResponse:
"""Create a new thread."""
thread_id = body.thread_id or str(uuid.uuid4())
request_user_id = resolve_request_user_id(request)
if request_user_id is None:
existing = await thread_meta_storage.get_thread(thread_id, user_id=None)
else:
with bind_request_actor_context(request):
existing = await thread_meta_storage.get_thread(thread_id)
if existing is not None:
return ThreadResponse(
thread_id=thread_id,
status=existing.status,
created_at=existing.created_time.isoformat() if existing.created_time else "",
updated_at=existing.updated_time.isoformat() if existing.updated_time else "",
metadata=existing.metadata,
)
try:
if request_user_id is None:
created = await thread_meta_storage.ensure_thread(
thread_id=thread_id,
assistant_id=body.assistant_id,
metadata=body.metadata,
user_id=None,
)
else:
with bind_request_actor_context(request):
created = await thread_meta_storage.ensure_thread(
thread_id=thread_id,
assistant_id=body.assistant_id,
metadata=body.metadata,
)
except Exception:
logger.exception("Failed to create thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to create thread")
logger.info("Thread created: %s", sanitize_log_param(thread_id))
return ThreadResponse(
thread_id=thread_id,
status=created.status,
created_at=created.created_time.isoformat() if created.created_time else "",
updated_at=created.updated_time.isoformat() if created.updated_time else "",
metadata=created.metadata,
)
@router.post("/search", response_model=list[ThreadResponse])
async def search_threads(
body: ThreadSearchRequest,
request: Request,
thread_meta_storage: CurrentThreadMetaStorage,
) -> list[ThreadResponse]:
"""Search threads with filters."""
try:
request_user_id = resolve_request_user_id(request)
if request_user_id is None:
threads = await thread_meta_storage.search_threads(
metadata=body.metadata or None,
status=body.status,
user_id=body.user_id,
assistant_id=body.assistant_id,
limit=body.limit,
offset=body.offset,
)
else:
with bind_request_actor_context(request):
threads = await thread_meta_storage.search_threads(
metadata=body.metadata or None,
status=body.status,
assistant_id=body.assistant_id,
limit=body.limit,
offset=body.offset,
)
except Exception:
logger.exception("Failed to search threads")
raise HTTPException(status_code=500, detail="Failed to search threads")
return [
ThreadResponse(
thread_id=t.thread_id,
status=t.status,
created_at=t.created_time.isoformat() if t.created_time else "",
updated_at=t.updated_time.isoformat() if t.updated_time else "",
metadata=t.metadata,
values={"title": t.display_name} if t.display_name else {},
interrupts={},
)
for t in threads
]
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
async def delete_thread(
thread_id: str,
checkpointer: CurrentCheckpointer,
thread_meta_storage: CurrentThreadMetaStorage,
) -> ThreadDeleteResponse:
"""Delete a thread and all associated data."""
response = _delete_thread_data(thread_id)
# Remove checkpoints (best-effort)
try:
if hasattr(checkpointer, "adelete_thread"):
await checkpointer.adelete_thread(thread_id)
except Exception:
logger.debug("Could not delete checkpoints for thread %s", sanitize_log_param(thread_id))
# Remove thread_meta (best-effort)
try:
await thread_meta_storage.delete_thread(thread_id)
except Exception:
logger.debug("Could not delete thread_meta for %s", sanitize_log_param(thread_id))
return response
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
async def get_thread_state(
thread_id: str,
request: Request,
checkpointer: CurrentCheckpointer,
thread_meta_storage: CurrentThreadMetaStorage,
run_repo: CurrentRunRepository,
) -> ThreadStateResponse:
"""Get the latest state snapshot for a thread."""
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception:
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
if await _thread_or_run_exists(
request=request,
thread_id=thread_id,
thread_meta_storage=thread_meta_storage,
run_repo=run_repo,
):
return ThreadStateResponse()
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
channel_values = checkpoint.get("channel_values", {})
ckpt_config = getattr(checkpoint_tuple, "config", {}) or {}
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
parent_config = getattr(checkpoint_tuple, "parent_config", None)
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")]
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
return ThreadStateResponse(
values=serialize_channel_values(channel_values),
next=next_nodes,
tasks=tasks,
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_checkpoint_id,
metadata=metadata,
created_at=str(metadata.get("created_at", "")),
)
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
async def update_thread_state(
thread_id: str,
body: ThreadStateUpdateRequest,
checkpointer: CurrentCheckpointer,
thread_meta_storage: CurrentThreadMetaStorage,
) -> ThreadStateResponse:
"""Update thread state (human-in-the-loop or title rename)."""
read_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
if body.checkpoint_id:
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
try:
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
except Exception:
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
if body.values:
channel_values.update(body.values)
checkpoint["channel_values"] = channel_values
metadata["updated_at"] = time.time()
if body.as_node:
metadata["source"] = "update"
metadata["step"] = metadata.get("step", 0) + 1
metadata["writes"] = {body.as_node: body.values}
write_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
except Exception:
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to update thread state")
new_checkpoint_id: str | None = None
if isinstance(new_config, dict):
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
# Sync title to thread_meta
if body.values and "title" in body.values:
new_title = body.values["title"]
if new_title:
try:
await thread_meta_storage.sync_thread_title(
thread_id=thread_id,
title=new_title,
)
except Exception:
logger.debug("Failed to sync title for %s", sanitize_log_param(thread_id))
return ThreadStateResponse(
values=serialize_channel_values(channel_values),
next=[],
metadata=metadata,
checkpoint_id=new_checkpoint_id,
created_at=str(metadata.get("created_at", "")),
)
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
async def get_thread_history(
thread_id: str,
body: ThreadHistoryRequest,
request: Request,
checkpointer: CurrentCheckpointer,
thread_meta_storage: CurrentThreadMetaStorage,
run_repo: CurrentRunRepository,
) -> list[HistoryEntry]:
"""Get checkpoint history for a thread."""
config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
if body.before:
config["configurable"]["checkpoint_id"] = body.before
entries: list[HistoryEntry] = []
is_first = True
try:
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
ckpt_config = getattr(checkpoint_tuple, "config", {}) or {}
parent_config = getattr(checkpoint_tuple, "parent_config", None)
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
parent_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None
channel_values = checkpoint.get("channel_values", {})
values: dict[str, Any] = {}
if title := channel_values.get("title"):
values["title"] = title
if is_first and (messages := channel_values.get("messages")):
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
is_first = False
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")]
entries.append(
HistoryEntry(
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_id,
metadata=metadata,
values=values,
created_at=str(metadata.get("created_at", "")),
next=next_nodes,
)
)
except Exception:
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread history")
if not entries and await _thread_or_run_exists(
request=request,
thread_id=thread_id,
thread_meta_storage=thread_meta_storage,
run_repo=run_repo,
):
return []
return entries
+20 -12
View File
@@ -3,10 +3,12 @@ import logging
from pathlib import Path
from typing import Literal
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, Field
from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["mcp"])
@@ -69,7 +71,7 @@ class McpConfigUpdateRequest(BaseModel):
summary="Get MCP Configuration",
description="Retrieve the current Model Context Protocol (MCP) server configurations.",
)
async def get_mcp_configuration() -> McpConfigResponse:
async def get_mcp_configuration(config: AppConfig = Depends(get_config)) -> McpConfigResponse:
"""Get the current MCP configuration.
Returns:
@@ -90,9 +92,9 @@ async def get_mcp_configuration() -> McpConfigResponse:
}
```
"""
config = get_extensions_config()
ext = config.extensions
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in config.mcp_servers.items()})
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in ext.mcp_servers.items()})
@router.put(
@@ -101,7 +103,11 @@ async def get_mcp_configuration() -> McpConfigResponse:
summary="Update MCP Configuration",
description="Update Model Context Protocol (MCP) server configurations and save to file.",
)
async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfigResponse:
async def update_mcp_configuration(
request: McpConfigUpdateRequest,
http_request: Request,
config: AppConfig = Depends(get_config),
) -> McpConfigResponse:
"""Update the MCP configuration.
This will:
@@ -142,13 +148,13 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
# Load current config to preserve skills configuration
current_config = get_extensions_config()
# Use injected config to preserve skills configuration
current_ext = config.extensions
# Convert request to dict format for JSON serialization
config_data = {
"mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_ext.skills.items()},
}
# Write the configuration to file
@@ -160,9 +166,11 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
# NOTE: No need to reload/reset cache here - LangGraph Server (separate process)
# will detect config file changes via mtime and reinitialize MCP tools automatically
# Reload the configuration and update the global cache
reloaded_config = reload_extensions_config()
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_config.mcp_servers.items()})
# Reload the configuration and swap ``app.state.config`` so subsequent
# ``Depends(get_config)`` calls see the refreshed value.
reloaded = AppConfig.from_file()
http_request.app.state.config = reloaded
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded.extensions.mcp_servers.items()})
except Exception as e:
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
+85 -88
View File
@@ -1,9 +1,9 @@
"""Memory API router for retrieving and managing global memory data."""
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from app.plugins.auth.security.actor_context import bind_request_actor_context
from app.gateway.deps import get_config
from deerflow.agents.memory.updater import (
clear_memory_data,
create_memory_fact,
@@ -13,8 +13,8 @@ from deerflow.agents.memory.updater import (
reload_memory_data,
update_memory_fact,
)
from deerflow.config.memory_config import get_memory_config
from deerflow.runtime.actor_context import get_effective_user_id
from deerflow.config.app_config import AppConfig
from deerflow.runtime.user_context import get_effective_user_id
router = APIRouter(prefix="/api", tags=["memory"])
@@ -115,7 +115,7 @@ class MemoryStatusResponse(BaseModel):
summary="Get Memory Data",
description="Retrieve the current global memory data including user context, history, and facts.",
)
async def get_memory(request: Request) -> MemoryResponse:
async def get_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Get the current global memory data.
Returns:
@@ -149,9 +149,8 @@ async def get_memory(request: Request) -> MemoryResponse:
}
```
"""
with bind_request_actor_context(request):
memory_data = get_memory_data(user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
memory_data = get_memory_data(app_config.memory, user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@router.post(
@@ -161,7 +160,7 @@ async def get_memory(request: Request) -> MemoryResponse:
summary="Reload Memory Data",
description="Reload memory data from the storage file, refreshing the in-memory cache.",
)
async def reload_memory(request: Request) -> MemoryResponse:
async def reload_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Reload memory data from file.
This forces a reload of the memory data from the storage file,
@@ -170,9 +169,8 @@ async def reload_memory(request: Request) -> MemoryResponse:
Returns:
The reloaded memory data.
"""
with bind_request_actor_context(request):
memory_data = reload_memory_data(user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
memory_data = reload_memory_data(app_config.memory, user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@router.delete(
@@ -182,15 +180,14 @@ async def reload_memory(request: Request) -> MemoryResponse:
summary="Clear All Memory Data",
description="Delete all saved memory data and reset the memory structure to an empty state.",
)
async def clear_memory(request: Request) -> MemoryResponse:
async def clear_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Clear all persisted memory data."""
with bind_request_actor_context(request):
try:
memory_data = clear_memory_data(user_id=get_effective_user_id())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
try:
memory_data = clear_memory_data(app_config.memory, user_id=get_effective_user_id())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
return MemoryResponse(**memory_data)
return MemoryResponse(**memory_data)
@router.post(
@@ -200,22 +197,22 @@ async def clear_memory(request: Request) -> MemoryResponse:
summary="Create Memory Fact",
description="Create a single saved memory fact manually.",
)
async def create_memory_fact_endpoint(request: Request, payload: FactCreateRequest) -> MemoryResponse:
async def create_memory_fact_endpoint(request: FactCreateRequest, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Create a single fact manually."""
with bind_request_actor_context(request):
try:
memory_data = create_memory_fact(
content=payload.content,
category=payload.category,
confidence=payload.confidence,
user_id=get_effective_user_id(),
)
except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
try:
memory_data = create_memory_fact(
app_config.memory,
content=request.content,
category=request.category,
confidence=request.confidence,
user_id=get_effective_user_id(),
)
except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
return MemoryResponse(**memory_data)
return MemoryResponse(**memory_data)
@router.delete(
@@ -225,17 +222,16 @@ async def create_memory_fact_endpoint(request: Request, payload: FactCreateReque
summary="Delete Memory Fact",
description="Delete a single saved memory fact by its fact id.",
)
async def delete_memory_fact_endpoint(fact_id: str, request: Request) -> MemoryResponse:
async def delete_memory_fact_endpoint(fact_id: str, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Delete a single fact from memory by fact id."""
with bind_request_actor_context(request):
try:
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
try:
memory_data = delete_memory_fact(app_config.memory, fact_id, user_id=get_effective_user_id())
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
return MemoryResponse(**memory_data)
return MemoryResponse(**memory_data)
@router.patch(
@@ -245,25 +241,25 @@ async def delete_memory_fact_endpoint(fact_id: str, request: Request) -> MemoryR
summary="Patch Memory Fact",
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
)
async def update_memory_fact_endpoint(fact_id: str, request: Request, payload: FactPatchRequest) -> MemoryResponse:
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Partially update a single fact manually."""
with bind_request_actor_context(request):
try:
memory_data = update_memory_fact(
fact_id=fact_id,
content=payload.content,
category=payload.category,
confidence=payload.confidence,
user_id=get_effective_user_id(),
)
except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
try:
memory_data = update_memory_fact(
app_config.memory,
fact_id=fact_id,
content=request.content,
category=request.category,
confidence=request.confidence,
user_id=get_effective_user_id(),
)
except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
return MemoryResponse(**memory_data)
return MemoryResponse(**memory_data)
@router.get(
@@ -273,11 +269,10 @@ async def update_memory_fact_endpoint(fact_id: str, request: Request, payload: F
summary="Export Memory Data",
description="Export the current global memory data as JSON for backup or transfer.",
)
async def export_memory(request: Request) -> MemoryResponse:
async def export_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Export the current memory data."""
with bind_request_actor_context(request):
memory_data = get_memory_data(user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
memory_data = get_memory_data(app_config.memory, user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@router.post(
@@ -287,15 +282,14 @@ async def export_memory(request: Request) -> MemoryResponse:
summary="Import Memory Data",
description="Import and overwrite the current global memory data from a JSON payload.",
)
async def import_memory(request: Request, payload: MemoryResponse) -> MemoryResponse:
async def import_memory(request: MemoryResponse, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Import and persist memory data."""
with bind_request_actor_context(request):
try:
memory_data = import_memory_data(payload.model_dump(), user_id=get_effective_user_id())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
try:
memory_data = import_memory_data(app_config.memory, request.model_dump(), user_id=get_effective_user_id())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
return MemoryResponse(**memory_data)
return MemoryResponse(**memory_data)
@router.get(
@@ -304,7 +298,9 @@ async def import_memory(request: Request, payload: MemoryResponse) -> MemoryResp
summary="Get Memory Configuration",
description="Retrieve the current memory system configuration.",
)
async def get_memory_config_endpoint() -> MemoryConfigResponse:
async def get_memory_config_endpoint(
app_config: AppConfig = Depends(get_config),
) -> MemoryConfigResponse:
"""Get the memory system configuration.
Returns:
@@ -323,7 +319,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
}
```
"""
config = get_memory_config()
config = app_config.memory
return MemoryConfigResponse(
enabled=config.enabled,
storage_path=config.storage_path,
@@ -342,25 +338,26 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
summary="Get Memory Status",
description="Retrieve both memory configuration and current data in a single request.",
)
async def get_memory_status(request: Request) -> MemoryStatusResponse:
async def get_memory_status(
app_config: AppConfig = Depends(get_config),
) -> MemoryStatusResponse:
"""Get the memory system status including configuration and data.
Returns:
Combined memory configuration and current data.
"""
with bind_request_actor_context(request):
config = get_memory_config()
memory_data = get_memory_data(user_id=get_effective_user_id())
config = app_config.memory
memory_data = get_memory_data(config, user_id=get_effective_user_id())
return MemoryStatusResponse(
config=MemoryConfigResponse(
enabled=config.enabled,
storage_path=config.storage_path,
debounce_seconds=config.debounce_seconds,
max_facts=config.max_facts,
fact_confidence_threshold=config.fact_confidence_threshold,
injection_enabled=config.injection_enabled,
max_injection_tokens=config.max_injection_tokens,
),
data=MemoryResponse(**memory_data),
)
return MemoryStatusResponse(
config=MemoryConfigResponse(
enabled=config.enabled,
storage_path=config.storage_path,
debounce_seconds=config.debounce_seconds,
max_facts=config.max_facts,
fact_confidence_threshold=config.fact_confidence_threshold,
injection_enabled=config.injection_enabled,
max_injection_tokens=config.max_injection_tokens,
),
data=MemoryResponse(**memory_data),
)
+27 -11
View File
@@ -1,7 +1,8 @@
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from deerflow.config import get_app_config
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
router = APIRouter(prefix="/api", tags=["models"])
@@ -17,10 +18,17 @@ class ModelResponse(BaseModel):
supports_reasoning_effort: bool = Field(default=False, description="Whether model supports reasoning effort")
class TokenUsageResponse(BaseModel):
"""Token usage display configuration."""
enabled: bool = Field(default=False, description="Whether token usage display is enabled")
class ModelsListResponse(BaseModel):
"""Response model for listing all models."""
models: list[ModelResponse]
token_usage: TokenUsageResponse
@router.get(
@@ -29,14 +37,14 @@ class ModelsListResponse(BaseModel):
summary="List All Models",
description="Retrieve a list of all available AI models configured in the system.",
)
async def list_models() -> ModelsListResponse:
async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResponse:
"""List all available models from configuration.
Returns model information suitable for frontend display,
excluding sensitive fields like API keys and internal configuration.
Returns:
A list of all configured models with their metadata.
A list of all configured models with their metadata and token usage display settings.
Example Response:
```json
@@ -44,21 +52,27 @@ async def list_models() -> ModelsListResponse:
"models": [
{
"name": "gpt-4",
"model": "gpt-4",
"display_name": "GPT-4",
"description": "OpenAI GPT-4 model",
"supports_thinking": false
"supports_thinking": false,
"supports_reasoning_effort": false
},
{
"name": "claude-3-opus",
"model": "claude-3-opus",
"display_name": "Claude 3 Opus",
"description": "Anthropic Claude 3 Opus model",
"supports_thinking": true
"supports_thinking": true,
"supports_reasoning_effort": false
}
]
],
"token_usage": {
"enabled": true
}
}
```
"""
config = get_app_config()
models = [
ModelResponse(
name=model.name,
@@ -70,7 +84,10 @@ async def list_models() -> ModelsListResponse:
)
for model in config.models
]
return ModelsListResponse(models=models)
return ModelsListResponse(
models=models,
token_usage=TokenUsageResponse(enabled=config.token_usage.enabled),
)
@router.get(
@@ -79,7 +96,7 @@ async def list_models() -> ModelsListResponse:
summary="Get Model Details",
description="Retrieve detailed information about a specific AI model by its name.",
)
async def get_model(model_name: str) -> ModelResponse:
async def get_model(model_name: str, config: AppConfig = Depends(get_config)) -> ModelResponse:
"""Get a specific model by name.
Args:
@@ -101,7 +118,6 @@ async def get_model(model_name: str) -> ModelResponse:
}
```
"""
config = get_app_config()
model = config.get_model_config(model_name)
if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
+142
View File
@@ -0,0 +1,142 @@
"""Stateless runs endpoints -- stream and wait without a pre-existing thread.
These endpoints auto-create a temporary thread when no ``thread_id`` is
supplied in the request body. When a ``thread_id`` **is** provided, it
is reused so that conversation history is preserved across calls.
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.routers.thread_runs import RunCreateRequest
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import serialize_channel_values
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/runs", tags=["runs"])
def _resolve_thread_id(body: RunCreateRequest) -> str:
"""Return the thread_id from the request body, or generate a new one."""
thread_id = (body.config or {}).get("configurable", {}).get("thread_id")
if thread_id:
return str(thread_id)
return str(uuid.uuid4())
@router.post("/stream")
async def stateless_stream(body: RunCreateRequest, request: Request) -> StreamingResponse:
"""Create a run and stream events via SSE.
If ``config.configurable.thread_id`` is provided, the run is created
on the given thread so that conversation history is preserved.
Otherwise a new temporary thread is created.
"""
thread_id = _resolve_thread_id(body)
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = await start_run(body, thread_id, request)
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
},
)
@router.post("/wait", response_model=dict)
async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until completion.
If ``config.configurable.thread_id`` is provided, the run is created
on the given thread so that conversation history is preserved.
Otherwise a new temporary thread is created.
"""
thread_id = _resolve_thread_id(body)
record = await start_run(body, thread_id, request)
if record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}
# ---------------------------------------------------------------------------
# Run-scoped read endpoints
# ---------------------------------------------------------------------------
async def _resolve_run(run_id: str, request: Request) -> dict:
"""Fetch run by run_id with user ownership check. Raises 404 if not found."""
run_store = get_run_store(request)
record = await run_store.get(run_id) # user_id=AUTO filters by contextvar
if record is None:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
return record
@router.get("/{run_id}/messages")
@require_permission("runs", "read")
async def run_messages(
run_id: str,
request: Request,
limit: int = Query(default=50, le=200, ge=1),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> dict:
"""Return paginated messages for a run (cursor-based).
Pagination:
- after_seq: messages with seq > after_seq (forward)
- before_seq: messages with seq < before_seq (backward)
- neither: latest messages
Response: { data: [...], has_more: bool }
"""
run = await _resolve_run(run_id, request)
event_store = get_run_event_store(request)
rows = await event_store.list_messages_by_run(
run["thread_id"], run_id,
limit=limit + 1,
before_seq=before_seq,
after_seq=after_seq,
)
has_more = len(rows) > limit
data = rows[:limit] if has_more else rows
return {"data": data, "has_more": has_more}
@router.get("/{run_id}/feedback")
@require_permission("runs", "read")
async def run_feedback(run_id: str, request: Request) -> list[dict]:
"""Return all feedback for a run."""
run = await _resolve_run(run_id, request)
feedback_repo = get_feedback_repo(request)
return await feedback_repo.list_by_run(run["thread_id"], run_id)
+87 -57
View File
@@ -1,14 +1,17 @@
import errno
import json
import logging
import shutil
from pathlib import Path
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, Field
from app.gateway.deps import get_config
from app.gateway.path_utils import resolve_thread_virtual_path
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
from deerflow.skills.manager import (
@@ -100,9 +103,9 @@ def _skill_to_response(skill: Skill) -> SkillResponse:
summary="List All Skills",
description="Retrieve a list of all available skills from both public and custom directories.",
)
async def list_skills() -> SkillsListResponse:
async def list_skills(app_config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(app_config, enabled_only=False)
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error(f"Failed to load skills: {e}", exc_info=True)
@@ -115,11 +118,11 @@ async def list_skills() -> SkillsListResponse:
summary="Install Skill",
description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.",
)
async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
async def install_skill(request: SkillInstallRequest, app_config: AppConfig = Depends(get_config)) -> SkillInstallResponse:
try:
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
result = install_skill_from_archive(skill_file_path)
await refresh_skills_system_prompt_cache_async()
await refresh_skills_system_prompt_cache_async(app_config)
return SkillInstallResponse(**result)
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -135,9 +138,9 @@ async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
async def list_custom_skills() -> SkillsListResponse:
async def list_custom_skills(app_config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try:
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
skills = [skill for skill in load_skills(app_config, enabled_only=False) if skill.category == "custom"]
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error("Failed to list custom skills: %s", e, exc_info=True)
@@ -145,13 +148,13 @@ async def list_custom_skills() -> SkillsListResponse:
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
async def get_custom_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(app_config, enabled_only=False)
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
if skill is None:
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name))
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name, app_config))
except HTTPException:
raise
except Exception as e:
@@ -160,14 +163,18 @@ async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse:
async def update_custom_skill(
skill_name: str,
request: CustomSkillUpdateRequest,
app_config: AppConfig = Depends(get_config),
) -> CustomSkillContentResponse:
try:
ensure_custom_skill_is_editable(skill_name)
ensure_custom_skill_is_editable(skill_name, app_config)
validate_skill_markdown_content(skill_name, request.content)
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md")
scan = await scan_skill_content(app_config, request.content, executable=False, location=f"{skill_name}/SKILL.md")
if scan.decision == "block":
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md"
skill_file = get_custom_skill_dir(skill_name, app_config) / "SKILL.md"
prev_content = skill_file.read_text(encoding="utf-8")
atomic_write(skill_file, request.content)
append_history(
@@ -181,9 +188,10 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
"new_content": request.content,
"scanner": {"decision": scan.decision, "reason": scan.reason},
},
app_config,
)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
await refresh_skills_system_prompt_cache_async(app_config)
return await get_custom_skill(skill_name, app_config)
except HTTPException:
raise
except FileNotFoundError as e:
@@ -196,25 +204,31 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
async def delete_custom_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> dict[str, bool]:
try:
ensure_custom_skill_is_editable(skill_name)
skill_dir = get_custom_skill_dir(skill_name)
prev_content = read_custom_skill_content(skill_name)
append_history(
skill_name,
{
"action": "human_delete",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"prev_content": prev_content,
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
)
ensure_custom_skill_is_editable(skill_name, app_config)
skill_dir = get_custom_skill_dir(skill_name, app_config)
prev_content = read_custom_skill_content(skill_name, app_config)
try:
append_history(
skill_name,
{
"action": "human_delete",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"prev_content": prev_content,
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
app_config,
)
except OSError as e:
if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}:
raise
logger.warning("Skipping delete history write for custom skill %s due to readonly/permission failure; continuing with skill directory removal: %s", skill_name, e)
shutil.rmtree(skill_dir)
await refresh_skills_system_prompt_cache_async()
await refresh_skills_system_prompt_cache_async(app_config)
return {"success": True}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -226,11 +240,11 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse:
async def get_custom_skill_history(skill_name: str, app_config: AppConfig = Depends(get_config)) -> CustomSkillHistoryResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
if not custom_skill_exists(skill_name, app_config) and not get_skill_history_file(skill_name, app_config).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillHistoryResponse(history=read_history(skill_name))
return CustomSkillHistoryResponse(history=read_history(skill_name, app_config))
except HTTPException:
raise
except Exception as e:
@@ -239,11 +253,15 @@ async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryRespons
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse:
async def rollback_custom_skill(
skill_name: str,
request: SkillRollbackRequest,
app_config: AppConfig = Depends(get_config),
) -> CustomSkillContentResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
if not custom_skill_exists(skill_name, app_config) and not get_skill_history_file(skill_name, app_config).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
history = read_history(skill_name)
history = read_history(skill_name, app_config)
if not history:
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
record = history[request.history_index]
@@ -251,8 +269,8 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
if target_content is None:
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
validate_skill_markdown_content(skill_name, target_content)
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md")
skill_file = get_custom_skill_file(skill_name)
scan = await scan_skill_content(app_config, target_content, executable=False, location=f"{skill_name}/SKILL.md")
skill_file = get_custom_skill_file(skill_name, app_config)
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
history_entry = {
"action": "rollback",
@@ -265,12 +283,12 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
"scanner": {"decision": scan.decision, "reason": scan.reason},
}
if scan.decision == "block":
append_history(skill_name, history_entry)
append_history(skill_name, history_entry, app_config)
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
atomic_write(skill_file, target_content)
append_history(skill_name, history_entry)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
append_history(skill_name, history_entry, app_config)
await refresh_skills_system_prompt_cache_async(app_config)
return await get_custom_skill(skill_name, app_config)
except HTTPException:
raise
except IndexError:
@@ -290,9 +308,9 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
summary="Get Skill Details",
description="Retrieve detailed information about a specific skill by its name.",
)
async def get_skill(skill_name: str) -> SkillResponse:
async def get_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(app_config, enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
@@ -312,9 +330,14 @@ async def get_skill(skill_name: str) -> SkillResponse:
summary="Update Skill",
description="Update a skill's enabled status by modifying the extensions_config.json file.",
)
async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillResponse:
async def update_skill(
skill_name: str,
request: SkillUpdateRequest,
http_request: Request,
app_config: AppConfig = Depends(get_config),
) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(app_config, enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
@@ -325,22 +348,29 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
extensions_config = get_extensions_config()
extensions_config.skills[skill_name] = SkillStateConfig(enabled=request.enabled)
# Do not mutate the frozen AppConfig in place. Compose the new skills
# state in a fresh dict, write to disk, and reload AppConfig below so
# every subsequent Depends(get_config) sees the refreshed snapshot.
ext = app_config.extensions
updated_skills = {name: {"enabled": skill_config.enabled} for name, skill_config in ext.skills.items()}
updated_skills[skill_name] = {"enabled": request.enabled}
config_data = {
"mcpServers": {name: server.model_dump() for name, server in extensions_config.mcp_servers.items()},
"skills": {name: {"enabled": skill_config.enabled} for name, skill_config in extensions_config.skills.items()},
"mcpServers": {name: server.model_dump() for name, server in ext.mcp_servers.items()},
"skills": updated_skills,
}
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
logger.info(f"Skills configuration updated and saved to: {config_path}")
reload_extensions_config()
await refresh_skills_system_prompt_cache_async()
# Reload AppConfig and swap ``app.state.config`` so subsequent
# ``Depends(get_config)`` sees the refreshed value.
reloaded = AppConfig.from_file()
http_request.app.state.config = reloaded
await refresh_skills_system_prompt_cache_async(reloaded)
skills = load_skills(enabled_only=False)
skills = load_skills(reloaded, enabled_only=False)
updated_skill = next((s for s in skills if s.name == skill_name), None)
if updated_skill is None:
+8 -4
View File
@@ -1,10 +1,13 @@
import json
import logging
from fastapi import APIRouter, Request
from fastapi import APIRouter, Depends, Request
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@@ -98,7 +101,8 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
summary="Generate Follow-up Questions",
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
)
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
@require_permission("threads", "read", owner_check=True)
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request, app_config: AppConfig = Depends(get_config)) -> SuggestionsResponse:
if not body.messages:
return SuggestionsResponse(suggestions=[])
@@ -120,8 +124,8 @@ async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try:
model = create_chat_model(name=body.model_name, thinking_enabled=False)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
model = create_chat_model(name=body.model_name, thinking_enabled=False, app_config=app_config)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)], config={"run_name": "suggest_agent"})
raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or []
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
+373
View File
@@ -0,0 +1,373 @@
"""Runs endpoints — create, stream, wait, cancel.
Implements the LangGraph Platform runs API on top of
:class:`deerflow.agents.runs.RunManager` and
:class:`deerflow.agents.stream_bridge.StreamBridge`.
SSE format is aligned with the LangGraph Platform protocol so that
the ``useStream`` React hook from ``@langchain/langgraph-sdk/react``
works without modification.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Literal
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import RunRecord, serialize_channel_values
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["runs"])
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class RunCreateRequest(BaseModel):
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
input: dict[str, Any] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
webhook: str | None = Field(default=None, description="Completion callback URL")
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
after_seconds: float | None = Field(default=None, description="Delayed execution")
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
follow_up_to_run_id: str | None = Field(default=None, description="Run ID this message follows up on. Auto-detected from latest successful run if not provided.")
class RunResponse(BaseModel):
run_id: str
thread_id: str
assistant_id: str | None = None
status: str
metadata: dict[str, Any] = Field(default_factory=dict)
kwargs: dict[str, Any] = Field(default_factory=dict)
multitask_strategy: str = "reject"
created_at: str = ""
updated_at: str = ""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _record_to_response(record: RunRecord) -> RunResponse:
return RunResponse(
run_id=record.run_id,
thread_id=record.thread_id,
assistant_id=record.assistant_id,
status=record.status.value,
metadata=record.metadata,
kwargs=record.kwargs,
multitask_strategy=record.multitask_strategy,
created_at=record.created_at,
updated_at=record.updated_at,
)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/{thread_id}/runs", response_model=RunResponse)
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
"""Create a background run (returns immediately)."""
record = await start_run(body, thread_id, request)
return _record_to_response(record)
@router.post("/{thread_id}/runs/stream")
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
"""Create a run and stream events via SSE.
The response includes a ``Content-Location`` header with the run's
resource URL, matching the LangGraph Platform protocol. The
``useStream`` React hook uses this to extract run metadata.
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = await start_run(body, thread_id, request)
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
# LangGraph Platform includes run metadata in this header.
# The SDK uses a greedy regex to extract the run id from this path,
# so it must point at the canonical run resource without extra suffixes.
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
},
)
@router.post("/{thread_id}/runs/wait", response_model=dict)
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until it completes, returning the final state."""
record = await start_run(body, thread_id, request)
if record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
@require_permission("runs", "read", owner_check=True)
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
"""List all runs for a thread."""
run_mgr = get_run_manager(request)
records = await run_mgr.list_by_thread(thread_id)
return [_record_to_response(r) for r in records]
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
@require_permission("runs", "read", owner_check=True)
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
"""Get details of a specific run."""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
return _record_to_response(record)
@router.post("/{thread_id}/runs/{run_id}/cancel")
@require_permission("runs", "cancel", owner_check=True, require_existing=True)
async def cancel_run(
thread_id: str,
run_id: str,
request: Request,
wait: bool = Query(default=False, description="Block until run completes after cancel"),
action: Literal["interrupt", "rollback"] = Query(default="interrupt", description="Cancel action"),
) -> Response:
"""Cancel a running or pending run.
- action=interrupt: Stop execution, keep current checkpoint (can be resumed)
- action=rollback: Stop execution, revert to pre-run checkpoint state
- wait=true: Block until the run fully stops, return 204
- wait=false: Return immediately with 202
"""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
cancelled = await run_mgr.cancel(run_id, action=action)
if not cancelled:
raise HTTPException(
status_code=409,
detail=f"Run {run_id} is not cancellable (status: {record.status.value})",
)
if wait and record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
return Response(status_code=204)
return Response(status_code=202)
@router.get("/{thread_id}/runs/{run_id}/join")
@require_permission("runs", "read", owner_check=True)
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
"""Join an existing run's SSE stream."""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
@require_permission("runs", "read", owner_check=True)
async def stream_existing_run(
thread_id: str,
run_id: str,
request: Request,
action: Literal["interrupt", "rollback"] | None = Query(default=None, description="Cancel action"),
wait: int = Query(default=0, description="Block until cancelled (1) or return immediately (0)"),
):
"""Join an existing run's SSE stream (GET), or cancel-then-stream (POST).
The LangGraph SDK's ``joinStream`` and ``useStream`` stop button both use
``POST`` to this endpoint. When ``action=interrupt`` or ``action=rollback``
is present the run is cancelled first; the response then streams any
remaining buffered events so the client observes a clean shutdown.
"""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
# Cancel if an action was requested (stop-button / interrupt flow)
if action is not None:
cancelled = await run_mgr.cancel(run_id, action=action)
if cancelled and wait and record.task is not None:
try:
await record.task
except (asyncio.CancelledError, Exception):
pass
return Response(status_code=204)
bridge = get_stream_bridge(request)
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# ---------------------------------------------------------------------------
# Messages / Events / Token usage endpoints
# ---------------------------------------------------------------------------
@router.get("/{thread_id}/messages")
@require_permission("runs", "read", owner_check=True)
async def list_thread_messages(
thread_id: str,
request: Request,
limit: int = Query(default=50, le=200),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> list[dict]:
"""Return displayable messages for a thread (across all runs), with feedback attached."""
event_store = get_run_event_store(request)
messages = await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
# Attach feedback to the last AI message of each run
feedback_repo = get_feedback_repo(request)
user_id = await get_current_user(request)
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
# Find the last ai_message per run_id
last_ai_per_run: dict[str, int] = {} # run_id -> index in messages list
for i, msg in enumerate(messages):
if msg.get("event_type") == "ai_message":
last_ai_per_run[msg["run_id"]] = i
# Attach feedback field
last_ai_indices = set(last_ai_per_run.values())
for i, msg in enumerate(messages):
if i in last_ai_indices:
run_id = msg["run_id"]
fb = feedback_map.get(run_id)
msg["feedback"] = {
"feedback_id": fb["feedback_id"],
"rating": fb["rating"],
"comment": fb.get("comment"),
} if fb else None
else:
msg["feedback"] = None
return messages
@router.get("/{thread_id}/runs/{run_id}/messages")
@require_permission("runs", "read", owner_check=True)
async def list_run_messages(
thread_id: str,
run_id: str,
request: Request,
limit: int = Query(default=50, le=200, ge=1),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> dict:
"""Return paginated messages for a specific run.
Response: { data: [...], has_more: bool }
"""
event_store = get_run_event_store(request)
rows = await event_store.list_messages_by_run(
thread_id, run_id,
limit=limit + 1,
before_seq=before_seq,
after_seq=after_seq,
)
has_more = len(rows) > limit
data = rows[:limit] if has_more else rows
return {"data": data, "has_more": has_more}
@router.get("/{thread_id}/runs/{run_id}/events")
@require_permission("runs", "read", owner_check=True)
async def list_run_events(
thread_id: str,
run_id: str,
request: Request,
event_types: str | None = Query(default=None),
limit: int = Query(default=500, le=2000),
) -> list[dict]:
"""Return the full event stream for a run (debug/audit)."""
event_store = get_run_event_store(request)
types = event_types.split(",") if event_types else None
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
@router.get("/{thread_id}/token-usage")
@require_permission("threads", "read", owner_check=True)
async def thread_token_usage(thread_id: str, request: Request) -> dict:
"""Thread-level token usage aggregation."""
run_store = get_run_store(request)
agg = await run_store.aggregate_tokens_by_thread(thread_id)
return {"thread_id": thread_id, **agg}
+621
View File
@@ -0,0 +1,621 @@
"""Thread CRUD, state, and history endpoints.
Combines the existing thread-local filesystem cleanup with LangGraph
Platform-compatible thread management backed by the checkpointer.
Channel values returned in state responses are serialized through
:func:`deerflow.runtime.serialization.serialize_channel_values` to
ensure LangChain message objects are converted to JSON-safe dicts
matching the LangGraph Platform wire format expected by the
``useStream`` React hook.
"""
from __future__ import annotations
import logging
import re
import time
import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field, field_validator
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer
from app.gateway.utils import sanitize_log_param
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["threads"])
# Metadata keys that the server controls; clients are not allowed to set
# them. Pydantic ``@field_validator("metadata")`` strips them on every
# inbound model below so a malicious client cannot reflect a forged
# owner identity through the API surface. Defense-in-depth — the
# row-level invariant is still ``threads_meta.user_id`` populated from
# the auth contextvar; this list closes the metadata-blob echo gap.
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
def _strip_reserved_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]:
"""Return ``metadata`` with server-controlled keys removed."""
if not metadata:
return metadata or {}
return {k: v for k, v in metadata.items() if k not in _SERVER_RESERVED_METADATA_KEYS}
# ---------------------------------------------------------------------------
# Response / request models
# ---------------------------------------------------------------------------
class ThreadDeleteResponse(BaseModel):
"""Response model for thread cleanup."""
success: bool
message: str
class ThreadResponse(BaseModel):
"""Response model for a single thread."""
thread_id: str = Field(description="Unique thread identifier")
status: str = Field(default="idle", description="Thread status: idle, busy, interrupted, error")
created_at: str = Field(default="", description="ISO timestamp")
updated_at: str = Field(default="", description="ISO timestamp")
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
values: dict[str, Any] = Field(default_factory=dict, description="Current state channel values")
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
class ThreadCreateRequest(BaseModel):
"""Request body for creating a thread."""
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
class ThreadSearchRequest(BaseModel):
"""Request body for searching threads."""
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
offset: int = Field(default=0, ge=0, description="Pagination offset")
status: str | None = Field(default=None, description="Filter by thread status")
class ThreadStateResponse(BaseModel):
"""Response model for thread state."""
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
next: list[str] = Field(default_factory=list, description="Next tasks to execute")
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
class ThreadPatchRequest(BaseModel):
"""Request body for patching thread metadata."""
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge")
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
class ThreadStateUpdateRequest(BaseModel):
"""Request body for updating thread state (human-in-the-loop resume)."""
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
as_node: str | None = Field(default=None, description="Node identity for the update")
class HistoryEntry(BaseModel):
"""Single checkpoint history entry."""
checkpoint_id: str
parent_checkpoint_id: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
values: dict[str, Any] = Field(default_factory=dict)
created_at: str | None = None
next: list[str] = Field(default_factory=list)
class ThreadHistoryRequest(BaseModel):
"""Request body for checkpoint history."""
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
before: str | None = Field(default=None, description="Cursor for pagination")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _delete_thread_data(thread_id: str, paths: Paths | None = None, *, user_id: str | None = None) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread."""
path_manager = paths or get_paths()
try:
path_manager.delete_thread_dir(thread_id, user_id=user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
except FileNotFoundError:
# Not critical — thread data may not exist on disk
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
except Exception as exc:
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
def _derive_thread_status(checkpoint_tuple) -> str:
"""Derive thread status from checkpoint metadata."""
if checkpoint_tuple is None:
return "idle"
pending_writes = getattr(checkpoint_tuple, "pending_writes", None) or []
# Check for error in pending writes
for pw in pending_writes:
if len(pw) >= 2 and pw[1] == "__error__":
return "error"
# Check for pending next tasks (indicates interrupt)
tasks = getattr(checkpoint_tuple, "tasks", None)
if tasks:
return "interrupted"
return "idle"
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread.
Cleans DeerFlow-managed thread directories, removes checkpoint data,
and removes the thread_meta row from the configured ThreadMetaStore
(sqlite or memory).
"""
from app.gateway.deps import get_thread_store
# Clean local filesystem
response = _delete_thread_data(thread_id, user_id=get_effective_user_id())
# Remove checkpoints (best-effort)
checkpointer = getattr(request.app.state, "checkpointer", None)
if checkpointer is not None:
try:
if hasattr(checkpointer, "adelete_thread"):
await checkpointer.adelete_thread(thread_id)
except Exception:
logger.debug("Could not delete checkpoints for thread %s (not critical)", sanitize_log_param(thread_id))
# Remove thread_meta row (best-effort) — required for sqlite backend
# so the deleted thread no longer appears in /threads/search.
try:
thread_store = get_thread_store(request)
await thread_store.delete(thread_id)
except Exception:
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
return response
@router.post("", response_model=ThreadResponse)
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
"""Create a new thread.
Writes a thread_meta record (so the thread appears in /threads/search)
and an empty checkpoint (so state endpoints work immediately).
Idempotent: returns the existing record when ``thread_id`` already exists.
"""
from app.gateway.deps import get_thread_store
checkpointer = get_checkpointer(request)
thread_store = get_thread_store(request)
thread_id = body.thread_id or str(uuid.uuid4())
now = time.time()
# ``body.metadata`` is already stripped of server-reserved keys by
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
# Idempotency: return existing record when already present
existing_record = await thread_store.get(thread_id)
if existing_record is not None:
return ThreadResponse(
thread_id=thread_id,
status=existing_record.get("status", "idle"),
created_at=str(existing_record.get("created_at", "")),
updated_at=str(existing_record.get("updated_at", "")),
metadata=existing_record.get("metadata", {}),
)
# Write thread_meta so the thread appears in /threads/search immediately
try:
await thread_store.create(
thread_id,
assistant_id=getattr(body, "assistant_id", None),
metadata=body.metadata,
)
except Exception:
logger.exception("Failed to write thread_meta for %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to create thread")
# Write an empty checkpoint so state endpoints work immediately
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
from langgraph.checkpoint.base import empty_checkpoint
ckpt_metadata = {
"step": -1,
"source": "input",
"writes": None,
"parents": {},
**body.metadata,
"created_at": now,
}
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
except Exception:
logger.exception("Failed to create checkpoint for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to create thread")
logger.info("Thread created: %s", sanitize_log_param(thread_id))
return ThreadResponse(
thread_id=thread_id,
status="idle",
created_at=str(now),
updated_at=str(now),
metadata=body.metadata,
)
@router.post("/search", response_model=list[ThreadResponse])
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
"""Search and list threads.
Delegates to the configured ThreadMetaStore implementation
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
"""
from app.gateway.deps import get_thread_store
repo = get_thread_store(request)
rows = await repo.search(
metadata=body.metadata or None,
status=body.status,
limit=body.limit,
offset=body.offset,
)
return [
ThreadResponse(
thread_id=r["thread_id"],
status=r.get("status", "idle"),
created_at=r.get("created_at", ""),
updated_at=r.get("updated_at", ""),
metadata=r.get("metadata", {}),
values={"title": r["display_name"]} if r.get("display_name") else {},
interrupts={},
)
for r in rows
]
@router.patch("/{thread_id}", response_model=ThreadResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
"""Merge metadata into a thread record."""
from app.gateway.deps import get_thread_store
thread_store = get_thread_store(request)
record = await thread_store.get(thread_id)
if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
try:
await thread_store.update_metadata(thread_id, body.metadata)
except Exception:
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to update thread")
# Re-read to get the merged metadata + refreshed updated_at
record = await thread_store.get(thread_id) or record
return ThreadResponse(
thread_id=thread_id,
status=record.get("status", "idle"),
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
)
@router.get("/{thread_id}", response_model=ThreadResponse)
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
"""Get thread info.
Reads metadata from the ThreadMetaStore and derives the accurate
execution status from the checkpointer. Falls back to the checkpointer
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
"""
from app.gateway.deps import get_thread_store
thread_store = get_thread_store(request)
checkpointer = get_checkpointer(request)
record: dict | None = await thread_store.get(thread_id)
# Derive accurate status from the checkpointer
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception:
logger.exception("Failed to get checkpoint for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread")
if record is None and checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# If the thread exists in the checkpointer but not in thread_meta (e.g.
# legacy data created before thread_meta adoption), synthesize a minimal
# record from the checkpoint metadata.
if record is None and checkpoint_tuple is not None:
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
record = {
"thread_id": thread_id,
"status": "idle",
"created_at": ckpt_meta.get("created_at", ""),
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
}
if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
channel_values = checkpoint.get("channel_values", {})
return ThreadResponse(
thread_id=thread_id,
status=status,
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
values=serialize_channel_values(channel_values),
)
# ---------------------------------------------------------------------------
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
@require_permission("threads", "read", owner_check=True)
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
"""Get the latest state snapshot for a thread.
Channel values are serialized to ensure LangChain message objects
are converted to JSON-safe dicts.
"""
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception:
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
checkpoint_id = None
ckpt_config = getattr(checkpoint_tuple, "config", {})
if ckpt_config:
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
channel_values = checkpoint.get("channel_values", {})
parent_config = getattr(checkpoint_tuple, "parent_config", None)
parent_checkpoint_id = None
if parent_config:
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id")
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
values = serialize_channel_values(channel_values)
return ThreadStateResponse(
values=values,
next=next_tasks,
metadata=metadata,
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_checkpoint_id,
created_at=str(metadata.get("created_at", "")),
tasks=tasks,
)
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
Writes a new checkpoint that merges *body.values* into the latest
channel values, then syncs any updated ``title`` field through the
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
change immediately in both sqlite and memory backends.
"""
from app.gateway.deps import get_thread_store
checkpointer = get_checkpointer(request)
thread_store = get_thread_store(request)
# checkpoint_ns must be present in the config for aput — default to ""
# (the root graph namespace). checkpoint_id is optional; omitting it
# fetches the latest checkpoint for the thread.
read_config: dict[str, Any] = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "",
}
}
if body.checkpoint_id:
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
try:
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
except Exception:
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# Work on mutable copies so we don't accidentally mutate cached objects.
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
if body.values:
channel_values.update(body.values)
checkpoint["channel_values"] = channel_values
metadata["updated_at"] = time.time()
if body.as_node:
metadata["source"] = "update"
metadata["step"] = metadata.get("step", 0) + 1
metadata["writes"] = {body.as_node: body.values}
# aput requires checkpoint_ns in the config — use the same config used for the
# read (which always includes checkpoint_ns=""). Do NOT include checkpoint_id
# so that aput generates a fresh checkpoint ID for the new snapshot.
write_config: dict[str, Any] = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "",
}
}
try:
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
except Exception:
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to update thread state")
new_checkpoint_id: str | None = None
if isinstance(new_config, dict):
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
# Sync title changes through the ThreadMetaStore abstraction so /threads/search
# reflects them immediately in both sqlite and memory backends.
if body.values and "title" in body.values:
new_title = body.values["title"]
if new_title: # Skip empty strings and None
try:
await thread_store.update_display_name(thread_id, new_title)
except Exception:
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
return ThreadStateResponse(
values=serialize_channel_values(channel_values),
next=[],
metadata=metadata,
checkpoint_id=new_checkpoint_id,
created_at=str(metadata.get("created_at", "")),
)
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
@require_permission("threads", "read", owner_check=True)
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
"""Get checkpoint history for a thread.
Messages are read from the checkpointer's channel values (the
authoritative source) and serialized via
:func:`~deerflow.runtime.serialization.serialize_channel_values`.
Only the latest (first) checkpoint carries the ``messages`` key to
avoid duplicating them across every entry.
"""
checkpointer = get_checkpointer(request)
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
if body.before:
config["configurable"]["checkpoint_id"] = body.before
entries: list[HistoryEntry] = []
is_latest_checkpoint = True
try:
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
ckpt_config = getattr(checkpoint_tuple, "config", {})
parent_config = getattr(checkpoint_tuple, "parent_config", None)
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
parent_id = None
if parent_config:
parent_id = parent_config.get("configurable", {}).get("checkpoint_id")
channel_values = checkpoint.get("channel_values", {})
# Build values from checkpoint channel_values
values: dict[str, Any] = {}
if title := channel_values.get("title"):
values["title"] = title
if thread_data := channel_values.get("thread_data"):
values["thread_data"] = thread_data
# Attach messages only to the latest checkpoint entry.
if is_latest_checkpoint:
messages = channel_values.get("messages")
if messages:
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
is_latest_checkpoint = False
# Derive next tasks
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
# Strip LangGraph internal keys from metadata
user_meta = {k: v for k, v in metadata.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
# Keep step for ordering context
if "step" in metadata:
user_meta["step"] = metadata["step"]
entries.append(
HistoryEntry(
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_id,
metadata=user_meta,
values=values,
created_at=str(metadata.get("created_at", "")),
next=next_tasks,
)
)
except Exception:
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread history")
return entries
+99 -64
View File
@@ -4,13 +4,15 @@ import logging
import os
import stat
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from pydantic import BaseModel
from app.plugins.auth.security.actor_context import bind_request_actor_context
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from app.gateway.authz import require_permission
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import get_paths
from deerflow.runtime.actor_context import get_effective_user_id
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
from deerflow.uploads.manager import (
PathTraversalError,
delete_file_safe,
@@ -55,79 +57,111 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
os.chmod(file_path, writable_mode, **chmod_kwargs)
def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
def _get_uploads_config_value(app_config: AppConfig, key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access."""
uploads_cfg = getattr(app_config, "uploads", None)
if isinstance(uploads_cfg, dict):
return uploads_cfg.get(key, default)
return getattr(uploads_cfg, key, default)
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
"""Return whether automatic host-side document conversion is enabled.
The secure default is disabled unless an operator explicitly opts in via
uploads.auto_convert_documents in config.yaml.
"""
try:
raw = _get_uploads_config_value(app_config, "auto_convert_documents", False)
if isinstance(raw, str):
return raw.strip().lower() in {"1", "true", "yes", "on"}
return bool(raw)
except Exception:
return False
@router.post("", response_model=UploadResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def upload_files(
thread_id: str,
request: Request,
files: list[UploadFile] = File(...),
app_config: AppConfig = Depends(get_config),
) -> UploadResponse:
"""Upload multiple files to a thread's uploads directory."""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
with bind_request_actor_context(request):
try:
uploads_dir = ensure_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
uploaded_files = []
try:
uploads_dir = ensure_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
uploaded_files = []
sandbox_provider = get_sandbox_provider()
sandbox_provider = get_sandbox_provider(app_config)
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
sandbox = None
if sync_to_sandbox:
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
auto_convert_documents = _auto_convert_documents_enabled(app_config)
for file in files:
if not file.filename:
continue
for file in files:
if not file.filename:
continue
try:
safe_filename = normalize_filename(file.filename)
except ValueError:
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
continue
try:
safe_filename = normalize_filename(file.filename)
except ValueError:
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
continue
try:
content = await file.read()
file_path = uploads_dir / safe_filename
file_path.write_bytes(content)
try:
content = await file.read()
file_path = uploads_dir / safe_filename
file_path.write_bytes(content)
virtual_path = upload_virtual_path(safe_filename)
virtual_path = upload_virtual_path(safe_filename)
if sandbox_id != "local":
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, content)
if sync_to_sandbox and sandbox is not None:
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, content)
file_info = {
"filename": safe_filename,
"size": str(len(content)),
"path": str(sandbox_uploads / safe_filename),
"virtual_path": virtual_path,
"artifact_url": upload_artifact_url(thread_id, safe_filename),
}
file_info = {
"filename": safe_filename,
"size": str(len(content)),
"path": str(sandbox_uploads / safe_filename),
"virtual_path": virtual_path,
"artifact_url": upload_artifact_url(thread_id, safe_filename),
}
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
file_ext = file_path.suffix.lower()
if file_ext in CONVERTIBLE_EXTENSIONS:
md_path = await convert_file_to_markdown(file_path)
if md_path:
md_virtual_path = upload_virtual_path(md_path.name)
file_ext = file_path.suffix.lower()
if auto_convert_documents and file_ext in CONVERTIBLE_EXTENSIONS:
md_path = await convert_file_to_markdown(file_path)
if md_path:
md_virtual_path = upload_virtual_path(md_path.name)
if sandbox_id != "local":
_make_file_sandbox_writable(md_path)
sandbox.update_file(md_virtual_path, md_path.read_bytes())
if sync_to_sandbox and sandbox is not None:
_make_file_sandbox_writable(md_path)
sandbox.update_file(md_virtual_path, md_path.read_bytes())
file_info["markdown_file"] = md_path.name
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
file_info["markdown_virtual_path"] = md_virtual_path
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
file_info["markdown_file"] = md_path.name
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
file_info["markdown_virtual_path"] = md_virtual_path
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
uploaded_files.append(file_info)
uploaded_files.append(file_info)
except Exception as e:
logger.error(f"Failed to upload {file.filename}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
except Exception as e:
logger.error(f"Failed to upload {file.filename}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
return UploadResponse(
success=True,
@@ -137,25 +171,26 @@ async def upload_files(
@router.get("/list", response_model=dict)
@require_permission("threads", "read", owner_check=True)
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
"""List all files in a thread's uploads directory."""
with bind_request_actor_context(request):
try:
uploads_dir = get_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
result = list_files_in_dir(uploads_dir)
enrich_file_listing(result, thread_id)
try:
uploads_dir = get_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
result = list_files_in_dir(uploads_dir)
enrich_file_listing(result, thread_id)
# Gateway additionally includes the sandbox-relative path.
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
for f in result["files"]:
f["path"] = str(sandbox_uploads / f["filename"])
# Gateway additionally includes the sandbox-relative path.
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
for f in result["files"]:
f["path"] = str(sandbox_uploads / f["filename"])
return result
return result
@router.delete("/{filename}")
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict:
"""Delete a file from a thread's uploads directory."""
try:
+345
View File
@@ -0,0 +1,345 @@
"""Run lifecycle service layer.
Centralizes the business logic for creating runs, formatting SSE
frames, and consuming stream bridge events. Router modules
(``thread_runs``, ``runs``) are thin HTTP handlers that delegate here.
"""
from __future__ import annotations
import asyncio
import dataclasses
import json
import logging
import re
import time
from collections.abc import Mapping
from typing import Any
from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage
from app.gateway.deps import get_run_context, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.utils import sanitize_log_param
from deerflow.runtime import (
END_SENTINEL,
HEARTBEAT_SENTINEL,
ConflictError,
DisconnectMode,
RunManager,
RunRecord,
RunStatus,
StreamBridge,
UnsupportedStrategyError,
run_agent,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# SSE formatting
# ---------------------------------------------------------------------------
def format_sse(event: str, data: Any, *, event_id: str | None = None) -> str:
"""Format a single SSE frame.
Field order: ``event:`` -> ``data:`` -> ``id:`` (optional) -> blank line.
This matches the LangGraph Platform wire format consumed by the
``useStream`` React hook and the Python ``langgraph-sdk`` SSE decoder.
"""
payload = json.dumps(data, default=str, ensure_ascii=False)
parts = [f"event: {event}", f"data: {payload}"]
if event_id:
parts.append(f"id: {event_id}")
parts.append("")
parts.append("")
return "\n".join(parts)
# ---------------------------------------------------------------------------
# Input / config helpers
# ---------------------------------------------------------------------------
def normalize_stream_modes(raw: list[str] | str | None) -> list[str]:
"""Normalize the stream_mode parameter to a list.
Default matches what ``useStream`` expects: values + messages-tuple.
"""
if raw is None:
return ["values"]
if isinstance(raw, str):
return [raw]
return raw if raw else ["values"]
def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
"""Convert LangGraph Platform input format to LangChain state dict."""
if raw_input is None:
return {}
messages = raw_input.get("messages")
if messages and isinstance(messages, list):
converted = []
for msg in messages:
if isinstance(msg, dict):
role = msg.get("role", msg.get("type", "user"))
content = msg.get("content", "")
if role in ("user", "human"):
converted.append(HumanMessage(content=content))
else:
# TODO: handle other message types (system, ai, tool)
converted.append(HumanMessage(content=content))
else:
converted.append(msg)
return {**raw_input, "messages": converted}
return raw_input
_DEFAULT_ASSISTANT_ID = "lead_agent"
def resolve_agent_factory(assistant_id: str | None):
"""Resolve the agent factory callable from config.
Custom agents are implemented as ``lead_agent`` + an ``agent_name``
injected into ``configurable`` or ``context`` — see
:func:`build_run_config`. All ``assistant_id`` values therefore map to the
same factory; the routing happens inside ``make_lead_agent`` when it reads
``cfg["agent_name"]``.
"""
from deerflow.agents.lead_agent.agent import make_lead_agent
return make_lead_agent
def build_run_config(
thread_id: str,
request_config: dict[str, Any] | None,
metadata: dict[str, Any] | None,
*,
assistant_id: str | None = None,
) -> dict[str, Any]:
"""Build a RunnableConfig dict for the agent.
When *assistant_id* refers to a custom agent (anything other than
``"lead_agent"`` / ``None``), the name is forwarded as ``agent_name`` in
whichever runtime options container is active: ``context`` for
LangGraph >= 0.6.0 requests, otherwise ``configurable``.
``make_lead_agent`` reads this key to load the matching
``agents/<name>/SOUL.md`` and per-agent config — without it the agent
silently runs as the default lead agent.
This mirrors the channel manager's ``_resolve_run_params`` logic so that
the LangGraph Platform-compatible HTTP API and the IM channel path behave
identically.
"""
config: dict[str, Any] = {"recursion_limit": 100}
if request_config:
# LangGraph >= 0.6.0 introduced ``context`` as the preferred way to
# pass thread-level data and rejects requests that include both
# ``configurable`` and ``context``. If the caller already sends
# ``context``, honour it and skip our own ``configurable`` dict.
if "context" in request_config:
if "configurable" in request_config:
logger.warning(
"build_run_config: client sent both 'context' and 'configurable'; preferring 'context' (LangGraph >= 0.6.0). thread_id=%s, caller_configurable keys=%s",
thread_id,
list(request_config.get("configurable", {}).keys()),
)
context_value = request_config["context"]
if context_value is None:
context = {}
elif isinstance(context_value, Mapping):
context = dict(context_value)
else:
raise ValueError("request config 'context' must be a mapping or null.")
config["context"] = context
else:
configurable = {"thread_id": thread_id}
configurable.update(request_config.get("configurable", {}))
config["configurable"] = configurable
for k, v in request_config.items():
if k not in ("configurable", "context"):
config[k] = v
else:
config["configurable"] = {"thread_id": thread_id}
# Inject custom agent name when the caller specified a non-default assistant.
# Honour an explicit agent_name in the active runtime options container.
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID:
normalized = assistant_id.strip().lower().replace("_", "-")
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
if "configurable" in config:
target = config["configurable"]
elif "context" in config:
target = config["context"]
else:
target = config.setdefault("configurable", {})
if target is not None and "agent_name" not in target:
target["agent_name"] = normalized
if metadata:
config.setdefault("metadata", {}).update(metadata)
return config
# ---------------------------------------------------------------------------
# Run lifecycle
# ---------------------------------------------------------------------------
async def start_run(
body: Any,
thread_id: str,
request: Request,
) -> RunRecord:
"""Create a RunRecord and launch the background agent task.
Parameters
----------
body : RunCreateRequest
The validated request body (typed as Any to avoid circular import
with the router module that defines the Pydantic model).
thread_id : str
Target thread.
request : Request
FastAPI request — used to retrieve singletons from ``app.state``.
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
run_ctx = get_run_context(request)
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
if follow_up_to_run_id is None:
run_store = get_run_store(request)
try:
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
if recent_runs and recent_runs[0].get("status") == "success":
follow_up_to_run_id = recent_runs[0]["run_id"]
except Exception:
pass # Don't block run creation
# Enrich base context with per-run field
if follow_up_to_run_id:
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
try:
record = await run_mgr.create_or_reject(
thread_id,
body.assistant_id,
on_disconnect=disconnect,
metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
follow_up_to_run_id=follow_up_to_run_id,
)
except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
except UnsupportedStrategyError as exc:
raise HTTPException(status_code=501, detail=str(exc)) from exc
# Upsert thread metadata so the thread appears in /threads/search,
# even for threads that were never explicitly created via POST /threads
# (e.g. stateless runs).
try:
existing = await run_ctx.thread_store.get(thread_id)
if existing is None:
await run_ctx.thread_store.create(
thread_id,
assistant_id=body.assistant_id,
metadata=body.metadata,
)
else:
await run_ctx.thread_store.update_status(thread_id, "running")
except Exception:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
agent_factory = resolve_agent_factory(body.assistant_id)
graph_input = normalize_input(body.input)
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
# Merge DeerFlow-specific context overrides into configurable.
# The ``context`` field is a custom extension for the langgraph-compat layer
# that carries agent configuration (model_name, thinking_enabled, etc.).
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
context = getattr(body, "context", None)
if context:
_CONTEXT_CONFIGURABLE_KEYS = {
"model_name",
"mode",
"thinking_enabled",
"reasoning_effort",
"is_plan_mode",
"subagent_enabled",
"max_concurrent_subagents",
"agent_name",
"is_bootstrap",
}
configurable = config.setdefault("configurable", {})
for key in _CONTEXT_CONFIGURABLE_KEYS:
if key in context:
configurable.setdefault(key, context[key])
stream_modes = normalize_stream_modes(body.stream_mode)
task = asyncio.create_task(
run_agent(
bridge,
run_mgr,
record,
ctx=run_ctx,
agent_factory=agent_factory,
graph_input=graph_input,
config=config,
stream_modes=stream_modes,
stream_subgraphs=body.stream_subgraphs,
interrupt_before=body.interrupt_before,
interrupt_after=body.interrupt_after,
)
)
record.task = task
# Title sync is handled by worker.py's finally block which reads the
# title from the checkpoint and calls thread_store.update_display_name
# after the run completes.
return record
async def sse_consumer(
bridge: StreamBridge,
record: RunRecord,
request: Request,
run_mgr: RunManager,
):
"""Async generator that yields SSE frames from the bridge.
The ``finally`` block implements ``on_disconnect`` semantics:
- ``cancel``: abort the background task on client disconnect.
- ``continue``: let the task run; events are discarded.
"""
last_event_id = request.headers.get("Last-Event-ID")
try:
async for entry in bridge.subscribe(record.run_id, last_event_id=last_event_id):
if await request.is_disconnected():
break
if entry is HEARTBEAT_SENTINEL:
yield ": heartbeat\n\n"
continue
if entry is END_SENTINEL:
yield format_sse("end", None, event_id=entry.id or None)
return
yield format_sse(entry.event, entry.data, event_id=entry.id or None)
finally:
if record.status in (RunStatus.pending, RunStatus.running):
if record.on_disconnect == DisconnectMode.cancel:
await run_mgr.cancel(record.run_id)
-5
View File
@@ -1,5 +0,0 @@
"""Gateway service layer."""
"""Compatibility package for app service submodules."""
__all__: list[str] = []
@@ -1,29 +0,0 @@
"""Runs app layer services."""
from app.infra.storage import StorageRunObserver
from .input import (
AdaptedRunRequest,
RunSpecBuilder,
UnsupportedRunFeatureError,
adapt_create_run_request,
adapt_create_stream_request,
adapt_create_wait_request,
adapt_join_stream_request,
adapt_join_wait_request,
)
from .store import AppRunCreateStore, AppRunDeleteStore, AppRunQueryStore
__all__ = [
"AdaptedRunRequest",
"AppRunCreateStore",
"AppRunDeleteStore",
"AppRunQueryStore",
"RunSpecBuilder",
"StorageRunObserver",
"UnsupportedRunFeatureError",
"adapt_create_run_request",
"adapt_create_stream_request",
"adapt_create_wait_request",
"adapt_join_stream_request",
"adapt_join_wait_request",
]
@@ -1,150 +0,0 @@
"""Facade factory - assembles RunsFacade with dependencies."""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
from fastapi import HTTPException, Request
from app.gateway.dependencies import get_checkpointer, get_stream_bridge
from deerflow.runtime.runs.facade import RunsFacade
from deerflow.runtime.runs.facade import RunsRuntime
from deerflow.runtime.runs.internal.execution.supervisor import RunSupervisor
from deerflow.runtime.runs.internal.planner import ExecutionPlanner
from deerflow.runtime.runs.internal.registry import RunRegistry
from deerflow.runtime.runs.internal.streams import RunStreamService
from deerflow.runtime.runs.internal.wait import RunWaitService
from app.infra.storage import StorageRunObserver, ThreadMetaStorage
from app.infra.storage.runs import RunDeleteRepository, RunReadRepository, RunWriteRepository
from .store import AppRunCreateStore, AppRunDeleteStore, AppRunQueryStore
if TYPE_CHECKING:
from deerflow.runtime.stream_bridge import StreamBridge
type AgentFactory = Callable[..., object]
# Module-level singleton registry (shared across requests)
_registry: RunRegistry | None = None
_supervisor: RunSupervisor | None = None
def _get_state(request: Request, attr: str, label: str):
value = getattr(request.app.state, attr, None)
if value is None:
raise HTTPException(status_code=503, detail=f"{label} not available")
return value
def get_registry() -> RunRegistry:
"""Get or create singleton registry."""
global _registry
if _registry is None:
_registry = RunRegistry()
return _registry
def get_supervisor() -> RunSupervisor:
"""Get or create singleton run supervisor."""
global _supervisor
if _supervisor is None:
_supervisor = RunSupervisor()
return _supervisor
def resolve_agent_factory(assistant_id: str | None) -> AgentFactory:
"""Resolve the agent factory callable from config."""
from deerflow.agents.lead_agent.agent import make_lead_agent
return make_lead_agent
def build_runs_facade(
*,
stream_bridge: "StreamBridge",
checkpointer: object,
store: object | None = None,
run_read_repo: RunReadRepository | None = None,
run_write_repo: RunWriteRepository | None = None,
run_delete_repo: RunDeleteRepository | None = None,
thread_meta_storage: ThreadMetaStorage | None = None,
run_event_store: object | None = None,
) -> RunsFacade:
"""
Build RunsFacade with all dependencies.
Args:
stream_bridge: StreamBridge instance
checkpointer: LangGraph checkpointer
store: Optional LangGraph runtime store
run_read_repo: Optional run repository for durable reads
run_write_repo: Optional run repository for durable writes
run_delete_repo: Optional run repository for durable deletes
thread_meta_storage: Optional thread metadata storage adapter
Returns:
Configured RunsFacade instance
"""
registry = get_registry()
planner = ExecutionPlanner()
supervisor = get_supervisor()
stream_service = RunStreamService(stream_bridge)
wait_service = RunWaitService(stream_service)
query_store = AppRunQueryStore(run_read_repo) if run_read_repo else None
create_store = (
AppRunCreateStore(run_write_repo, thread_meta_storage=thread_meta_storage)
if run_write_repo
else None
)
delete_store = AppRunDeleteStore(run_delete_repo) if run_delete_repo else None
# Build storage observer if repositories provided
storage_observer = None
if run_write_repo or thread_meta_storage:
storage_observer = StorageRunObserver(
run_write_repo=run_write_repo,
thread_meta_storage=thread_meta_storage,
)
return RunsFacade(
registry=registry,
planner=planner,
supervisor=supervisor,
stream_service=stream_service,
wait_service=wait_service,
runtime=RunsRuntime(
bridge=stream_bridge,
checkpointer=checkpointer,
store=store,
event_store=run_event_store,
agent_factory_resolver=resolve_agent_factory,
),
observer=storage_observer,
query_store=query_store,
create_store=create_store,
delete_store=delete_store,
)
def build_runs_facade_from_request(request: "Request") -> RunsFacade:
"""
Build RunsFacade from FastAPI request context.
Extracts dependencies from request.app.state.
"""
app_state = request.app.state
return build_runs_facade(
stream_bridge=get_stream_bridge(request),
checkpointer=get_checkpointer(request),
store=getattr(request.app.state, "store", None),
run_read_repo=getattr(app_state, "run_read_repo", None),
run_write_repo=getattr(app_state, "run_write_repo", None),
run_delete_repo=getattr(app_state, "run_delete_repo", None),
thread_meta_storage=getattr(app_state, "thread_meta_storage", None),
run_event_store=getattr(app_state, "run_event_store", None),
)
@@ -1,22 +0,0 @@
"""Input adapters for app-owned runs entrypoints."""
from .request_adapter import (
AdaptedRunRequest,
adapt_create_run_request,
adapt_create_stream_request,
adapt_create_wait_request,
adapt_join_stream_request,
adapt_join_wait_request,
)
from .spec_builder import RunSpecBuilder, UnsupportedRunFeatureError
__all__ = [
"AdaptedRunRequest",
"RunSpecBuilder",
"UnsupportedRunFeatureError",
"adapt_create_run_request",
"adapt_create_stream_request",
"adapt_create_wait_request",
"adapt_join_stream_request",
"adapt_join_wait_request",
]
@@ -1,127 +0,0 @@
"""App-owned request adapter for runs entrypoints."""
from __future__ import annotations
from dataclasses import dataclass
from deerflow.runtime.stream_bridge import JSONValue
from deerflow.runtime.runs.types import RunIntent
type RequestBody = dict[str, JSONValue]
type RequestQuery = dict[str, str]
@dataclass(frozen=True)
class AdaptedRunRequest:
"""
统一的内部请求 DTO.
路由层只负责提取 path/query/body,适配器负责转成稳定内部结构。
"""
intent: RunIntent
thread_id: str | None
run_id: str | None
body: RequestBody
headers: dict[str, str]
query: RequestQuery
@property
def last_event_id(self) -> str | None:
"""Extract Last-Event-ID from headers."""
return self.headers.get("last-event-id") or self.headers.get("Last-Event-ID")
@property
def is_stateless(self) -> bool:
"""Check if this is a stateless request."""
return self.thread_id is None
def adapt_create_run_request(
*,
thread_id: str | None,
body: RequestBody,
headers: dict[str, str] | None = None,
query: RequestQuery | None = None,
) -> AdaptedRunRequest:
"""Adapt POST /threads/{thread_id}/runs or POST /runs."""
return AdaptedRunRequest(
intent="create_background",
thread_id=thread_id,
run_id=None,
body=body,
headers=headers or {},
query=query or {},
)
def adapt_create_stream_request(
*,
thread_id: str | None,
body: RequestBody,
headers: dict[str, str] | None = None,
query: RequestQuery | None = None,
) -> AdaptedRunRequest:
"""Adapt POST /threads/{thread_id}/runs/stream or POST /runs/stream."""
return AdaptedRunRequest(
intent="create_and_stream",
thread_id=thread_id,
run_id=None,
body=body,
headers=headers or {},
query=query or {},
)
def adapt_create_wait_request(
*,
thread_id: str | None,
body: RequestBody,
headers: dict[str, str] | None = None,
query: RequestQuery | None = None,
) -> AdaptedRunRequest:
"""Adapt POST /threads/{thread_id}/runs/wait or POST /runs/wait."""
return AdaptedRunRequest(
intent="create_and_wait",
thread_id=thread_id,
run_id=None,
body=body,
headers=headers or {},
query=query or {},
)
def adapt_join_stream_request(
*,
thread_id: str,
run_id: str,
headers: dict[str, str] | None = None,
query: RequestQuery | None = None,
) -> AdaptedRunRequest:
"""Adapt GET /threads/{thread_id}/runs/{run_id}/stream."""
return AdaptedRunRequest(
intent="join_stream",
thread_id=thread_id,
run_id=run_id,
body={},
headers=headers or {},
query=query or {},
)
def adapt_join_wait_request(
*,
thread_id: str,
run_id: str,
headers: dict[str, str] | None = None,
query: RequestQuery | None = None,
) -> AdaptedRunRequest:
"""Adapt GET /threads/{thread_id}/runs/{run_id}/join."""
return AdaptedRunRequest(
intent="join_wait",
thread_id=thread_id,
run_id=run_id,
body={},
headers=headers or {},
query=query or {},
)
@@ -1,254 +0,0 @@
"""App-owned RunSpec builder."""
from __future__ import annotations
import re
import uuid
from langchain_core.messages import HumanMessage
from deerflow.runtime.runs.types import CheckpointRequest, RunScope, RunSpec
from deerflow.runtime.stream_bridge import JSONValue
from .request_adapter import AdaptedRunRequest
type JSONMapping = dict[str, JSONValue]
type GraphInput = dict[str, object]
type RunnableConfigDict = dict[str, object]
class UnsupportedRunFeatureError(ValueError):
"""Raised when a phase1-unsupported feature is requested."""
pass
class RunSpecBuilder:
"""
Build RunSpec from AdaptedRunRequest.
Phase 1 rules:
1. messages-tuple normalized to messages
2. enqueue not supported
3. rollback not supported
4. after_seconds not supported
5. stream_resumable accepted
6. stateless auto-generates temporary thread
"""
# Phase 1 unsupported features
UNSUPPORTED_MULTITASK_STRATEGIES = {"enqueue"}
UNSUPPORTED_ACTIONS = {"rollback"}
# Default stream modes
DEFAULT_STREAM_MODES = ["values", "messages"]
CONTEXT_CONFIGURABLE_KEYS = frozenset({
"model_name",
"mode",
"thinking_enabled",
"reasoning_effort",
"is_plan_mode",
"subagent_enabled",
"max_concurrent_subagents",
})
DEFAULT_ASSISTANT_ID = "lead_agent"
@staticmethod
def _as_json_mapping(value: JSONValue | None) -> JSONMapping | None:
return value if isinstance(value, dict) else None
@staticmethod
def _as_string_list(value: JSONValue | None) -> list[str] | None:
if not isinstance(value, list):
return None
return [item for item in value if isinstance(item, str)]
def build(self, request: AdaptedRunRequest) -> RunSpec:
"""Build RunSpec from adapted request."""
body = request.body
# Validate phase1 constraints
self._validate_constraints(body)
# Build scope
scope = self._build_scope(request)
# Normalize stream modes
stream_modes = self._normalize_stream_modes(body.get("stream_mode"))
# Build checkpoint request
checkpoint_request = self._build_checkpoint_request(body)
config = self._build_runnable_config(
thread_id=scope.thread_id,
request_config=self._as_json_mapping(body.get("config")),
metadata=self._as_json_mapping(body.get("metadata")),
assistant_id=body.get("assistant_id"),
context=self._as_json_mapping(body.get("context")),
)
return RunSpec(
intent=request.intent,
scope=scope,
assistant_id=body.get("assistant_id") if isinstance(body.get("assistant_id"), str) else None,
input=self._normalize_input(self._as_json_mapping(body.get("input"))),
command=self._as_json_mapping(body.get("command")),
runnable_config=config,
context=self._as_json_mapping(body.get("context")),
metadata=self._as_json_mapping(body.get("metadata")) or {},
stream_modes=stream_modes,
stream_subgraphs=bool(body.get("stream_subgraphs", False)),
stream_resumable=bool(body.get("stream_resumable", False)),
on_disconnect=body.get("on_disconnect", "cancel") if body.get("on_disconnect") in {"cancel", "continue"} else "cancel",
on_completion=body.get("on_completion", "keep") if body.get("on_completion") in {"delete", "keep"} else "keep",
multitask_strategy=body.get("multitask_strategy", "reject") if body.get("multitask_strategy") in {"reject", "interrupt"} else "reject",
interrupt_before="*" if body.get("interrupt_before") == "*" else self._as_string_list(body.get("interrupt_before")),
interrupt_after="*" if body.get("interrupt_after") == "*" else self._as_string_list(body.get("interrupt_after")),
checkpoint_request=checkpoint_request,
follow_up_to_run_id=body.get("follow_up_to_run_id") if isinstance(body.get("follow_up_to_run_id"), str) else None,
webhook=body.get("webhook") if isinstance(body.get("webhook"), str) else None,
feedback_keys=self._as_string_list(body.get("feedback_keys")),
)
def _validate_constraints(self, body: JSONMapping) -> None:
"""Validate phase1 constraints, raise UnsupportedRunFeatureError if violated."""
# Check multitask_strategy
strategy = body.get("multitask_strategy", "reject")
if strategy in self.UNSUPPORTED_MULTITASK_STRATEGIES:
raise UnsupportedRunFeatureError(
f"multitask_strategy '{strategy}' is not supported in phase1. "
f"Supported: reject, interrupt"
)
# Check for rollback action
command = self._as_json_mapping(body.get("command")) or {}
if command.get("action") in self.UNSUPPORTED_ACTIONS:
raise UnsupportedRunFeatureError(
f"action '{command.get('action')}' is not supported in phase1"
)
# Check for after_seconds
if body.get("after_seconds") is not None:
raise UnsupportedRunFeatureError("after_seconds is not supported in phase1")
def _build_scope(self, request: AdaptedRunRequest) -> RunScope:
"""Build RunScope from request."""
if request.is_stateless:
# Stateless: generate temporary thread
return RunScope(
kind="stateless",
thread_id=str(uuid.uuid4()),
temporary=True,
)
else:
assert request.thread_id is not None
return RunScope(
kind="stateful",
thread_id=request.thread_id,
temporary=False,
)
def _normalize_stream_modes(self, stream_mode: JSONValue | None) -> list[str]:
"""Normalize stream_mode to list, convert messages-tuple to messages."""
if stream_mode is None:
return self.DEFAULT_STREAM_MODES.copy()
if isinstance(stream_mode, str):
modes = [stream_mode]
elif isinstance(stream_mode, list):
modes = [mode for mode in stream_mode if isinstance(mode, str)]
else:
return self.DEFAULT_STREAM_MODES.copy()
return ["messages" if m == "messages-tuple" else m for m in modes]
def _build_checkpoint_request(self, body: JSONMapping) -> CheckpointRequest | None:
"""Build CheckpointRequest if checkpoint data is provided."""
checkpoint_id = body.get("checkpoint_id")
checkpoint = self._as_json_mapping(body.get("checkpoint"))
if not isinstance(checkpoint_id, str) and checkpoint is None:
return None
return CheckpointRequest(
checkpoint_id=checkpoint_id if isinstance(checkpoint_id, str) else None,
checkpoint=checkpoint,
)
def _normalize_input(self, raw_input: JSONMapping | None) -> GraphInput | None:
"""Convert HTTP-friendly message dicts into LangChain message objects."""
if raw_input is None:
return None
messages = raw_input.get("messages")
if not messages or not isinstance(messages, list):
return raw_input
converted: list[object] = []
for msg in messages:
if isinstance(msg, dict):
role = msg.get("role", msg.get("type", "user"))
content = msg.get("content", "")
if role in ("user", "human"):
converted.append(HumanMessage(content=content))
else:
converted.append(HumanMessage(content=content))
else:
converted.append(msg)
return {**raw_input, "messages": converted}
def _build_runnable_config(
self,
*,
thread_id: str,
request_config: JSONMapping | None,
metadata: JSONMapping | None,
assistant_id: str | None,
context: JSONMapping | None,
) -> RunnableConfigDict:
"""Build RunnableConfig from request payload and app-side rules."""
config: RunnableConfigDict = {"recursion_limit": 100}
if request_config:
if "context" in request_config:
config["context"] = request_config["context"]
else:
configurable = {"thread_id": thread_id}
raw_configurable = request_config.get("configurable")
if isinstance(raw_configurable, dict):
configurable.update(raw_configurable)
config["configurable"] = configurable
for key, value in request_config.items():
if key not in ("configurable", "context"):
config[key] = value
else:
config["configurable"] = {"thread_id": thread_id}
configurable = config.get("configurable")
if (
assistant_id
and assistant_id != self.DEFAULT_ASSISTANT_ID
and isinstance(configurable, dict)
and "agent_name" not in configurable
):
normalized = assistant_id.strip().lower().replace("_", "-")
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
raise ValueError(
f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization."
)
configurable["agent_name"] = normalized
if metadata:
existing_metadata = config.get("metadata")
if isinstance(existing_metadata, dict):
existing_metadata.update(metadata)
else:
config["metadata"] = dict(metadata)
if context and isinstance(configurable, dict):
for key in self.CONTEXT_CONFIGURABLE_KEYS:
if key in context:
configurable.setdefault(key, context[key])
return config
@@ -1,5 +0,0 @@
"""Compatibility wrapper for the app-owned storage observer."""
from app.infra.storage.runs import StorageRunObserver
__all__ = ["StorageRunObserver"]
@@ -1,11 +0,0 @@
"""App-owned runs store adapters."""
from .create_store import AppRunCreateStore
from .delete_store import AppRunDeleteStore
from .query_store import AppRunQueryStore
__all__ = [
"AppRunCreateStore",
"AppRunDeleteStore",
"AppRunQueryStore",
]
@@ -1,38 +0,0 @@
"""App-owned durable run creation adapter."""
from __future__ import annotations
from deerflow.runtime.runs.store import RunCreateStore
from deerflow.runtime.runs.types import RunRecord
from app.infra.storage import ThreadMetaStorage
from app.infra.storage.runs import RunWriteRepository
class AppRunCreateStore(RunCreateStore):
"""Write the initial durable row for a newly created run."""
def __init__(self, repo: RunWriteRepository, thread_meta_storage: ThreadMetaStorage | None = None) -> None:
self._repo = repo
self._thread_meta_storage = thread_meta_storage
async def create_run(self, record: RunRecord) -> None:
await self._repo.create(
run_id=record.run_id,
thread_id=record.thread_id,
assistant_id=record.assistant_id,
status=str(record.status),
metadata=record.metadata,
follow_up_to_run_id=record.follow_up_to_run_id,
created_at=record.created_at,
)
if self._thread_meta_storage is not None and record.assistant_id:
thread = await self._thread_meta_storage.ensure_thread(
thread_id=record.thread_id,
assistant_id=record.assistant_id,
)
if thread.assistant_id != record.assistant_id:
await self._thread_meta_storage.sync_thread_assistant_id(
thread_id=record.thread_id,
assistant_id=record.assistant_id,
)
@@ -1,17 +0,0 @@
"""App-owned durable run deletion adapter."""
from __future__ import annotations
from deerflow.runtime.runs.store import RunDeleteStore
from app.infra.storage.runs import RunDeleteRepository
class AppRunDeleteStore(RunDeleteStore):
"""Delete durable run rows via the app storage adapter."""
def __init__(self, repo: RunDeleteRepository) -> None:
self._repo = repo
async def delete_run(self, run_id: str) -> bool:
return await self._repo.delete(run_id)
@@ -1,47 +0,0 @@
"""App-owned durable run query adapter."""
from __future__ import annotations
from deerflow.runtime.runs.store import RunQueryStore
from deerflow.runtime.runs.types import RunRecord, RunStatus
from app.infra.storage.runs import RunReadRepository, RunRow
class AppRunQueryStore(RunQueryStore):
"""Map app-side durable run rows into harness RunRecord DTOs."""
def __init__(self, repo: RunReadRepository) -> None:
self._repo = repo
async def get_run(self, run_id: str) -> RunRecord | None:
row = await self._repo.get(run_id)
if row is None:
return None
return self._to_run_record(row)
async def list_runs(
self,
thread_id: str,
*,
limit: int = 100,
) -> list[RunRecord]:
rows = await self._repo.list_by_thread(thread_id, limit=limit)
return [self._to_run_record(row) for row in rows]
def _to_run_record(self, row: RunRow) -> RunRecord:
return RunRecord(
run_id=row["run_id"],
thread_id=row["thread_id"],
assistant_id=row.get("assistant_id"),
status=RunStatus(row.get("status", "pending")),
temporary=False,
multitask_strategy=row.get("multitask_strategy", "reject"),
metadata=row.get("metadata", {}),
follow_up_to_run_id=row.get("follow_up_to_run_id"),
created_at=row.get("created_at", ""),
updated_at=row.get("updated_at", ""),
started_at=row.get("started_at"),
ended_at=row.get("ended_at"),
error=row.get("error"),
)
+6
View File
@@ -0,0 +1,6 @@
"""Shared utility helpers for the Gateway layer."""
def sanitize_log_param(value: str) -> str:
"""Strip control characters to prevent log injection."""
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
-1
View File
@@ -1 +0,0 @@
"""Application-owned infrastructure adapters and wiring."""
-6
View File
@@ -1,6 +0,0 @@
"""Run event store backends owned by app infrastructure."""
from .factory import build_run_event_store
from .jsonl_store import JsonlRunEventStore
__all__ = ["JsonlRunEventStore", "build_run_event_store"]
-25
View File
@@ -1,25 +0,0 @@
"""Factory for app-owned run event store backends."""
from __future__ import annotations
from pathlib import Path
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.infra.storage import AppRunEventStore
from deerflow.config import get_app_config
from .jsonl_store import JsonlRunEventStore
def build_run_event_store(session_factory: async_sessionmaker[AsyncSession]) -> AppRunEventStore | JsonlRunEventStore:
"""Build the run event store selected by app configuration."""
config = get_app_config().run_events
if config.backend == "db":
return AppRunEventStore(session_factory)
if config.backend == "jsonl":
return JsonlRunEventStore(
base_dir=Path(config.jsonl_base_dir),
)
raise ValueError(f"Unsupported run event backend: {config.backend}")
-210
View File
@@ -1,210 +0,0 @@
"""JSONL run event store backend owned by app infrastructure."""
from __future__ import annotations
import asyncio
import json
import shutil
from collections.abc import Iterable
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
class JsonlRunEventStore:
"""Append-only JSONL implementation of the runs RunEventStore protocol."""
def __init__(
self,
base_dir: Path | str = ".deer-flow/run-events",
) -> None:
self._base_dir = Path(base_dir)
self._locks: dict[str, asyncio.Lock] = {}
self._locks_guard = asyncio.Lock()
async def put_batch(self, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
if not events:
return []
grouped: dict[str, list[dict[str, Any]]] = {}
for event in events:
grouped.setdefault(str(event["thread_id"]), []).append(event)
records_by_thread: dict[str, list[dict[str, Any]]] = {}
for thread_id, thread_events in grouped.items():
async with await self._thread_lock(thread_id):
records_by_thread[thread_id] = self._append_thread_events(thread_id, thread_events)
indexes = {thread_id: 0 for thread_id in records_by_thread}
ordered: list[dict[str, Any]] = []
for event in events:
thread_id = str(event["thread_id"])
index = indexes[thread_id]
ordered.append(records_by_thread[thread_id][index])
indexes[thread_id] = index + 1
return ordered
async def list_messages(
self,
thread_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> list[dict[str, Any]]:
events = [event for event in await self._read_thread_events(thread_id) if event.get("category") == "message"]
if before_seq is not None:
events = [event for event in events if int(event["seq"]) < before_seq]
return events[-limit:]
if after_seq is not None:
events = [event for event in events if int(event["seq"]) > after_seq]
return events[:limit]
return events[-limit:]
async def list_events(
self,
thread_id: str,
run_id: str,
*,
event_types: list[str] | None = None,
limit: int = 500,
) -> list[dict[str, Any]]:
event_type_set = set(event_types or [])
events = [
event
for event in await self._read_thread_events(thread_id)
if event.get("run_id") == run_id and (not event_type_set or event.get("event_type") in event_type_set)
]
return events[:limit]
async def list_messages_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> list[dict[str, Any]]:
events = [
event
for event in await self._read_thread_events(thread_id)
if event.get("run_id") == run_id and event.get("category") == "message"
]
if before_seq is not None:
events = [event for event in events if int(event["seq"]) < before_seq]
return events[-limit:]
if after_seq is not None:
events = [event for event in events if int(event["seq"]) > after_seq]
return events[:limit]
return events[-limit:]
async def count_messages(self, thread_id: str) -> int:
return len(await self.list_messages(thread_id, limit=10**9))
async def delete_by_thread(self, thread_id: str) -> int:
async with await self._thread_lock(thread_id):
count = len(self._read_thread_events_sync(thread_id))
shutil.rmtree(self._thread_dir(thread_id), ignore_errors=True)
return count
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
async with await self._thread_lock(thread_id):
events = self._read_thread_events_sync(thread_id)
kept = [event for event in events if event.get("run_id") != run_id]
deleted = len(events) - len(kept)
if deleted:
self._write_thread_events(thread_id, kept)
return deleted
async def _thread_lock(self, thread_id: str) -> asyncio.Lock:
async with self._locks_guard:
lock = self._locks.get(thread_id)
if lock is None:
lock = asyncio.Lock()
self._locks[thread_id] = lock
return lock
def _append_thread_events(self, thread_id: str, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
thread_dir = self._thread_dir(thread_id)
thread_dir.mkdir(parents=True, exist_ok=True)
seq = self._read_seq(thread_id)
records: list[dict[str, Any]] = []
with self._events_path(thread_id).open("a", encoding="utf-8") as file:
for event in events:
seq += 1
record = self._normalize_event(event, seq=seq)
file.write(json.dumps(record, ensure_ascii=False, default=str))
file.write("\n")
records.append(record)
self._write_seq(thread_id, seq)
return records
def _normalize_event(self, event: dict[str, Any], *, seq: int) -> dict[str, Any]:
created_at = event.get("created_at")
if isinstance(created_at, datetime):
created_at_value = created_at.isoformat()
elif created_at:
created_at_value = str(created_at)
else:
created_at_value = datetime.now(UTC).isoformat()
return {
"thread_id": str(event["thread_id"]),
"run_id": str(event["run_id"]),
"seq": seq,
"event_type": str(event["event_type"]),
"category": str(event["category"]),
"content": event.get("content", ""),
"metadata": dict(event.get("metadata") or {}),
"created_at": created_at_value,
}
async def _read_thread_events(self, thread_id: str) -> list[dict[str, Any]]:
async with await self._thread_lock(thread_id):
return self._read_thread_events_sync(thread_id)
def _read_thread_events_sync(self, thread_id: str) -> list[dict[str, Any]]:
path = self._events_path(thread_id)
if not path.exists():
return []
events: list[dict[str, Any]] = []
with path.open(encoding="utf-8") as file:
for line in file:
stripped = line.strip()
if stripped:
events.append(json.loads(stripped))
return events
def _write_thread_events(self, thread_id: str, events: Iterable[dict[str, Any]]) -> None:
thread_dir = self._thread_dir(thread_id)
thread_dir.mkdir(parents=True, exist_ok=True)
temp_path = self._events_path(thread_id).with_suffix(".jsonl.tmp")
with temp_path.open("w", encoding="utf-8") as file:
for event in events:
file.write(json.dumps(event, ensure_ascii=False, default=str))
file.write("\n")
temp_path.replace(self._events_path(thread_id))
def _read_seq(self, thread_id: str) -> int:
path = self._seq_path(thread_id)
if not path.exists():
return 0
try:
return int(path.read_text(encoding="utf-8").strip() or "0")
except ValueError:
return 0
def _write_seq(self, thread_id: str, seq: int) -> None:
self._seq_path(thread_id).write_text(str(seq), encoding="utf-8")
def _thread_dir(self, thread_id: str) -> Path:
return self._base_dir / "threads" / thread_id
def _events_path(self, thread_id: str) -> Path:
return self._thread_dir(thread_id) / "events.jsonl"
def _seq_path(self, thread_id: str) -> Path:
return self._thread_dir(thread_id) / "seq"
-14
View File
@@ -1,14 +0,0 @@
"""Storage-facing adapters owned by the app layer."""
from .run_events import AppRunEventStore
from .runs import FeedbackStoreAdapter, RunStoreAdapter, StorageRunObserver
from .thread_meta import ThreadMetaStorage, ThreadMetaStoreAdapter
__all__ = [
"AppRunEventStore",
"FeedbackStoreAdapter",
"RunStoreAdapter",
"StorageRunObserver",
"ThreadMetaStorage",
"ThreadMetaStoreAdapter",
]
-166
View File
@@ -1,166 +0,0 @@
"""App-owned adapter from runs callbacks to storage run event repository."""
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from store.repositories import RunEvent, RunEventCreate, build_run_event_repository, build_thread_meta_repository
from deerflow.runtime.actor_context import get_actor_context
class AppRunEventStore:
"""Implements the harness RunEventStore protocol using storage repositories."""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._session_factory = session_factory
async def put_batch(self, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
if not events:
return []
denied = {str(event["thread_id"]) for event in events if not await self._thread_visible(str(event["thread_id"]))}
if denied:
raise PermissionError(f"actor is not allowed to append events for thread(s): {', '.join(sorted(denied))}")
async with self._session_factory() as session:
try:
repo = build_run_event_repository(session)
rows = await repo.append_batch([_event_create_from_dict(event) for event in events])
await session.commit()
except Exception:
await session.rollback()
raise
return [_event_to_dict(row) for row in rows]
async def list_messages(
self,
thread_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> list[dict[str, Any]]:
if not await self._thread_visible(thread_id):
return []
async with self._session_factory() as session:
repo = build_run_event_repository(session)
rows = await repo.list_messages(
thread_id,
limit=limit,
before_seq=before_seq,
after_seq=after_seq,
)
return [_event_to_dict(row) for row in rows]
async def list_events(
self,
thread_id: str,
run_id: str,
*,
event_types: list[str] | None = None,
limit: int = 500,
) -> list[dict[str, Any]]:
if not await self._thread_visible(thread_id):
return []
async with self._session_factory() as session:
repo = build_run_event_repository(session)
rows = await repo.list_events(thread_id, run_id, event_types=event_types, limit=limit)
return [_event_to_dict(row) for row in rows]
async def list_messages_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> list[dict[str, Any]]:
if not await self._thread_visible(thread_id):
return []
async with self._session_factory() as session:
repo = build_run_event_repository(session)
rows = await repo.list_messages_by_run(
thread_id,
run_id,
limit=limit,
before_seq=before_seq,
after_seq=after_seq,
)
return [_event_to_dict(row) for row in rows]
async def count_messages(self, thread_id: str) -> int:
if not await self._thread_visible(thread_id):
return 0
async with self._session_factory() as session:
repo = build_run_event_repository(session)
return await repo.count_messages(thread_id)
async def delete_by_thread(self, thread_id: str) -> int:
if not await self._thread_visible(thread_id):
return 0
async with self._session_factory() as session:
try:
repo = build_run_event_repository(session)
count = await repo.delete_by_thread(thread_id)
await session.commit()
except Exception:
await session.rollback()
raise
return count
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
if not await self._thread_visible(thread_id):
return 0
async with self._session_factory() as session:
try:
repo = build_run_event_repository(session)
count = await repo.delete_by_run(thread_id, run_id)
await session.commit()
except Exception:
await session.rollback()
raise
return count
async def _thread_visible(self, thread_id: str) -> bool:
actor = get_actor_context()
if actor is None or actor.user_id is None:
return True
async with self._session_factory() as session:
thread_repo = build_thread_meta_repository(session)
thread = await thread_repo.get_thread_meta(thread_id)
if thread is None:
return True
return thread.user_id is None or thread.user_id == actor.user_id
def _event_create_from_dict(event: dict[str, Any]) -> RunEventCreate:
created_at = event.get("created_at")
return RunEventCreate(
thread_id=str(event["thread_id"]),
run_id=str(event["run_id"]),
event_type=str(event["event_type"]),
category=str(event["category"]),
content=event.get("content", ""),
metadata=dict(event.get("metadata") or {}),
created_at=datetime.fromisoformat(created_at) if isinstance(created_at, str) else created_at,
)
def _event_to_dict(event: RunEvent) -> dict[str, Any]:
return {
"thread_id": event.thread_id,
"run_id": event.run_id,
"event_type": event.event_type,
"category": event.category,
"content": event.content,
"metadata": event.metadata,
"seq": event.seq,
"created_at": event.created_at.isoformat(),
}
-515
View File
@@ -1,515 +0,0 @@
"""Run lifecycle persistence adapters owned by the app layer."""
from __future__ import annotations
import logging
import uuid
from collections.abc import Callable
from typing import Protocol, TypedDict, Unpack, cast
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from store.repositories import FeedbackCreate, Run, RunCreate, build_feedback_repository, build_run_repository
from deerflow.runtime.actor_context import AUTO, resolve_user_id
from deerflow.runtime.serialization import serialize_lc_object
from deerflow.runtime.runs.observer import LifecycleEventType, RunLifecycleEvent, RunObserver
from deerflow.runtime.stream_bridge import JSONValue
from .thread_meta import ThreadMetaStorage
logger = logging.getLogger(__name__)
class RunCreateFields(TypedDict, total=False):
status: str
created_at: str
started_at: str
ended_at: str
assistant_id: str | None
user_id: str | None
follow_up_to_run_id: str | None
metadata: dict[str, JSONValue]
kwargs: dict[str, JSONValue]
class RunStatusUpdateFields(TypedDict, total=False):
started_at: str
ended_at: str
metadata: dict[str, JSONValue]
class RunCompletionFields(TypedDict, total=False):
total_input_tokens: int
total_output_tokens: int
total_tokens: int
llm_call_count: int
lead_agent_tokens: int
subagent_tokens: int
middleware_tokens: int
message_count: int
last_ai_message: str | None
first_human_message: str | None
error: str | None
class RunRow(TypedDict, total=False):
run_id: str
thread_id: str
assistant_id: str | None
status: str
multitask_strategy: str
follow_up_to_run_id: str | None
metadata: dict[str, JSONValue]
created_at: str
updated_at: str
started_at: str | None
ended_at: str | None
error: str | None
class RunReadRepository(Protocol):
"""Protocol for durable run queries."""
async def get(self, run_id: str, *, user_id: str | None | object = AUTO) -> RunRow | None: ...
async def list_by_thread(
self,
thread_id: str,
*,
limit: int = 100,
user_id: str | None | object = AUTO,
) -> list[RunRow]: ...
class RunWriteRepository(Protocol):
"""Protocol for durable run writes."""
async def create(self, run_id: str, thread_id: str, **kwargs: Unpack[RunCreateFields]) -> None: ...
async def update_status(
self,
run_id: str,
status: str,
**kwargs: Unpack[RunStatusUpdateFields],
) -> None: ...
async def set_error(self, run_id: str, error: str) -> None: ...
async def update_run_completion(
self,
run_id: str,
*,
status: str,
**kwargs: Unpack[RunCompletionFields],
) -> None: ...
class RunDeleteRepository(Protocol):
"""Protocol for durable run deletion."""
async def delete(self, run_id: str) -> bool: ...
class _RepositoryContext:
def __init__(
self,
session_factory: async_sessionmaker[AsyncSession],
build_repo: Callable[[AsyncSession], object],
*,
commit: bool,
) -> None:
self._session_factory = session_factory
self._build_repo = build_repo
self._commit = commit
self._session: AsyncSession | None = None
async def __aenter__(self):
self._session = self._session_factory()
return self._build_repo(self._session)
async def __aexit__(self, exc_type, exc, tb) -> None:
if self._session is None:
return
try:
if self._commit:
if exc_type is None:
await self._session.commit()
else:
await self._session.rollback()
finally:
await self._session.close()
def _run_to_row(row: Run) -> RunRow:
return {
"run_id": row.run_id,
"thread_id": row.thread_id,
"assistant_id": row.assistant_id,
"user_id": row.user_id,
"status": row.status,
"model_name": row.model_name,
"multitask_strategy": row.multitask_strategy,
"follow_up_to_run_id": row.follow_up_to_run_id,
"metadata": cast(dict[str, JSONValue], row.metadata),
"kwargs": cast(dict[str, JSONValue], row.kwargs),
"created_at": row.created_time.isoformat(),
"updated_at": row.updated_time.isoformat() if row.updated_time else "",
"total_input_tokens": row.total_input_tokens,
"total_output_tokens": row.total_output_tokens,
"total_tokens": row.total_tokens,
"llm_call_count": row.llm_call_count,
"lead_agent_tokens": row.lead_agent_tokens,
"subagent_tokens": row.subagent_tokens,
"middleware_tokens": row.middleware_tokens,
"message_count": row.message_count,
"first_human_message": row.first_human_message,
"last_ai_message": row.last_ai_message,
"error": row.error,
}
class FeedbackStoreAdapter:
"""Expose feedback route semantics on top of storage package repositories."""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._session_factory = session_factory
async def create(
self,
*,
run_id: str,
thread_id: str,
rating: int,
owner_id: str | None = None,
user_id: str | None = None,
message_id: str | None = None,
comment: str | None = None,
) -> dict[str, object]:
if rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {rating}")
effective_user_id = user_id if user_id is not None else owner_id
async with self._transaction() as repo:
row = await repo.create_feedback(
FeedbackCreate(
feedback_id=str(uuid.uuid4()),
run_id=run_id,
thread_id=thread_id,
rating=rating,
user_id=effective_user_id,
message_id=message_id,
comment=comment,
)
)
return _feedback_to_dict(row)
async def get(self, feedback_id: str) -> dict[str, object] | None:
async with self._read() as repo:
row = await repo.get_feedback(feedback_id)
return _feedback_to_dict(row) if row is not None else None
async def list_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 100,
user_id: str | None = None,
) -> list[dict[str, object]]:
async with self._read() as repo:
rows = await repo.list_feedback_by_run(run_id)
filtered = [row for row in rows if row.thread_id == thread_id]
if user_id is not None:
filtered = [row for row in filtered if row.user_id == user_id]
return [_feedback_to_dict(row) for row in filtered][:limit]
async def list_by_thread(self, thread_id: str, *, limit: int = 100) -> list[dict[str, object]]:
async with self._read() as repo:
rows = await repo.list_feedback_by_thread(thread_id)
return [_feedback_to_dict(row) for row in rows][:limit]
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict[str, object]:
rows = await self.list_by_run(thread_id, run_id)
positive = sum(1 for row in rows if row["rating"] == 1)
negative = sum(1 for row in rows if row["rating"] == -1)
return {"run_id": run_id, "total": len(rows), "positive": positive, "negative": negative}
async def delete(self, feedback_id: str) -> bool:
async with self._transaction() as repo:
return await repo.delete_feedback(feedback_id)
async def upsert(
self,
*,
run_id: str,
thread_id: str,
rating: int,
user_id: str,
comment: str | None = None,
) -> dict[str, object]:
if rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {rating}")
async with self._transaction() as repo:
rows = await repo.list_feedback_by_run(run_id)
existing = next((row for row in rows if row.thread_id == thread_id and row.user_id == user_id), None)
feedback_id = existing.feedback_id if existing is not None else str(uuid.uuid4())
if existing is not None:
await repo.delete_feedback(existing.feedback_id)
row = await repo.create_feedback(
FeedbackCreate(
feedback_id=feedback_id,
run_id=run_id,
thread_id=thread_id,
rating=rating,
user_id=user_id,
comment=comment,
)
)
return _feedback_to_dict(row)
async def delete_by_run(self, *, thread_id: str, run_id: str, user_id: str) -> bool:
async with self._transaction() as repo:
rows = await repo.list_feedback_by_run(run_id)
existing = next((row for row in rows if row.thread_id == thread_id and row.user_id == user_id), None)
if existing is None:
return False
return await repo.delete_feedback(existing.feedback_id)
async def list_by_thread_grouped(self, thread_id: str, *, user_id: str) -> dict[str, dict[str, object]]:
rows = await self.list_by_thread(thread_id)
return {
row["run_id"]: row
for row in rows
if row["user_id"] == user_id
}
def _read(self) -> _RepositoryContext:
return _RepositoryContext(self._session_factory, build_feedback_repository, commit=False)
def _transaction(self) -> _RepositoryContext:
return _RepositoryContext(self._session_factory, build_feedback_repository, commit=True)
def _feedback_to_dict(row) -> dict[str, object]:
return {
"feedback_id": row.feedback_id,
"run_id": row.run_id,
"thread_id": row.thread_id,
"user_id": row.user_id,
"owner_id": row.user_id,
"message_id": row.message_id,
"rating": row.rating,
"comment": row.comment,
"created_at": row.created_time.isoformat(),
}
class RunStoreAdapter:
"""Expose runs facade storage semantics on top of storage package repositories."""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._session_factory = session_factory
async def get(self, run_id: str, *, user_id: str | None | object = AUTO) -> RunRow | None:
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.get")
async with self._read() as repo:
row = await repo.get_run(run_id)
if row is None:
return None
if effective_user_id is not None and row.user_id != effective_user_id:
return None
return _run_to_row(row)
async def list_by_thread(
self,
thread_id: str,
*,
limit: int = 100,
user_id: str | None | object = AUTO,
) -> list[RunRow]:
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.list_by_thread")
async with self._read() as repo:
rows = await repo.list_runs_by_thread(thread_id, limit=limit, offset=0)
if effective_user_id is not None:
rows = [row for row in rows if row.user_id == effective_user_id]
return [_run_to_row(row) for row in rows]
async def create(self, run_id: str, thread_id: str, **kwargs: Unpack[RunCreateFields]) -> None:
metadata = cast(dict[str, JSONValue], serialize_lc_object(kwargs.get("metadata") or {}))
run_kwargs = cast(dict[str, JSONValue], serialize_lc_object(kwargs.get("kwargs") or {}))
effective_user_id = resolve_user_id(kwargs.get("user_id", AUTO), method_name="RunStoreAdapter.create")
async with self._transaction() as repo:
await repo.create_run(
RunCreate(
run_id=run_id,
thread_id=thread_id,
assistant_id=kwargs.get("assistant_id"),
user_id=effective_user_id,
status=kwargs.get("status", "pending"),
metadata=dict(metadata),
kwargs=dict(run_kwargs),
follow_up_to_run_id=kwargs.get("follow_up_to_run_id"),
)
)
async def delete(self, run_id: str, *, user_id: str | None | object = AUTO) -> bool:
async with self._transaction() as repo:
existing = await repo.get_run(run_id)
if existing is None:
return False
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.delete")
if effective_user_id is not None and existing.user_id != effective_user_id:
return False
await repo.delete_run(run_id)
return True
async def update_status(
self,
run_id: str,
status: str,
**kwargs: Unpack[RunStatusUpdateFields],
) -> None:
async with self._transaction() as repo:
await repo.update_run_status(run_id, status)
async def set_error(self, run_id: str, error: str) -> None:
async with self._transaction() as repo:
await repo.update_run_status(run_id, "error", error=error)
async def update_run_completion(
self,
run_id: str,
*,
status: str,
**kwargs: Unpack[RunCompletionFields],
) -> None:
async with self._transaction() as repo:
await repo.update_run_completion(
run_id,
status=status,
total_input_tokens=kwargs.get("total_input_tokens", 0),
total_output_tokens=kwargs.get("total_output_tokens", 0),
total_tokens=kwargs.get("total_tokens", 0),
llm_call_count=kwargs.get("llm_call_count", 0),
lead_agent_tokens=kwargs.get("lead_agent_tokens", 0),
subagent_tokens=kwargs.get("subagent_tokens", 0),
middleware_tokens=kwargs.get("middleware_tokens", 0),
message_count=kwargs.get("message_count", 0),
last_ai_message=kwargs.get("last_ai_message"),
first_human_message=kwargs.get("first_human_message"),
error=kwargs.get("error"),
)
def _read(self) -> _RepositoryContext:
return _RepositoryContext(self._session_factory, build_run_repository, commit=False)
def _transaction(self) -> _RepositoryContext:
return _RepositoryContext(self._session_factory, build_run_repository, commit=True)
class StorageRunObserver(RunObserver):
"""Persist run lifecycle state into app-owned repositories."""
def __init__(
self,
run_write_repo: RunWriteRepository | None = None,
thread_meta_storage: ThreadMetaStorage | None = None,
) -> None:
self._run_write_repo = run_write_repo
self._thread_meta_storage = thread_meta_storage
async def on_event(self, event: RunLifecycleEvent) -> None:
try:
await self._dispatch(event)
except Exception:
logger.exception(
"StorageRunObserver failed to persist event %s for run %s",
event.event_type,
event.run_id,
)
async def _dispatch(self, event: RunLifecycleEvent) -> None:
handlers = {
LifecycleEventType.RUN_STARTED: self._handle_run_started,
LifecycleEventType.RUN_COMPLETED: self._handle_run_completed,
LifecycleEventType.RUN_FAILED: self._handle_run_failed,
LifecycleEventType.RUN_CANCELLED: self._handle_run_cancelled,
LifecycleEventType.THREAD_STATUS_UPDATED: self._handle_thread_status,
}
handler = handlers.get(event.event_type)
if handler:
await handler(event)
async def _handle_run_started(self, event: RunLifecycleEvent) -> None:
if self._run_write_repo:
await self._run_write_repo.update_status(
run_id=event.run_id,
status="running",
started_at=event.occurred_at.isoformat(),
)
async def _handle_run_completed(self, event: RunLifecycleEvent) -> None:
payload = dict(event.payload) if event.payload else {}
if self._run_write_repo:
completion_data = payload.get("completion_data")
if isinstance(completion_data, dict):
await self._run_write_repo.update_run_completion(
run_id=event.run_id,
status="success",
**cast(RunCompletionFields, completion_data),
)
else:
await self._run_write_repo.update_status(
run_id=event.run_id,
status="success",
ended_at=event.occurred_at.isoformat(),
)
if self._thread_meta_storage and "title" in payload:
await self._thread_meta_storage.sync_thread_title(
thread_id=event.thread_id,
title=payload["title"],
)
async def _handle_run_failed(self, event: RunLifecycleEvent) -> None:
if self._run_write_repo:
payload = dict(event.payload) if event.payload else {}
error = payload.get("error", "Unknown error")
completion_data = payload.get("completion_data")
if isinstance(completion_data, dict):
await self._run_write_repo.update_run_completion(
run_id=event.run_id,
status="error",
error=str(error),
**cast(RunCompletionFields, completion_data),
)
else:
await self._run_write_repo.update_status(
run_id=event.run_id,
status="error",
ended_at=event.occurred_at.isoformat(),
)
await self._run_write_repo.set_error(run_id=event.run_id, error=str(error))
async def _handle_run_cancelled(self, event: RunLifecycleEvent) -> None:
if self._run_write_repo:
payload = dict(event.payload) if event.payload else {}
completion_data = payload.get("completion_data")
if isinstance(completion_data, dict):
await self._run_write_repo.update_run_completion(
run_id=event.run_id,
status="interrupted",
**cast(RunCompletionFields, completion_data),
)
else:
await self._run_write_repo.update_status(
run_id=event.run_id,
status="interrupted",
ended_at=event.occurred_at.isoformat(),
)
async def _handle_thread_status(self, event: RunLifecycleEvent) -> None:
if self._thread_meta_storage:
payload = dict(event.payload) if event.payload else {}
status = payload.get("status", "idle")
await self._thread_meta_storage.sync_thread_status(
thread_id=event.thread_id,
status=status,
)
-208
View File
@@ -1,208 +0,0 @@
"""Thread metadata storage adapter owned by the app layer."""
from __future__ import annotations
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from store.repositories import build_thread_meta_repository
from store.repositories.contracts import (
ThreadMeta,
ThreadMetaCreate,
ThreadMetaRepositoryProtocol,
)
from deerflow.runtime.actor_context import AUTO, resolve_user_id
class ThreadMetaStoreAdapter:
"""Use storage package thread repositories with per-call sessions."""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._session_factory = session_factory
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
async with self._transaction() as repo:
return await repo.create_thread_meta(data)
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
async with self._read() as repo:
return await repo.get_thread_meta(thread_id)
async def update_thread_meta(
self,
thread_id: str,
*,
assistant_id: str | None = None,
display_name: str | None = None,
status: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
async with self._transaction() as repo:
await repo.update_thread_meta(
thread_id,
assistant_id=assistant_id,
display_name=display_name,
status=status,
metadata=metadata,
)
async def delete_thread(self, thread_id: str) -> None:
async with self._transaction() as repo:
await repo.delete_thread(thread_id)
async def search_threads(
self,
*,
metadata: dict[str, Any] | None = None,
status: str | None = None,
user_id: str | None = None,
assistant_id: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[ThreadMeta]:
async with self._read() as repo:
return await repo.search_threads(
metadata=metadata,
status=status,
user_id=user_id,
assistant_id=assistant_id,
limit=limit,
offset=offset,
)
def _read(self):
return _ThreadMetaRepositoryContext(self._session_factory, commit=False)
def _transaction(self):
return _ThreadMetaRepositoryContext(self._session_factory, commit=True)
class _ThreadMetaRepositoryContext:
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, commit: bool) -> None:
self._session_factory = session_factory
self._commit = commit
self._session: AsyncSession | None = None
async def __aenter__(self):
self._session = self._session_factory()
return build_thread_meta_repository(self._session)
async def __aexit__(self, exc_type, exc, tb) -> None:
if self._session is None:
return
try:
if self._commit:
if exc_type is None:
await self._session.commit()
else:
await self._session.rollback()
finally:
await self._session.close()
class ThreadMetaStorage:
"""App-facing adapter around the storage thread metadata contract."""
def __init__(self, repo: ThreadMetaRepositoryProtocol) -> None:
self._repo = repo
async def get_thread(self, thread_id: str, *, user_id: str | None | object = AUTO) -> ThreadMeta | None:
thread = await self._repo.get_thread_meta(thread_id)
if thread is None:
return None
effective_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.get_thread")
if effective_user_id is not None and thread.user_id != effective_user_id:
return None
return thread
async def ensure_thread(
self,
*,
thread_id: str,
assistant_id: str | None = None,
metadata: dict[str, Any] | None = None,
user_id: str | None | object = AUTO,
) -> ThreadMeta:
effective_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.ensure_thread")
existing = await self.get_thread(thread_id, user_id=effective_user_id)
if existing is not None:
return existing
return await self._repo.create_thread_meta(
ThreadMetaCreate(
thread_id=thread_id,
assistant_id=assistant_id,
user_id=effective_user_id,
metadata=metadata or {},
)
)
async def ensure_thread_running(
self,
*,
thread_id: str,
assistant_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> ThreadMeta | None:
existing = await self._repo.get_thread_meta(thread_id)
if existing is None:
return await self._repo.create_thread_meta(
ThreadMetaCreate(
thread_id=thread_id,
assistant_id=assistant_id,
status="running",
metadata=metadata or {},
)
)
await self._repo.update_thread_meta(thread_id, status="running")
return await self._repo.get_thread_meta(thread_id)
async def sync_thread_title(self, *, thread_id: str, title: str) -> None:
await self._repo.update_thread_meta(thread_id, display_name=title)
async def sync_thread_assistant_id(self, *, thread_id: str, assistant_id: str) -> None:
await self._repo.update_thread_meta(thread_id, assistant_id=assistant_id)
async def sync_thread_status(self, *, thread_id: str, status: str) -> None:
await self._repo.update_thread_meta(thread_id, status=status)
async def sync_thread_metadata(
self,
*,
thread_id: str,
metadata: dict[str, Any],
) -> None:
await self._repo.update_thread_meta(thread_id, metadata=metadata)
async def delete_thread(self, thread_id: str) -> None:
await self._repo.delete_thread(thread_id)
async def search_threads(
self,
*,
metadata: dict[str, Any] | None = None,
status: str | None = None,
user_id: str | None | object = AUTO,
assistant_id: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[ThreadMeta]:
normalized_status = status.strip() if status is not None else None
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.search_threads")
normalized_user_id = resolved_user_id.strip() if resolved_user_id is not None else None
normalized_assistant_id = (
assistant_id.strip() if assistant_id is not None else None
)
return await self._repo.search_threads(
metadata=metadata,
status=normalized_status or None,
user_id=normalized_user_id or None,
assistant_id=normalized_assistant_id or None,
limit=limit,
offset=offset,
)
__all__ = ["ThreadMetaStorage", "ThreadMetaStoreAdapter"]
@@ -1,6 +0,0 @@
"""App-owned stream bridge adapters and factory."""
from .factory import build_stream_bridge
from .adapters import MemoryStreamBridge, RedisStreamBridge
__all__ = ["MemoryStreamBridge", "RedisStreamBridge", "build_stream_bridge"]
@@ -1,6 +0,0 @@
"""Concrete stream bridge adapters owned by the app layer."""
from .memory import MemoryStreamBridge
from .redis import RedisStreamBridge
__all__ = ["MemoryStreamBridge", "RedisStreamBridge"]
@@ -1,450 +0,0 @@
"""In-memory stream bridge implementation owned by the app layer."""
from __future__ import annotations
import asyncio
import json
import logging
import time
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from typing import Any, Literal
from deerflow.runtime.stream_bridge import (
CANCELLED_SENTINEL,
END_SENTINEL,
HEARTBEAT_SENTINEL,
TERMINAL_STATES,
ResumeResult,
StreamBridge,
StreamEvent,
StreamStatus,
)
from deerflow.runtime.stream_bridge.exceptions import (
BridgeClosedError,
StreamCapacityExceededError,
StreamTerminatedError,
)
logger = logging.getLogger(__name__)
@dataclass
class _RunStream:
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
events: list[StreamEvent] = field(default_factory=list)
id_to_offset: dict[str, int] = field(default_factory=dict)
start_offset: int = 0
current_bytes: int = 0
seq: int = 0
status: StreamStatus = StreamStatus.ACTIVE
created_at: float = field(default_factory=time.monotonic)
last_publish_at: float | None = None
ended_at: float | None = None
subscriber_count: int = 0
last_subscribe_at: float | None = None
awaiting_input: bool = False
awaiting_since: float | None = None
class MemoryStreamBridge(StreamBridge):
"""Per-run in-memory event log implementation."""
def __init__(
self,
*,
max_events_per_stream: int = 256,
max_bytes_per_stream: int = 10 * 1024 * 1024,
max_active_streams: int = 1000,
stream_eviction_policy: Literal["reject", "lru"] = "lru",
terminal_retention_ttl: float = 300.0,
active_no_publish_timeout: float = 600.0,
orphan_timeout: float = 60.0,
max_stream_age: float = 86400.0,
hitl_extended_timeout: float = 7200.0,
cleanup_interval: float = 30.0,
queue_maxsize: int | None = None,
) -> None:
if queue_maxsize is not None:
max_events_per_stream = queue_maxsize
self._max_events = max_events_per_stream
self._max_bytes = max_bytes_per_stream
self._max_streams = max_active_streams
self._eviction_policy = stream_eviction_policy
self._terminal_ttl = terminal_retention_ttl
self._active_timeout = active_no_publish_timeout
self._orphan_timeout = orphan_timeout
self._max_age = max_stream_age
self._hitl_timeout = hitl_extended_timeout
self._cleanup_interval = cleanup_interval
self._streams: dict[str, _RunStream] = {}
self._registry_lock = asyncio.Lock()
self._closed = False
self._cleanup_task: asyncio.Task[None] | None = None
async def start(self) -> None:
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info(
"MemoryStreamBridge started (max_events=%d, max_bytes=%d, max_streams=%d)",
self._max_events,
self._max_bytes,
self._max_streams,
)
async def close(self) -> None:
async with self._registry_lock:
self._closed = True
if self._cleanup_task is not None:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
for stream in self._streams.values():
async with stream.condition:
stream.status = StreamStatus.CLOSED
stream.condition.notify_all()
self._streams.clear()
logger.info("MemoryStreamBridge closed")
async def _get_or_create_stream(self, run_id: str) -> _RunStream:
stream = self._streams.get(run_id)
if stream is not None:
return stream
async with self._registry_lock:
if self._closed:
raise BridgeClosedError("Stream bridge is closed")
stream = self._streams.get(run_id)
if stream is not None:
return stream
if len(self._streams) >= self._max_streams:
if self._eviction_policy == "reject":
raise StreamCapacityExceededError(
f"Max {self._max_streams} active streams reached"
)
evicted = self._evict_oldest_terminal()
if evicted is None:
raise StreamCapacityExceededError("All streams active, cannot evict")
logger.info("Evicted stream %s to make room", evicted)
stream = _RunStream()
self._streams[run_id] = stream
logger.debug("Created stream for run %s", run_id)
return stream
def _evict_oldest_terminal(self) -> str | None:
oldest_run_id: str | None = None
oldest_ended_at: float = float("inf")
for run_id, stream in self._streams.items():
if stream.status in TERMINAL_STATES and stream.ended_at is not None:
if stream.ended_at < oldest_ended_at:
oldest_ended_at = stream.ended_at
oldest_run_id = run_id
if oldest_run_id is not None:
del self._streams[oldest_run_id]
return oldest_run_id
return None
def _next_id(self, stream: _RunStream) -> str:
stream.seq += 1
return f"{int(time.time() * 1000)}-{stream.seq}"
def _estimate_size(self, event: StreamEvent) -> int:
base = len(event.id) + len(event.event) + 100
if event.data is None:
return base
if isinstance(event.data, str):
return base + len(event.data)
if isinstance(event.data, (dict, list)):
try:
return base + len(json.dumps(event.data, default=str))
except (TypeError, ValueError):
return base + 200
return base + 50
def _evict_overflow(self, stream: _RunStream) -> None:
while len(stream.events) > self._max_events or stream.current_bytes > self._max_bytes:
if not stream.events:
break
evicted = stream.events.pop(0)
stream.id_to_offset.pop(evicted.id, None)
stream.current_bytes -= self._estimate_size(evicted)
stream.start_offset += 1
async def publish(self, run_id: str, event: str, data: Any) -> str:
stream = await self._get_or_create_stream(run_id)
async with stream.condition:
if stream.status != StreamStatus.ACTIVE:
raise StreamTerminatedError(
f"Cannot publish to {stream.status.value} stream"
)
entry = StreamEvent(id=self._next_id(stream), event=event, data=data)
absolute_offset = stream.start_offset + len(stream.events)
stream.events.append(entry)
stream.id_to_offset[entry.id] = absolute_offset
stream.current_bytes += self._estimate_size(entry)
stream.last_publish_at = time.monotonic()
self._evict_overflow(stream)
stream.condition.notify_all()
return entry.id
async def publish_end(self, run_id: str) -> str:
return await self.publish_terminal(run_id, StreamStatus.ENDED)
async def publish_terminal(
self,
run_id: str,
kind: StreamStatus,
data: Any = None,
) -> str:
if kind not in TERMINAL_STATES:
raise ValueError(f"Invalid terminal kind: {kind}")
stream = await self._get_or_create_stream(run_id)
async with stream.condition:
if stream.status != StreamStatus.ACTIVE:
for evt in reversed(stream.events):
if evt.event in ("end", "cancel", "error", "dead_letter"):
return evt.id
return ""
event_name = {
StreamStatus.ENDED: "end",
StreamStatus.CANCELLED: "cancel",
StreamStatus.ERRORED: "error",
}[kind]
entry = StreamEvent(id=self._next_id(stream), event=event_name, data=data)
absolute_offset = stream.start_offset + len(stream.events)
stream.events.append(entry)
stream.id_to_offset[entry.id] = absolute_offset
stream.current_bytes += self._estimate_size(entry)
stream.status = kind
stream.ended_at = time.monotonic()
stream.awaiting_input = False
stream.condition.notify_all()
logger.debug("Stream %s terminal: %s", run_id, kind.value)
return entry.id
async def cancel(self, run_id: str) -> None:
await self.publish_terminal(run_id, StreamStatus.CANCELLED)
async def subscribe(
self,
run_id: str,
*,
last_event_id: str | None = None,
heartbeat_interval: float = 15.0,
) -> AsyncIterator[StreamEvent]:
stream = await self._get_or_create_stream(run_id)
resume = self._resolve_resume_point(stream, last_event_id)
next_offset = resume.next_offset
async with stream.condition:
stream.subscriber_count += 1
stream.last_subscribe_at = time.monotonic()
try:
while True:
entry_to_yield: StreamEvent | None = None
sentinel_to_yield: StreamEvent | None = None
should_return = False
should_wait = False
async with stream.condition:
if self._closed or stream.status == StreamStatus.CLOSED:
sentinel_to_yield = CANCELLED_SENTINEL
should_return = True
elif next_offset < stream.start_offset:
next_offset = stream.start_offset
else:
local_index = next_offset - stream.start_offset
if 0 <= local_index < len(stream.events):
entry_to_yield = stream.events[local_index]
next_offset += 1
if entry_to_yield.event in ("end", "cancel", "error", "dead_letter"):
should_return = True
elif stream.status in TERMINAL_STATES:
sentinel_to_yield = END_SENTINEL
should_return = True
else:
should_wait = True
try:
await asyncio.wait_for(
stream.condition.wait(),
timeout=heartbeat_interval,
)
except TimeoutError:
pass
if sentinel_to_yield is not None:
yield sentinel_to_yield
if should_return:
return
continue
if entry_to_yield is not None:
yield entry_to_yield
if should_return:
return
continue
if should_wait:
async with stream.condition:
local_index = next_offset - stream.start_offset
has_events = 0 <= local_index < len(stream.events)
is_terminal = stream.status in TERMINAL_STATES
if not has_events and not is_terminal:
yield HEARTBEAT_SENTINEL
finally:
async with stream.condition:
stream.subscriber_count = max(0, stream.subscriber_count - 1)
async def mark_awaiting_input(self, run_id: str) -> None:
stream = self._streams.get(run_id)
if stream is None:
return
async with stream.condition:
if stream.status == StreamStatus.ACTIVE:
stream.awaiting_input = True
stream.awaiting_since = time.monotonic()
logger.debug("Stream %s marked as awaiting input", run_id)
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
if delay > 0:
await asyncio.sleep(delay)
await self._do_cleanup(run_id, "manual")
async def _do_cleanup(self, run_id: str, reason: str) -> None:
async with self._registry_lock:
stream = self._streams.pop(run_id, None)
if stream is not None:
async with stream.condition:
stream.status = StreamStatus.CLOSED
stream.condition.notify_all()
logger.debug("Cleaned up stream %s (reason: %s)", run_id, reason)
async def _mark_dead_letter(self, run_id: str, reason: str) -> None:
stream = self._streams.get(run_id)
if stream is None:
return
async with stream.condition:
if stream.status != StreamStatus.ACTIVE:
return
entry = StreamEvent(
id=self._next_id(stream),
event="dead_letter",
data={"reason": reason, "timestamp": time.time()},
)
absolute_offset = stream.start_offset + len(stream.events)
stream.events.append(entry)
stream.id_to_offset[entry.id] = absolute_offset
stream.current_bytes += self._estimate_size(entry)
stream.status = StreamStatus.ERRORED
stream.ended_at = time.monotonic()
stream.condition.notify_all()
logger.warning("Stream %s marked as dead letter: %s", run_id, reason)
async def _cleanup_loop(self) -> None:
while not self._closed:
try:
await asyncio.sleep(self._cleanup_interval)
except asyncio.CancelledError:
break
now = time.monotonic()
to_cleanup: list[tuple[str, str]] = []
to_mark_dead: list[tuple[str, str]] = []
async with self._registry_lock:
for run_id, stream in list(self._streams.items()):
if now - stream.created_at > self._max_age:
to_cleanup.append((run_id, "max_age_exceeded"))
continue
if stream.status == StreamStatus.ACTIVE:
timeout = self._hitl_timeout if stream.awaiting_input else self._active_timeout
last_activity = stream.last_publish_at or stream.created_at
if now - last_activity > timeout:
to_mark_dead.append((run_id, "no_publish_timeout"))
continue
if stream.status in TERMINAL_STATES and stream.ended_at:
if stream.subscriber_count > 0:
continue
last_sub = stream.last_subscribe_at or stream.ended_at
if now - last_sub > self._orphan_timeout:
to_cleanup.append((run_id, "orphan"))
continue
if now - stream.ended_at > self._terminal_ttl:
to_cleanup.append((run_id, "ttl_expired"))
for run_id, reason in to_mark_dead:
await self._mark_dead_letter(run_id, reason)
for run_id, reason in to_cleanup:
await self._do_cleanup(run_id, reason)
def get_stats(self) -> dict[str, Any]:
active = sum(1 for s in self._streams.values() if s.status == StreamStatus.ACTIVE)
terminal = sum(1 for s in self._streams.values() if s.status in TERMINAL_STATES)
total_events = sum(len(s.events) for s in self._streams.values())
total_bytes = sum(s.current_bytes for s in self._streams.values())
total_subs = sum(s.subscriber_count for s in self._streams.values())
return {
"total_streams": len(self._streams),
"active_streams": active,
"terminal_streams": terminal,
"total_events": total_events,
"total_bytes": total_bytes,
"total_subscribers": total_subs,
"closed": self._closed,
}
def _resolve_resume_point(
self,
stream: _RunStream,
last_event_id: str | None,
) -> ResumeResult:
if last_event_id is None:
return ResumeResult(next_offset=stream.start_offset, status="fresh")
if last_event_id in stream.id_to_offset:
return ResumeResult(
next_offset=stream.id_to_offset[last_event_id] + 1,
status="resumed",
)
parts = last_event_id.split("-")
if len(parts) != 2:
return ResumeResult(next_offset=stream.start_offset, status="invalid")
try:
event_ts = int(parts[0])
_event_seq = int(parts[1])
except ValueError:
return ResumeResult(next_offset=stream.start_offset, status="invalid")
if stream.events:
try:
oldest_parts = stream.events[0].id.split("-")
oldest_ts = int(oldest_parts[0])
if event_ts < oldest_ts:
return ResumeResult(
next_offset=stream.start_offset,
status="evicted",
gap_count=stream.start_offset,
)
except (ValueError, IndexError):
pass
return ResumeResult(next_offset=stream.start_offset, status="unknown")
__all__ = ["MemoryStreamBridge"]
@@ -1,37 +0,0 @@
"""Redis-backed stream bridge placeholder owned by the app layer."""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Any
from deerflow.runtime.stream_bridge import StreamBridge, StreamEvent
class RedisStreamBridge(StreamBridge):
"""Reserved app-owned Redis implementation.
Phase 1 intentionally keeps Redis out of the harness package. The concrete
implementation will live here once cross-process streaming is introduced.
"""
def __init__(self, *, redis_url: str) -> None:
self._redis_url = redis_url
async def publish(self, run_id: str, event: str, data: Any) -> str:
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
async def publish_end(self, run_id: str) -> str:
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
def subscribe(
self,
run_id: str,
*,
last_event_id: str | None = None,
heartbeat_interval: float = 15.0,
) -> AsyncIterator[StreamEvent]:
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
@@ -1,50 +0,0 @@
"""App-owned stream bridge factory."""
from __future__ import annotations
import logging
from collections.abc import AsyncIterator
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from deerflow.config.stream_bridge_config import get_stream_bridge_config
from deerflow.runtime.stream_bridge import StreamBridge
from .adapters import MemoryStreamBridge, RedisStreamBridge
logger = logging.getLogger(__name__)
def build_stream_bridge(config=None) -> AbstractAsyncContextManager[StreamBridge]:
"""Build the configured app-owned stream bridge."""
return _build_stream_bridge_impl(config)
@asynccontextmanager
async def _build_stream_bridge_impl(config=None) -> AsyncIterator[StreamBridge]:
if config is None:
config = get_stream_bridge_config()
if config is None or config.type == "memory":
maxsize = config.queue_maxsize if config is not None else 256
bridge = MemoryStreamBridge(queue_maxsize=maxsize)
await bridge.start()
logger.info("Stream bridge initialised: memory (queue_maxsize=%d)", maxsize)
try:
yield bridge
finally:
await bridge.close()
return
if config.type == "redis":
if not config.redis_url:
raise ValueError("Redis stream bridge requires redis_url")
bridge = RedisStreamBridge(redis_url=config.redis_url)
await bridge.start()
logger.info("Stream bridge initialised: redis (%s)", config.redis_url)
try:
yield bridge
finally:
await bridge.close()
return
raise ValueError(f"Unknown stream bridge type: {config.type!r}")
-15
View File
@@ -1,15 +0,0 @@
"""Entry point for running the Gateway API via `python app/main.py`.
Useful for IDE debugging (e.g., PyCharm / VS Code debug configurations).
Equivalent to: PYTHONPATH=. uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
"""
import uvicorn
if __name__ == "__main__":
uvicorn.run(
"app.gateway.app:app",
host="0.0.0.0",
port=8001,
reload=True,
)
-314
View File
@@ -1,314 +0,0 @@
# app.plugins Design Overview
This document describes the current role of `backend/app/plugins`, its plugin design contract, dependency boundaries, and how the current `auth` plugin provides services with minimal intrusion into the host application.
## 1. Overall Role
`app.plugins` is the application-side plugin boundary.
Its purpose is not to implement a generic plugin marketplace. Instead, it provides a clear boundary inside `app` for separable business capabilities, so that a capability can:
1. carry its own domain model, runtime state, and adapters inside the plugin
2. interact with the host application only through a limited set of seams
3. remain replaceable, removable, and extensible over time
The only real plugin currently implemented under this directory is [`auth`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth).
The current direction is not “put all logic into app”. It is:
1. the host application owns unified bootstrap, shared infrastructure, and top-level router assembly
2. each plugin owns its own business contract, persistence definitions, runtime state, and outward-facing adapters
## 2. Plugin Design Contract
### 2.1 A plugin should carry its own implementation
The primary contract visible in the current codebase is:
A plugins own ORM, runtime, domain, and adapters should be implemented inside the plugin itself. Core business behavior should not be scattered into unrelated external modules.
The `auth` plugin already follows that pattern with a fairly complete internal structure:
1. `domain`
- config, errors, JWT, password logic, domain models, service
2. `storage`
- plugin-owned ORM models, repository contracts, and repository implementations
3. `runtime`
- plugin-owned runtime config state
4. `api`
- plugin-owned HTTP router and schemas
5. `security`
- plugin-owned middleware, dependencies, CSRF logic, and LangGraph adapter
6. `authorization`
- plugin-owned permission model, policy resolution, and hooks
7. `injection`
- plugin-owned route-policy loading, injection, and validation
In other words, a plugin should be a self-contained capability module, not a bag of helpers.
### 2.2 The host app should provide shared infrastructure, not plugin internals
The current contract is not that every plugin must be fully infrastructure-independent.
It is:
1. a plugin may reuse the applications shared `engine`, `session_factory`, FastAPI app, and router tree
2. but the plugin must still own its table definitions, repositories, runtime config, and business/auth behavior
This is stated explicitly in [`auth/plugin.toml`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/plugin.toml):
1. `storage.mode = "shared_infrastructure"`
2. the plugin owns its storage definitions and repositories
3. but it reuses the applications shared persistence infrastructure
So the real rule is not “never reuse infrastructure”. The real rule is “do not outsource plugin business semantics to the rest of the app”.
### 2.3 Dependencies should remain one-way
The intended dependency direction in the current design is:
```text
gateway / app bootstrap
-> plugin public adapters
-> plugin domain / storage / runtime
```
Not:
```text
plugin domain
-> depends on app business modules
```
A plugin may depend on:
1. shared persistence infrastructure
2. `app.state` provided by the host application
3. generic framework capabilities such as FastAPI / Starlette
But its core business rules should not depend on unrelated app business modules, otherwise hot-swappability becomes unrealistic.
## 3. The Current auth Plugin Structure
The current `auth` plugin is effectively a self-contained authentication and authorization package with its own models, services, and adapters.
### 3.1 domain
[`auth/domain`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/domain) owns:
1. `config.py`
- auth-related configuration definition and loading
2. `errors.py`
- error codes and response contracts
3. `jwt.py`
- token encoding and decoding
4. `password.py`
- password hashing and verification
5. `models.py`
- auth domain models
6. `service.py`
- `AuthService` as the core business service
`AuthService` depends only on the plugins own `DbUserRepository` plus the shared session factory. The auth business logic is not reimplemented in `gateway`.
### 3.2 storage
[`auth/storage`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage) clearly shows the “ORM is owned by the plugin” contract:
1. [`models.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/models.py)
- defines the plugin-owned `users` table model
2. [`contracts.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/contracts.py)
- defines `User`, `UserCreate`, and `UserRepositoryProtocol`
3. [`repositories.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/repositories.py)
- implements `DbUserRepository`
The key point is:
1. the plugin defines its own ORM model
2. the plugin defines its own repository protocol
3. the plugin implements its own repository
4. external code only needs to provide a session or session factory
That is the minimal shared seam the boundary should preserve.
### 3.3 runtime
[`auth/runtime/config_state.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/runtime/config_state.py) keeps plugin-owned runtime config state:
1. `get_auth_config()`
2. `set_auth_config()`
3. `reset_auth_config()`
This matters because runtime state is also part of the plugin boundary. If future plugins need their own caches, state holders, or feature flags, they should follow the same pattern and keep them inside the plugin.
### 3.4 adapters
The `auth` plugin exposes capability through four main adapter groups:
1. `api/router.py`
- HTTP endpoints
2. `security/*`
- middleware, dependencies, request-user resolution, actor-context bridge
3. `authorization/*`
- capabilities, policy evaluators, auth hooks
4. `injection/*`
- route-policy registry, guard injection, startup validation
These adapters all follow the same rule:
1. entry-point behavior is defined inside the plugin
2. the host app only assembles and wires it
## 4. How a Plugin Interacts with the Host App
### 4.1 The top-level router only includes plugin routers
[`app/gateway/router.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/gateway/router.py) simply:
1. imports `app.plugins.auth.api.router`
2. calls `include_router(auth_router)`
That means the host app integrates auth HTTP behavior by assembly, not by duplicating login/register logic in `gateway`.
### 4.2 registrar performs wiring, not takeover
In [`app/gateway/registrar.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/gateway/registrar.py), the host app mainly does this:
1. `app.state.authz_hooks = build_authz_hooks()`
2. loads and validates the route-policy registry
3. calls `install_route_guards(app)`
4. calls `app.add_middleware(CSRFMiddleware)`
5. calls `app.add_middleware(AuthMiddleware)`
So the host app only wires the plugin in:
1. register middleware
2. install route guards
3. expose hooks and registries through `app.state`
The actual auth logic, authz logic, and route-policy semantics still live inside the plugin.
### 4.3 The plugin reuses shared sessions, but still owns business repositories
In [`auth/security/dependencies.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/security/dependencies.py):
1. the plugin reads the shared session factory from `request.app.state.persistence.session_factory`
2. constructs `DbUserRepository` itself
3. constructs `AuthService` itself
This is a good low-intrusion seam:
1. the outside world provides only shared infrastructure handles
2. the plugin decides how to instantiate its internal dependencies
## 5. Hot-Swappability and Low-Intrusion Principles
### 5.1 If a plugin serves other modules, it should minimize intrusion
When a plugin provides services to the rest of the app, the preferred patterns are:
1. expose a router
2. expose middleware or dependencies
3. expose hooks or protocols
4. inject a small number of shared objects through `app.state`
5. use config-driven route policies or capabilities instead of hardcoding checks inside business routes
Patterns to avoid:
1. large plugin-specific branches spread across `gateway`
2. unrelated business modules importing plugin ORM internals and rebuilding plugin logic themselves
3. plugin state being maintained across many global modules
### 5.2 Low-intrusion seams already visible in auth
The current `auth` plugin already uses four important low-intrusion seams:
1. router integration
- `gateway.router` only calls `include_router`
2. middleware integration
- `registrar` only registers `AuthMiddleware` and `CSRFMiddleware`
3. policy injection
- `install_route_guards(app)` appends `Depends(enforce_route_policy)` uniformly to routes
4. hook seam
- `authz_hooks` is exposed via `app.state`, so permission providers and policy builders can be replaced
This structure has three practical benefits:
1. host-app changes stay concentrated in the assembly layer
2. plugin core logic stays concentrated inside the plugin directory
3. swapping implementations does not require editing business routes one by one
### 5.3 Route policy is a key low-intrusion mechanism
[`auth/injection/registry_loader.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/registry_loader.py), [`validation.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/validation.py), and [`route_injector.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/route_injector.py) together form an important contract:
1. route policies live in the plugin-owned `route_policies.yaml`
2. startup validates that policy entries and real routes stay aligned
3. guards are attached by uniform injection instead of manual per-endpoint code
That allows the plugin to:
1. describe which routes are public, which capabilities are required, and which owner policies apply
2. avoid large invasive changes to the host routing layer
3. remain easier to replace or trim down later
## 6. What “ORM and runtime are implemented inside the plugin” Should Mean
That contract should be read as three concrete rules:
1. data models belong to the plugin
- the plugins own tables, Pydantic contracts, repository protocols, and repository implementations stay inside the plugin directory
2. runtime state belongs to the plugin
- plugin-owned config caches, context bridges, and plugin-level hooks stay inside the plugin
3. the outside world exposes infrastructure, not plugin semantics
- for example shared `session_factory`, FastAPI app, and `app.state`
Using `auth` as the example:
1. the `users` table is defined inside the plugin, not in `app.infra`
2. `AuthService` is implemented inside the plugin, not in `gateway`
3. `get_auth_config()` is maintained inside the plugin, not cached elsewhere
4. `AuthMiddleware`, `route_guard`, and `AuthzHooks` are all provided by the plugin itself
This is the structural prerequisite for meaningful pluginization later.
## 7. Current Scope and Non-Goals
At the current stage, the role of `app.plugins` is mainly:
1. to create module boundaries for separable application-side capabilities
2. to let each plugin own its own domain/storage/runtime/adapters
3. to connect plugins to the host app through assembly-oriented seams
The current non-goals are also clear:
1. this is not yet a full generic plugin discovery/installation system
2. plugins are not dynamically enabled or disabled at runtime
3. shared infrastructure is not being duplicated into every plugin
So at this stage, “hot-swappable” should be interpreted more precisely as:
1. plugin boundaries stay as independent as possible
2. integration points stay concentrated in the assembly layer
3. replacing or removing a plugin should mostly affect a small number of places such as `registrar`, router includes, and `app.state` hooks
## 8. Suggested Evolution Rules
If `app.plugins` is going to become a more stable plugin boundary, the codebase should keep following these rules:
1. each plugin directory should keep a `domain` / `storage` / `runtime` / `adapter` split
2. plugin-owned ORM and repositories should not drift into shared business directories
3. when a plugin serves the rest of the app, it should prefer exposing protocols, hooks, routers, and middleware over forcing external code to import internal implementation details
4. seams between a plugin and the host app should stay mostly limited to:
- `router.include_router(...)`
- `app.add_middleware(...)`
- `app.state.*`
- lifespan/bootstrap wiring
5. config-driven integration should be preferred over scattered hardcoded integration
6. startup validation should be preferred over implicit runtime failure
## 9. Summary
The current `app.plugins` contract can be summarized in one sentence:
Each plugin owns its own business implementation, ORM, and runtime; the host application provides shared infrastructure and assembly seams; and services should be integrated through low-intrusion, replaceable boundaries so the system can evolve toward real hot-swappability.
-310
View File
@@ -1,310 +0,0 @@
# app.plugins 设计说明
本文基于当前代码实现,说明 `backend/app/plugins` 的定位、插件设计契约、依赖边界,以及当前 `auth` 插件是如何在尽量少侵入宿主应用的前提下提供服务的。
## 1. 总体定位
`app.plugins` 是应用侧插件边界。它的目标不是做一个通用插件市场,而是在 `app` 这一层给可拆分的业务能力预留清晰边界,使某一类能力可以:
1. 在插件内部自带领域模型、运行时状态和适配器
2. 只通过有限的接缝与宿主应用交互
3. 在未来保持“可替换、可裁剪、可扩展”
当前目录下实际落地的插件是 [`auth`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth)。
从当前实现看,`app.plugins` 的方向不是“所有逻辑都塞进 app”,而是:
1. 宿主应用负责统一启动、共享基础设施和总路由装配
2. 插件负责自己的业务契约、持久化定义、运行时状态和外部适配器
## 2. 插件设计契约
### 2.1 插件内部要自带完整能力
当前代码体现出的首要契约是:
插件自己的 ORM、runtime、domain、adapter,原则上都应由插件内部实现,不要把核心业务依赖散落到外部模块。
`auth` 插件为例,它内部已经自带了完整分层:
1. `domain`
- 配置、错误、JWT、密码、领域模型、服务
2. `storage`
- 插件自己的 ORM 模型、仓储契约和仓储实现
3. `runtime`
- 插件自己的运行时配置状态
4. `api`
- 插件自己的 HTTP router 和 schema
5. `security`
- 插件自己的 middleware、dependency、csrf、LangGraph 适配
6. `authorization`
- 插件自己的权限模型、policy 解析和 hook
7. `injection`
- 插件自己的路由策略注册、注入和校验逻辑
换句话说,插件不是一组零散 helper,而应该是一个自闭合的功能模块。
### 2.2 宿主应用只提供共享基础设施,不承接插件内部逻辑
当前约束不是“插件完全独立进程”,而是:
1. 插件可以复用应用共享的 `engine``session_factory`、FastAPI app、路由树
2. 但插件自己的表结构、仓储、运行时配置、鉴权逻辑,仍然应由插件自己拥有
这一点在 [`auth/plugin.toml`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/plugin.toml) 里写得很明确:
1. `storage.mode = "shared_infrastructure"`
2. 说明插件拥有自己的 storage definitions 和 repositories
3. 但复用应用共享的 persistence infrastructure
所以这里的契约不是“禁止复用基础设施”,而是“不要把插件内部业务实现外包给 app 其他模块”。
### 2.3 依赖方向要单向
按当前实现,比较理想的依赖方向是:
```text
gateway / app bootstrap
-> plugin public adapters
-> plugin domain / storage / runtime
```
而不是:
```text
plugin domain
-> 依赖 app 里的业务模块
```
插件可以使用:
1. 共享持久化基础设施
2. 宿主应用提供的 `app.state`
3. FastAPI / Starlette 等通用框架能力
但不应该把自己的核心业务规则建立在别的业务模块之上,否则后续无法热插拔。
## 3. 当前 auth 插件的实际结构
当前 `auth` 插件可以概括为一套“自带模型、自带服务、自带适配器”的认证授权包。
### 3.1 domain
[`auth/domain`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/domain) 负责:
1. `config.py`
- 认证相关配置定义与加载
2. `errors.py`
- 错误码和错误响应契约
3. `jwt.py`
- token 编解码
4. `password.py`
- 密码哈希和校验
5. `models.py`
- auth 域模型
6. `service.py`
- `AuthService`,作为核心业务服务
`AuthService` 本身只依赖插件内部的 `DbUserRepository` 和共享 session factory,没有把认证逻辑散到 `gateway`
### 3.2 storage
[`auth/storage`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage) 明确体现了“ORM 由插件自己内部实现”的契约:
1. [`models.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/models.py)
- 定义插件自己的 `users` 表模型
2. [`contracts.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/contracts.py)
- 定义 `User``UserCreate``UserRepositoryProtocol`
3. [`repositories.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/repositories.py)
- 实现 `DbUserRepository`
这里的关键点是:
1. 插件自己定义 ORM model
2. 插件自己定义 repository protocol
3. 插件自己实现 repository
4. 外部只需要给它 session / session_factory
这就是插件边界应该保持的最小共享面。
### 3.3 runtime
[`auth/runtime/config_state.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/runtime/config_state.py) 维护插件自己的 runtime config state
1. `get_auth_config()`
2. `set_auth_config()`
3. `reset_auth_config()`
这说明运行时配置状态也属于插件内部,而不是由外部模块代持。后续如果别的插件需要自己的缓存、状态机、feature flag,也应沿这个模式内聚在插件内部。
### 3.4 adapters
`auth` 插件对外暴露能力主要通过四类 adapter:
1. `api/router.py`
- HTTP 接口
2. `security/*`
- middleware、dependency、request user 解析、actor context bridge
3. `authorization/*`
- capability、policy evaluator、auth hooks
4. `injection/*`
- route policy registry、guard 注入、启动校验
这类 adapter 的共同特征是:
1. 入口能力在插件内定义
2. 宿主应用只负责调用和装配
## 4. 插件如何与宿主应用交互
### 4.1 总路由只 include,不重写插件逻辑
[`app/gateway/router.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/gateway/router.py) 只是:
1. 引入 `app.plugins.auth.api.router`
2. `include_router(auth_router)`
这说明宿主应用对 auth HTTP 能力的接入是装配式的,而不是在 `gateway` 里重写一套登录/注册逻辑。
### 4.2 registrar 负责启动装配,不负责接管插件实现
[`app/gateway/registrar.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/gateway/registrar.py) 里,宿主应用做的事情主要是:
1. `app.state.authz_hooks = build_authz_hooks()`
2. 加载并校验 route policy registry
3. `install_route_guards(app)`
4. `app.add_middleware(CSRFMiddleware)`
5. `app.add_middleware(AuthMiddleware)`
也就是说,宿主应用只负责把插件接进来:
1. 注册 middleware
2. 安装 route guard
3. 把 hooks 和 registry 放到 `app.state`
真正的鉴权逻辑、认证逻辑、路由策略语义仍然在插件内部。
### 4.3 共享会话工厂,但业务仓储仍归插件
在 [`auth/security/dependencies.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/security/dependencies.py) 中:
1. 插件从 `request.app.state.persistence.session_factory` 取得共享 session factory
2. 然后自己构造 `DbUserRepository`
3. 再自己构造 `AuthService`
这就是一个很典型的低侵入接缝:
1. 外部只提供共享基础设施句柄
2. 插件自己决定如何实例化内部依赖
## 5. 热插拔与低侵入原则
### 5.1 如果要向其他模块提供服务,应尽量减少入侵
插件给其他模块提供服务时,优先选下面这些方式:
1. 暴露 router
2. 暴露 middleware / dependency
3. 暴露 hook 或 protocol
4. 通过 `app.state` 注入少量共享对象
5. 使用配置驱动的 route policy / capability,而不是把判断逻辑硬编码进业务路由
不推荐的方式是:
1. 在 `gateway` 大量写插件特定分支
2. 让别的业务模块直接 import 插件内部 ORM 细节后自行拼逻辑
3. 把插件状态散落到全局多个模块中共同维护
### 5.2 当前 auth 插件已经体现出的低侵入点
当前 `auth` 插件的低侵入接入点主要有四个:
1. 路由接入
- `gateway.router``include_router`
2. 中间件接入
- `registrar` 只注册 `AuthMiddleware` / `CSRFMiddleware`
3. 策略注入
- `install_route_guards(app)` 给路由统一追加 `Depends(enforce_route_policy)`
4. hook 接缝
- `authz_hooks` 通过 `app.state` 暴露,策略构建和权限提供器可以替换
这套结构的好处是:
1. 宿主应用改动面集中在装配层
2. 插件核心实现集中在插件目录内部
3. 替换实现时,不需要在业务路由里逐个修改
### 5.3 route policy 是低侵入的关键机制
[`auth/injection/registry_loader.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/registry_loader.py)、[`validation.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/validation.py) 和 [`route_injector.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/route_injector.py) 共同形成了一套很关键的契约:
1. 路由策略写在插件自己的 `route_policies.yaml`
2. 启动时会校验策略表和真实路由是否一致
3. guard 通过统一注入附着到路由,而不是每个 endpoint 手写一遍
这使得插件能够:
1. 用配置描述“哪些路由公开、需要哪些 capability、需要哪些 owner policy”
2. 避免对宿主路由层做大规模侵入
3. 在未来更容易替换或裁剪某个插件
## 6. 关于“ORM、runtime 都由自己内部实现”的具体说明
这条契约建议明确理解为以下三点:
1. 数据模型归插件
- 插件自己的表、Pydantic contract、repository protocol、repository implementation 都放在插件目录内
2. 运行时状态归插件
- 插件自己的配置缓存、上下文桥、插件级 hooks 都在插件内部维护
3. 外部只暴露基础设施,不接管插件语义
- 例如共享 `session_factory`、FastAPI app、`app.state`
`auth` 举例:
1. `users` 表在插件里定义,不在 `app.infra` 定义
2. `AuthService` 在插件里实现,不在 `gateway` 实现
3. `get_auth_config()` 在插件里维护,不由别的模块缓存
4. `AuthMiddleware``route_guard``AuthzHooks` 都由插件自己提供
这是后续做插件化时最重要的结构前提。
## 7. 当前作用范围与非目标
就当前实现而言,`app.plugins` 的作用范围主要是:
1. 为应用侧可拆分能力建立模块边界
2. 让插件拥有自己的 domain/storage/runtime/adapter
3. 通过装配式接缝接入宿主应用
当前非目标也很明确:
1. 还不是一个完整的通用插件发现/安装系统
2. 还没有做到运行时动态启停插件
3. 也不是把共享基础设施完全复制进每个插件
所以“热插拔”在当前阶段更准确的含义是:
1. 插件边界尽量独立
2. 接入点尽量集中在装配层
3. 替换或移除时,改动尽量局限在 `registrar``router include``app.state` hooks 这些少数位置
## 8. 后续演进建议
如果后续要继续把 `app.plugins` 做成更稳定的插件边界,建议保持这些规则:
1. 每个插件目录内部都保持 `domain` / `storage` / `runtime` / `adapter` 分层
2. 插件自己的 ORM 与 repository 不要下沉到共享业务目录
3. 插件向外提供服务时优先暴露 protocol、hook、router、middleware,而不是要求外部 import 内部实现细节
4. 插件与宿主应用的接缝尽量限制在:
- `router.include_router(...)`
- `app.add_middleware(...)`
- `app.state.*`
- 生命周期装配
5. 配置驱动优先于散落的硬编码接入
6. 启动期校验优先于运行时隐式失败
## 9. 设计总结
可以把当前 `app.plugins` 的契约总结为一句话:
插件内部拥有自己的业务实现、ORM 和 runtime;宿主应用只提供共享基础设施和装配接缝;对外服务时尽量通过低侵入、可替换的方式接入,以便后续做到真正的热插拔和边界演进。
-1
View File
@@ -1 +0,0 @@
"""Application plugin packages."""
-21
View File
@@ -1,21 +0,0 @@
# Auth Plugin
This package is the future Level 2 auth plugin boundary for DeerFlow.
Scope:
- Auth domain logic: config, errors, models, JWT, password hashing, service
- Auth adapters: HTTP router, FastAPI dependencies, middleware, LangGraph adapter
- Auth storage: user/account models and repositories
Non-scope:
- Shared app/container bootstrap
- Shared persistence engine/session lifecycle
- Generic plugin discovery/registration framework
Target architecture:
- The plugin owns its storage definitions and business logic
- The plugin reuses the application's shared persistence infrastructure
- The gateway only assembles the plugin instead of owning auth logic directly
-14
View File
@@ -1,14 +0,0 @@
"""Auth plugin package.
Level 2 plugin goal:
- Own auth domain logic
- Own auth adapters (router, dependencies, middleware, LangGraph adapter)
- Own auth storage definitions
- Reuse the application's shared persistence/session infrastructure
"""
from app.plugins.auth.authorization.hooks import build_authz_hooks
__all__ = [
"build_authz_hooks",
]
-17
View File
@@ -1,17 +0,0 @@
"""HTTP API layer for the auth plugin."""
from app.plugins.auth.api.router import (
ChangePasswordRequest,
LoginResponse,
MessageResponse,
RegisterRequest,
router,
)
__all__ = [
"ChangePasswordRequest",
"LoginResponse",
"MessageResponse",
"RegisterRequest",
"router",
]
-171
View File
@@ -1,171 +0,0 @@
"""Authentication endpoints for the auth plugin."""
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from app.plugins.auth.api.schemas import (
ChangePasswordRequest,
InitializeAdminRequest,
LoginResponse,
MessageResponse,
RegisterRequest,
_check_rate_limit,
_get_client_ip,
_login_attempts,
_record_login_failure,
_record_login_success,
)
from app.plugins.auth.domain.errors import AuthErrorResponse
from app.plugins.auth.domain.jwt import create_access_token
from app.plugins.auth.domain.models import UserResponse
from app.plugins.auth.domain.service import AuthServiceError
from app.plugins.auth.runtime.config_state import get_auth_config
from app.plugins.auth.security.csrf import is_secure_request
from app.plugins.auth.security.dependencies import CurrentAuthService, get_current_user_from_request
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
config = get_auth_config()
is_https = is_secure_request(request)
response.set_cookie(
key="access_token",
value=token,
httponly=True,
secure=is_https,
samesite="lax",
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
)
@router.post("/login/local", response_model=LoginResponse)
async def login_local(
request: Request,
response: Response,
auth_service: CurrentAuthService,
form_data: OAuth2PasswordRequestForm = Depends(),
):
client_ip = _get_client_ip(request)
_check_rate_limit(client_ip)
try:
user = await auth_service.login_local(form_data.username, form_data.password)
except AuthServiceError as exc:
_record_login_failure(client_ip)
raise HTTPException(
status_code=exc.status_code,
detail=AuthErrorResponse(code=exc.code, message=exc.message).model_dump(),
) from exc
_record_login_success(client_ip)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return LoginResponse(
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
needs_setup=user.needs_setup,
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(request: Request, response: Response, body: RegisterRequest, auth_service: CurrentAuthService):
try:
user = await auth_service.register(body.email, body.password)
except AuthServiceError as exc:
raise HTTPException(
status_code=exc.status_code,
detail=AuthErrorResponse(code=exc.code, message=exc.message).model_dump(),
) from exc
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
@router.post("/logout", response_model=MessageResponse)
async def logout(request: Request, response: Response):
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
return MessageResponse(message="Successfully logged out")
@router.post("/change-password", response_model=MessageResponse)
async def change_password(
request: Request,
response: Response,
body: ChangePasswordRequest,
auth_service: CurrentAuthService,
):
user = await get_current_user_from_request(request)
try:
user = await auth_service.change_password(
user,
current_password=body.current_password,
new_password=body.new_password,
new_email=body.new_email,
)
except AuthServiceError as exc:
raise HTTPException(
status_code=exc.status_code,
detail=AuthErrorResponse(code=exc.code, message=exc.message).model_dump(),
) from exc
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return MessageResponse(message="Password changed successfully")
@router.get("/me", response_model=UserResponse)
async def get_me(request: Request):
user = await get_current_user_from_request(request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
@router.get("/setup-status")
async def setup_status(auth_service: CurrentAuthService):
return {"needs_setup": await auth_service.get_setup_status()}
@router.post("/initialize", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def initialize_admin(
request: Request,
response: Response,
body: InitializeAdminRequest,
auth_service: CurrentAuthService,
):
try:
user = await auth_service.initialize_admin(body.email, body.password)
except AuthServiceError as exc:
raise HTTPException(
status_code=exc.status_code,
detail=AuthErrorResponse(code=exc.code, message=exc.message).model_dump(),
) from exc
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
@router.get("/oauth/{provider}")
async def oauth_login(provider: str):
if provider not in ["github", "google"]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported OAuth provider: {provider}")
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="OAuth login not yet implemented")
@router.get("/callback/{provider}")
async def oauth_callback(provider: str, code: str, state: str):
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="OAuth callback not yet implemented")
__all__ = [
"ChangePasswordRequest",
"InitializeAdminRequest",
"LoginResponse",
"MessageResponse",
"RegisterRequest",
"_check_rate_limit",
"_get_client_ip",
"_login_attempts",
"_record_login_failure",
"_record_login_success",
"router",
]
-176
View File
@@ -1,176 +0,0 @@
"""HTTP schemas and request helpers for the auth plugin API."""
from __future__ import annotations
import os
import time
from ipaddress import ip_address, ip_network
from fastapi import HTTPException, Request
from pydantic import BaseModel, EmailStr, Field, field_validator
_COMMON_PASSWORDS: frozenset[str] = frozenset(
{
"password",
"password1",
"password12",
"password123",
"password1234",
"12345678",
"123456789",
"1234567890",
"qwerty12",
"qwertyui",
"qwerty123",
"abc12345",
"abcd1234",
"iloveyou",
"letmein1",
"welcome1",
"welcome123",
"admin123",
"administrator",
"passw0rd",
"p@ssw0rd",
"monkey12",
"trustno1",
"sunshine",
"princess",
"football",
"baseball",
"superman",
"batman123",
"starwars",
"dragon123",
"master123",
"shadow12",
"michael1",
"jennifer",
"computer",
}
)
_MAX_LOGIN_ATTEMPTS = 5
_LOCKOUT_SECONDS = 300
_MAX_TRACKED_IPS = 10000
_login_attempts: dict[str, tuple[int, float]] = {}
class LoginResponse(BaseModel):
expires_in: int
needs_setup: bool = False
class RegisterRequest(BaseModel):
email: EmailStr
password: str = Field(..., min_length=8)
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
class ChangePasswordRequest(BaseModel):
current_password: str
new_password: str = Field(..., min_length=8)
new_email: EmailStr | None = None
_strong_password = field_validator("new_password")(classmethod(lambda cls, v: _validate_strong_password(v)))
class MessageResponse(BaseModel):
message: str
class InitializeAdminRequest(BaseModel):
email: EmailStr
password: str = Field(..., min_length=8)
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
def _password_is_common(password: str) -> bool:
return password.lower() in _COMMON_PASSWORDS
def _validate_strong_password(value: str) -> str:
if _password_is_common(value):
raise ValueError("Password is too common; choose a stronger password.")
return value
def _trusted_proxies() -> list:
raw = os.getenv("AUTH_TRUSTED_PROXIES", "").strip()
if not raw:
return []
nets = []
for entry in raw.split(","):
entry = entry.strip()
if not entry:
continue
try:
nets.append(ip_network(entry, strict=False))
except ValueError:
pass
return nets
def _get_client_ip(request: Request) -> str:
peer_host = request.client.host if request.client else None
trusted = _trusted_proxies()
if trusted and peer_host:
try:
peer_ip = ip_address(peer_host)
if any(peer_ip in net for net in trusted):
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
except ValueError:
pass
return peer_host or "unknown"
def _check_rate_limit(ip: str) -> None:
record = _login_attempts.get(ip)
if record is None:
return
fail_count, lock_until = record
if fail_count >= _MAX_LOGIN_ATTEMPTS:
if time.time() < lock_until:
raise HTTPException(status_code=429, detail="Too many login attempts. Try again later.")
del _login_attempts[ip]
def _record_login_failure(ip: str) -> None:
if len(_login_attempts) >= _MAX_TRACKED_IPS:
now = time.time()
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
for key in expired:
del _login_attempts[key]
if len(_login_attempts) >= _MAX_TRACKED_IPS:
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
for key, _ in by_time[: len(by_time) // 2]:
del _login_attempts[key]
record = _login_attempts.get(ip)
if record is None:
_login_attempts[ip] = (1, 0.0)
else:
new_count = record[0] + 1
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
_login_attempts[ip] = (new_count, lock_until)
def _record_login_success(ip: str) -> None:
_login_attempts.pop(ip, None)
__all__ = [
"ChangePasswordRequest",
"InitializeAdminRequest",
"LoginResponse",
"MessageResponse",
"RegisterRequest",
"_check_rate_limit",
"_get_client_ip",
"_login_attempts",
"_record_login_failure",
"_record_login_success",
]
@@ -1,31 +0,0 @@
"""Authorization layer for the auth plugin."""
from app.plugins.auth.authorization.authentication import get_auth_context
from app.plugins.auth.authorization.hooks import (
AuthzHooks,
build_authz_hooks,
build_permission_provider,
build_policy_chain_builder,
get_authz_hooks,
get_default_authz_hooks,
)
from app.plugins.auth.authorization.types import (
AuthContext,
Permissions,
ALL_PERMISSIONS,
)
_ALL_PERMISSIONS = ALL_PERMISSIONS
__all__ = [
"AuthContext",
"AuthzHooks",
"Permissions",
"_ALL_PERMISSIONS",
"build_authz_hooks",
"build_permission_provider",
"build_policy_chain_builder",
"get_auth_context",
"get_authz_hooks",
"get_default_authz_hooks",
]

Some files were not shown because too many files have changed in this diff Show More