Compare commits

...

51 Commits

Author SHA1 Message Date
copilot-swe-agent[bot] dad3997459 fix(sandbox): cleanup dead containers and avoid lock-held liveness checks
Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/96707445-0f8b-4901-8ef3-d8e5667f8a05

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>
2026-05-11 00:09:09 +00:00
Willem Jiang b67c2a4e56 fix(sandbox): auto-restart crashed containers transparently (#2788)
When a sandbox container crashes (e.g. due to an internal error), the
  agent enters a connection-refused loop because AioSandboxProvider.get()
  returns a cached but dead sandbox object. Add a liveness check in get()
  that detects crashed containers via backend.is_alive() and evicts them
  from all caches, allowing ensure_sandbox_initialized() to transparently
  recreate a fresh container on the next acquire().

  The behavior is controlled by a new  config option
  (default: true). Set to false to skip health checks and preserve the
  old behavior of returning stale cached sandboxes.

  Closes #2788
2026-05-10 22:53:58 +08:00
Xinmin Zeng 94da8f67d7 fix(scripts): preserve uv extras across make dev restarts (#2754) (#2767)
`make dev` ran `uv sync` unconditionally on every restart, wiping any
optional extras the user had installed manually with
`uv sync --all-packages --extra postgres`. The Docker image-build path
already solved this via the `UV_EXTRAS` build-arg in backend/Dockerfile;
the local serve.sh path and the docker-compose-dev startup command
were the remaining outliers.

`scripts/serve.sh` now resolves extras before `uv sync`:
  1. honors `UV_EXTRAS` (parity with backend/Dockerfile and
     docker/docker-compose.yaml — no new convention introduced);
  2. falls back to parsing config.yaml — `database.backend: postgres`
     or legacy `checkpointer.type: postgres` auto-pins
     `--extra postgres`, so the common case needs zero extra config.
  3. detector stderr is no longer suppressed, so whitelist warnings or
     crashes surface to the dev terminal (review feedback).

Detection lives in `scripts/detect_uv_extras.py` (stdlib-only — has to
run before the venv exists). Extra names are validated against
`^[A-Za-z][A-Za-z0-9_-]*$` so a stray shell metacharacter in `.env`
cannot reach `uv sync` downstream (defense in depth).

`docker/docker-compose-dev.yaml`'s startup command is now extracted to
`docker/dev-entrypoint.sh` (review feedback — the inline command had
grown to a ~350-char one-liner). The script:
  - parses comma/whitespace-separated UV_EXTRAS, applying the same
    `^[A-Za-z][A-Za-z0-9_-]*$` whitelist as the local detector;
  - emits one `--extra X` flag per token, so `UV_EXTRAS=postgres,ollama`
    works in Docker dev too (harmonized with local — review feedback);
  - calls `uv sync --all-packages` (PR #2584) so workspace member
    extras (deerflow-harness's postgres extra) are installed;
  - keeps the existing self-heal `(uv sync || (recreate venv && retry))`
    branch;
  - exposes `--print-extras` for dry-run testing.

The compose file mounts the script read-only at runtime, so script
edits take effect on `make docker-restart` without an image rebuild.

The `--no-sync` alternative (a separate suggestion in the issue thread)
was considered but rejected for dev paths because it would drop the
self-heal branch and the auto-pickup of new pyproject deps. `--no-sync`
is already in use for the production CMD (`backend/Dockerfile:101`)
where it's appropriate.

Updates the asyncpg-missing error message to include the
`--all-packages` flag (matching #2584) plus the persistent install flow,
and expands `config.example.yaml` so all three install paths
(local / docker dev / docker image build) are documented with their
multi-extra capabilities.

Tests:
  - `tests/test_detect_uv_extras.py` (21 tests) — local-path env parsing,
    YAML edge cases, env-vs-config precedence, whitelist rejection of
    shell metacharacters.
  - `tests/test_dev_entrypoint.py` (15 tests) — docker-path validation
    via `--print-extras`, multi-extra parsing, metacharacter abort.
  - `tests/test_persistence_scaffold.py` (22 tests, unchanged) — passes
    with the merged `--all-packages --extra postgres` error message.

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-10 22:28:29 +08:00
YuJitang 5127f08e1a enable token usage by default (#2841) 2026-05-10 22:00:57 +08:00
DanielWalnut dfa4eb0c1a [codex] fix follow-up suggestions layout (#2836)
* fix follow-up suggestions layout

* fix agent chat welcome layout transition

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-10 15:10:44 +08:00
DanielWalnut 08ee7adeba fix(lint): remove duplicate is_dynamic_context_reminder definition (#2837)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-09 23:40:46 +08:00
Eilen Shin 1c96a6afc8 fix: keep new agent bootstrap in user scope (#2784) 2026-05-09 19:43:50 +08:00
YuJitang 417416087b fix: use backend thread token usage for header total (#2800)
* fix: use backend thread token usage for header total

* Refactor thread token usage fetch
2026-05-09 19:40:32 +08:00
DanielWalnut 881ff71252 fix(harness): preserve dynamic context across summarization (#2823) 2026-05-09 19:39:36 +08:00
DanielWalnut f76e4e35c8 fix title generation with dynamic context reminder (#2830) 2026-05-09 18:22:58 +08:00
yangyufan 0d1053ca44 fix(uploads): add Windows support for safe symlink-protected uploads (#2794)
* fix(uploads): add Windows support for safe symlink-protected uploads

* fix(uploads): update tests and translate comments;
2026-05-09 18:21:54 +08:00
He Wang 4063dd7157 feat(debug): print presented file paths with physical resolution (#2825)
Surface artifacts produced via the present_files tool in the CLI debug
REPL so headless clients without a frontend (VS Code launch configs,
etc.) can locate output files. Each turn prints newly added artifacts
plus their resolved host path. Works for any source that goes through
present_files — ACP agents, subagents, or sandbox writes.

Co-authored-by: Claude Opus 4 <noreply@anthropic.com>
2026-05-09 18:21:01 +08:00
ChenglongZ 7a3c58a733 Fix duplicate gateway upload filenames (#2789) 2026-05-09 18:02:40 +08:00
dependabot[bot] 1edc9d9fae chore(deps): bump langchain-core from 1.3.2 to 1.3.3 in /backend (#2807)
Bumps [langchain-core](https://github.com/langchain-ai/langchain) from 1.3.2 to 1.3.3.
- [Release notes](https://github.com/langchain-ai/langchain/releases)
- [Commits](https://github.com/langchain-ai/langchain/compare/langchain-core==1.3.2...langchain-core==1.3.3)

---
updated-dependencies:
- dependency-name: langchain-core
  dependency-version: 1.3.3
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-09 15:51:18 +08:00
KiteEater 7caf03e97c fix(packaging): add postgres extra for store/checkpointer supportFix postgres extra install guidance (#2584)
* Fix postgres extra install guidance

* Fix postgres install message lint

* Format postgres install messages

* Fix postgres install guidance and config docs
2026-05-09 09:49:08 +08:00
dependabot[bot] 41b04a556f chore(deps): bump uuid from 10.0.0 to 14.0.0 in /frontend (#2802)
Bumps [uuid](https://github.com/uuidjs/uuid) from 10.0.0 to 14.0.0.
- [Release notes](https://github.com/uuidjs/uuid/releases)
- [Changelog](https://github.com/uuidjs/uuid/blob/main/CHANGELOG.md)
- [Commits](https://github.com/uuidjs/uuid/compare/v10.0.0...v14.0.0)

---
updated-dependencies:
- dependency-name: uuid
  dependency-version: 14.0.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-09 09:33:00 +08:00
DanielWalnut c1b7f1d189 feat: static system prompt with DynamicContextMiddleware for prefix-cache optimization (#2801)
* feat(middleware): inject dynamic context via DynamicContextMiddleware

Move memory and current date out of the system prompt and into a
dedicated <system-reminder> HumanMessage injected once per session
(frozen-snapshot pattern) via a new DynamicContextMiddleware.

This keeps the system prompt byte-exact across all users and sessions,
enabling maximum Anthropic/Bedrock prefix-cache reuse.

Key design decisions:
- ID-swap technique: reminder takes the first HumanMessage's ID
  (replacing it in-place via add_messages), original content gets a
  derived `{id}__user` ID (appended after). Preserves correct ordering.
- hide_from_ui: True on reminder messages so frontend filters them out.
- Midnight crossing: date-update reminder injected before the current
  turn's HumanMessage when the conversation spans midnight.
- INFO-level logging for production diagnostics.

Also adds prompt-caching breakpoint budget enforcement tests and
updates ClaudeChatModel docs to reference the new pattern.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* feat(token-usage): log input/output token detail breakdown in middleware

Extend the LLM token usage log line to include input_token_details and
output_token_details (cache_creation, cache_read, reasoning, audio, etc.)
when present. Adds tests covering Anthropic cache detail logging from
both usage_metadata and response_metadata.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: fix nginx

* fix(middleware): always inject date; gate memory on injection_enabled

Date injection is now unconditional — it is part of the static system
prompt replacement and should always be present. Memory injection
remains gated by `memory.injection_enabled` in the app config.

Previously the entire DynamicContextMiddleware was skipped when
injection_enabled was False, which also suppressed the date.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(lint): format files and correct test assertions for token usage middleware

- ruff format dynamic_context_middleware.py and test_claude_provider_prompt_caching.py
- Remove unused pytest import from test_dynamic_context_middleware.py
- Fix two tests that asserted response_metadata fallback logic that
  doesn't exist: replace with tests that match actual middleware behavior

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(middleware): address Copilot review comments on DynamicContextMiddleware

- Use additional_kwargs flag for reminder detection instead of content
  substring matching, so user messages containing '<system-reminder>'
  are not mistakenly treated as injected reminders
- Generate stable UUID when original HumanMessage.id is None to prevent
  ambiguous 'None__user' derived IDs and message collisions
- Downgrade per-turn no-op log to DEBUG; keep actual injection events at INFO
- Add two new tests: missing-id UUID fallback and user-text false-positive

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-09 09:27:02 +08:00
dependabot[bot] 109490da25 chore(deps): bump python-multipart from 0.0.26 to 0.0.27 in /backend (#2799)
Bumps [python-multipart](https://github.com/Kludex/python-multipart) from 0.0.26 to 0.0.27.
- [Release notes](https://github.com/Kludex/python-multipart/releases)
- [Changelog](https://github.com/Kludex/python-multipart/blob/main/CHANGELOG.md)
- [Commits](https://github.com/Kludex/python-multipart/compare/0.0.26...0.0.27)

---
updated-dependencies:
- dependency-name: python-multipart
  dependency-version: 0.0.27
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-08 22:58:15 +08:00
dependabot[bot] 14c0a32ee6 chore(deps): bump mako from 1.3.11 to 1.3.12 in /backend (#2798)
Bumps [mako](https://github.com/sqlalchemy/mako) from 1.3.11 to 1.3.12.
- [Release notes](https://github.com/sqlalchemy/mako/releases)
- [Changelog](https://github.com/sqlalchemy/mako/blob/main/CHANGES)
- [Commits](https://github.com/sqlalchemy/mako/commits)

---
updated-dependencies:
- dependency-name: mako
  dependency-version: 1.3.12
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-08 22:57:48 +08:00
Willem Jiang 70737af7cd fix(nignx):resolve CSRF auth failure on non-standard ports (#2796) 2026-05-08 22:40:38 +08:00
DanielWalnut 2b1fcb3e43 fix(task): remove max_turns parameter from task tool interface (#2783)
* fix(task): remove max_turns parameter from task tool interface

Subagents should always use their configured max_turns value. Exposing
this parameter allowed callers to override the admin-configured limit,
which is undesirable. The value is now exclusively driven by subagent
config (per-agent overrides and global defaults in config.yaml).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-08 15:05:24 +08:00
He Wang 7de9b5828b fix(tools): introduce Runtime type alias to eliminate Pydantic serialization warning (#2774)
* fix(tools): introduce Runtime type alias to eliminate Pydantic serialization warning

Add deerflow/tools/types.py with:

    Runtime = ToolRuntime[dict[str, Any], ThreadState]

Replace every runtime: ToolRuntime[ContextT, ThreadState] and
runtime: ToolRuntime[dict[str, Any], ThreadState] annotation in
sandbox/tools.py, present_file_tool.py, task_tool.py, view_image_tool.py,
and skill_manage_tool.py with the new Runtime alias.

The unbound ContextT TypeVar (default None) caused
PydanticSerializationUnexpectedValue warnings on every tool call because
LangChain's BaseTool._parse_input calls model_dump() on the auto-generated
args_schema while DeerFlow passes a dict as runtime context.
Binding the context to dict[str, Any] aligns Pydantic's serialization
expectations with reality and removes the noise from all run modes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

* fix(tools): extend Runtime alias to setup_agent and update_agent tools

Replace bare ToolRuntime annotations in setup_agent_tool.py and
update_agent_tool.py with the shared Runtime alias introduced in the
previous commit, and add both tools to the Pydantic serialization
warning regression test (13 cases total).

Co-authored-by: Cursor <cursoragent@cursor.com>

* test(tools): loosen Pydantic warning filter to avoid version-specific format

Replace the brittle "field_name='context'" substring check with a looser
"context" match so the assertion stays valid if Pydantic changes its
internal warning format across versions.

Co-authored-by: Cursor <cursoragent@cursor.com>

* test(tools): simplify warning filter and clean up docstring

Remove the "context" substring condition from the Pydantic warning
filter — asserting that no PydanticSerializationUnexpectedValue fires
at all is both simpler and more comprehensive, since the test payload
contains only the tool's own args plus runtime.

Also update the module docstring to remove the version-specific warning
format example that was inconsistent with the looser filter.

Co-authored-by: Cursor <cursoragent@cursor.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-08 14:50:33 +08:00
Eilen Shin 37db689349 fix(events): serialize structured db event content (#2762) 2026-05-08 10:17:17 +08:00
Eilen Shin bd45cb2846 fix(sandbox): disable msys path conversion (#2766) 2026-05-08 10:13:11 +08:00
Eilen Shin 5fd0e6ac89 fix(middleware): sync raw tool call metadata (#2757) 2026-05-08 10:08:53 +08:00
YuJitang 530bda7107 fix: dedupe token usage aggregation by message id (#2770) 2026-05-08 09:54:20 +08:00
Willem Jiang 6c220a9aef fix(chat): prevent first user message from being swallowed in new conversations (#2731)
* fix(chat): prevent first user message from being swallowed in new conversations

  The optimistic message clearing effect cleared too eagerly — any stream
  message (including AI messages from messages-tuple events) triggered the
  clear before the server's human message had arrived via values events.
  For new threads this caused the user's first prompt to disappear permanently.

  Only clear optimistic messages once the server's human message has been
  confirmed to arrive in thread.messages, not just when any message arrives.

  Fixes #2730

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-07 17:31:48 +08:00
Tao Liu daa3ffc29b feat(loop-detection): make loop detection configurable with per-tool frequency overrides (#2711)
* Make loop detection configurable

Expose LoopDetectionMiddleware thresholds through config.yaml while preserving existing defaults and allowing the middleware to be disabled.

Refs bytedance/deer-flow#2517

* feat(loop-detection): add per-tool tool_freq_overrides to Phase 1

Adds ToolFreqOverride model and tool_freq_overrides field to
LoopDetectionConfig, wires it through LoopDetectionMiddleware, and
documents the option in config.example.yaml.

Resolves the gap flagged in the #2586 review: without per-tool overrides,
users hit by #2510/#2511 (RNA-seq workflows exceeding the bash hard limit)
had no way to raise thresholds for one tool without loosening the global
limit for every tool.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* docs(loop-detection): document tool_freq_overrides in LoopDetectionMiddleware docstring

Add the missing Args entry for tool_freq_overrides, explaining the
(warn, hard_limit) tuple structure and how per-tool thresholds supersede
the global tool_freq_warn / tool_freq_hard_limit for named tools.
Also run ruff format on the three files flagged by the lint check.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(loop-detection): validate LoopDetectionMiddleware __init__ params eagerly

Raise clear ValueError at construction time instead of crashing at
unpack-time inside _track_and_check when bad values are passed:
- tool_freq_overrides: must be 2-tuples of positive ints with hard_limit >= warn
- scalar thresholds: warn_threshold, hard_limit, tool_freq_warn,
  tool_freq_hard_limit must be >= 1 and hard limits must >= their warn pairs
- window_size, max_tracked_threads must be >= 1

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(test): isolate credential loader directory-path test from real ~/.claude

The test didn't monkeypatch HOME, so on any machine with real Claude Code
credentials at ~/.claude/.credentials.json the function fell through to
those credentials and the assertion failed. Adding HOME redirect ensures
the default credential path doesn't exist during the test.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style(test): add blank lines after import pytest in TestInitValidation

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* refactor(loop-detection): collapse dual validation to LoopDetectionConfig

Modifications
  - LoopDetectionMiddleware.__init__: stripped of all ValueError raises;
    becomes a plain field-assignment constructor.
  - LoopDetectionMiddleware.from_config: classmethod that builds the
    middleware from a Pydantic-validated LoopDetectionConfig and handles
    the ToolFreqOverride -> tuple[int, int] conversion.
  - agents/factory.py: SDK construction routed through
    LoopDetectionMiddleware.from_config(LoopDetectionConfig()) so the
    defaults path is Pydantic-validated too.
  - agents/lead_agent/agent.py: uses from_config instead of unpacking
    config fields by hand.
  - tests/test_loop_detection_middleware.py: deleted TestInitValidation
    (16 methods exercising the removed __init__ checks); added
    TestFromConfig (4 tests: scalar field mapping, override tuple
    conversion, empty overrides, behavioral smoke test).

Result: one validation layer (Pydantic), zero duplication, no __new__
hacks. Both production construction sites flow through LoopDetectionConfig.

Test results
  make test   -> 2977 passed, 18 skipped, 0 failed (137s)
  make format -> All checks passed; 411 files left unchanged

* feat(agents): make loop_detection configurable in create_deerflow_agent

Adds a `loop_detection: bool | AgentMiddleware = True` field to
RuntimeFeatures, mirroring the existing pattern used by `sandbox`,
`memory`, and `vision`. SDK users can now disable LoopDetectionMiddleware
or replace it with a custom instance built from their own
LoopDetectionConfig — e.g.
`LoopDetectionMiddleware.from_config(my_cfg)` — instead of being stuck
with the hardcoded defaults previously installed by the SDK factory.

The lead-agent path (which already reads AppConfig.loop_detection) is
unchanged, and the default `True` preserves prior always-on behavior for
all existing callers.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

---------

Co-authored-by: knight0940 <631532668@qq.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: Amorend <142649913+knight0940@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-07 16:15:15 +08:00
Xinmin Zeng 27559f3675 fix(frontend): defer thread id to onStart to avoid 404 on new chat (#2749)
* fix(frontend): defer thread id to onStart to avoid 404 on new chat

The LangGraph SDK's useStream eagerly fetches /threads/{id}/history the
moment it receives a thread id, and the local useThreadRuns issues
GET /threads/{id}/runs for the same reason. The chats page used to flip
isNewThread=false (and forward the client-generated thread id) inside
the synchronous onSend callback, before thread.submit had created the
thread on the backend. The two queries therefore raced ahead of
POST /runs/stream and returned 404 on the very first send.

Drop the onSend handler so isNewThread stays true until onStart fires
from useStream's onCreated — by then the backend has the thread, and
the SDK's submittingRef guard naturally suppresses the redundant
history fetch. The agent chat page already uses this pattern, so this
also unifies the two flows.

Adds an E2E regression that records request ordering and asserts
GET /history and GET /runs are never issued before POST /runs/stream
on the first send from /chats/new.

Closes #2746

* fix(frontend): split welcome layout from backend thread state

Removing onSend kept GET /history and GET /runs from racing ahead of
POST /runs/stream, but it also coupled the welcome layout (centered
input, hero, quick actions) to backend thread creation.  Until onCreated
returned, the user's optimistic message and the welcome hero rendered on
top of each other.

Introduce a dedicated `isWelcomeMode` UI flag, separate from
`isNewThread`:
- `isNewThread` still tracks "backend has no thread yet" and gates the
  thread id forwarded to useStream.
- `isWelcomeMode` drives the visual layout (header background, input
  box position, max width, hero, quick actions, autoFocus) and flips to
  false inside onSend so the layout animates immediately.

`isWelcomeMode` is kept in sync with `isNewThread` via an effect so
sidebar navigation and "new chat" still behave correctly.  All 15 E2E
tests pass, including the ordering regression added in the previous
commit.

* test(e2e): use monotonic sequence for thread-init ordering check

Date.now() is millisecond-resolution, so two requests emitted within
the same tick would share a timestamp and slip past the strict `<`
ordering assertions. Replace the timestamp with a monotonic counter
that increments on every observed request/requestfinished event so the
ordering check is robust regardless of scheduling.

Per PR #2749 review feedback from copilot-pull-request-reviewer.

* refactor(input-box): rename isNewThread prop to isWelcomeMode

Inside InputBox, the prop named `isNewThread` is only ever consulted
for visual layout decisions — gating follow-up suggestions, the bottom
background strip, and the welcome-mode quick-action SuggestionList. It
never reflects "the backend has created the thread", which after #2746
is tracked separately via `isNewThread` in the chat pages themselves.

Rename the prop to `isWelcomeMode` and update both call sites
(workspace chats page and agent chats page) so the prop name matches
its actual semantics. No behavior change.

Per PR #2749 review feedback from @WillemJiang.
2026-05-07 16:11:44 +08:00
AochenShen99 cef4224381 fix(skills): enforce allowed-tools metadata (#2626)
* fix(skills): parse allowed-tools frontmatter

* fix(skills): validate allowed-tools metadata

* fix(skills): add shared allowed-tools policy

* fix(subagents): enforce skill allowed-tools

* fix(agent): enforce skill allowed-tools

* refactor(skills): dedupe TypeVar and reuse cached enabled skills

- Drop redundant module-level TypeVar in tool_policy; rely on PEP 695 syntax.
- Expose get_cached_enabled_skills() and have the lead agent reuse it
  instead of synchronously rescanning skills on every request.

* fix(agent): expose config-scoped skill cache

* fix(subagents): pass filtered tools explicitly

* fix(skills): clean allowed-tools policy feedback
2026-05-07 08:34:43 +08:00
Hinotobi 2b0e62f679 [security] fix(auth): reject cross-site auth POSTs (#2740)
* fix(security): reject cross-site auth posts

* fix(auth): align secure cookie proxy scheme handling

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-07 07:58:06 +08:00
Eilen Shin 1336872b15 fix(channels): authenticate gateway command requests (#2742) 2026-05-06 15:27:34 +08:00
KiteEater 4ead2c6b19 fix(config): reset config-backed singletons on hot reload (#2588)
* Fix stale config singletons on reload

* fix(config): update checkpointer imports after runtime move

* Fix config reload singleton mutation on validation failure

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-06 10:17:55 +08:00
yangzheli 59c4a3f0a4 feat(agent): add custom-agent self-updates with user isolation (#2713)
* feat(agent): add update_agent tool for in-chat custom-agent self-updates (#2616)

Custom agents had no built-in way to persist updates to their own SOUL.md /
config.yaml from a normal chat — `setup_agent` was only bound during the
bootstrap flow, so when the user asked the agent to refine its description
or personality, the agent would shell out via bash/write_file and the edits
landed in a temporary sandbox/tool workspace instead of
`{base_dir}/agents/{agent_name}/`.

Changes:
- New `update_agent` builtin tool with partial-update semantics (only the
  fields you pass are written) and atomic temp-file + os.replace writes so
  a failed update never corrupts existing SOUL.md / config.yaml.
- Lead agent now binds `update_agent` in the non-bootstrap path whenever
  `agent_name` is set in the runtime context. Default agent (no
  agent_name) and bootstrap flow are unchanged.
- New `<self_update>` system-prompt section is injected for custom agents,
  instructing them to use `update_agent` — and explicitly NOT bash /
  write_file — to persist self-updates.
- Tests: 11 new cases in `tests/test_update_agent_tool.py` covering
  validation (missing/invalid agent_name, unknown agent, no fields),
  partial updates (soul-only, description-only, skills=[] vs omitted),
  no-op detection, atomic-write safety, and AgentConfig round-tripping;
  plus 2 new cases in `tests/test_lead_agent_prompt.py` covering the
  self-update prompt section.
- Docs: updated backend/CLAUDE.md builtin tools list and tools.mdx
  (en/zh) with the new tool description.

* feat(agent): isolate custom agents per user

Store custom agent definitions under the effective user, keep legacy agents readable until migration, and cover API/tool/migration behavior with tests.

Co-authored-by: Cursor <cursoragent@cursor.com>

* feat: consistent write/delete targets & add --user-id to migration

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-05 23:17:42 +08:00
Nan Gao e8675f266d fix(loop-detection): keep tool-call pairing on warn injection (#2724) (#2725)
* fix(loop-detection): keep tool-call pairing on warn injection (#2724)

* make format

* fix(loop-detection): avoid IMMessage leak to downstream consumer

* fix(channels): filter loop warning text from IM replies
2026-05-05 18:53:49 +08:00
Xun 680187ddc2 fix: Supplement list_running in RemoteSandboxBackend (#2716)
* fix: Supplement list_running in RemoteSandboxBackend

* fix

* except requests.RequestException as exc:

* fix
2026-05-05 18:53:10 +08:00
Xinmin Zeng aded753de3 fix(frontend): restore localhost fallback for getGatewayConfig in prod mode (#2705) (#2718)
* fix(frontend): unify gateway-config localhost fallback for prod (#2705)

`getGatewayConfig()` only fell back to localhost defaults when
`NODE_ENV === "development"`, while `next.config.js` always falls back
to `127.0.0.1:8001`. Running `make start` (which sets NODE_ENV=production
via `next start`) without `DEER_FLOW_INTERNAL_GATEWAY_BASE_URL` /
`DEER_FLOW_TRUSTED_ORIGINS` therefore caused zod to throw inside SSR
layouts and surfaced as a 500.

Drop the NODE_ENV gating and use localhost defaults everywhere — the
"force explicit config in prod" intent should be enforced by deployment
templates (docker-compose already sets both vars), not by request-time
crashes. Document the two vars in both .env.example files and add unit
coverage for the dev/prod env-unset paths.

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* Update internalGatewayUrl in gateway config tests

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-05 16:27:29 +08:00
Willem Jiang 028493bfd8 fix(docker):force ngix to resolve upstream names at request time (#2717)
* fix(docker):force ngix to resolve upstream names at request time

* fix(docker): set resolver valid=0s to eliminate DNS cache window for request-time re-resolution

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/07bdb872-022f-4fd2-9fa8-d800a4ce34a7

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

* Update DNS resolver valid time and add upstreams

* fix the unit test error

* Remove upstream server configurations from nginx.conf

Removed upstream server configurations for gateway and frontend.

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-05-05 14:35:55 +08:00
Willem Jiang 8e48b7e85c fix(channels): preserve clarification conversation history across follow-up turns (#2444)
* fix(channels): preserve clarification conversation history across follow-up turns

Pin channel-triggered runs to the root checkpoint namespace and ensure thread_id is always present in configurable run config so follow-up replies resume the same conversation state.

Add regression coverage to channel tests:

assert checkpoint_ns/thread_id are passed in wait and stream paths
add an integration-style clarification flow test that verifies the second user reply continues prior context instead of starting a new session
This addresses history loss after ask_clarification interruptions (issue #2425).

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix(channels): copy configurable dict before injecting run-scoped fields

  When configurable was already a plain dict, _resolve_run_params mutated
  it in place, leaking checkpoint_ns and thread_id back into the shared
  session config. Always copy via dict() before mutating to prevent
  cross-user or cross-channel config pollution.

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-05-04 16:14:07 +08:00
Willem Jiang af6e48ccaa fix(i18n): add Chinese translations for account settings page (#2712)
The account settings page had all user-facing strings (profile labels,
  password form placeholders, validation messages, button text) hardcoded
  in English. Replace them with i18n translation keys so the page renders
  correctly when the locale is set to Chinese.

 Fixed #2710
2026-05-04 11:15:16 +08:00
Willem Jiang b10eb7bafc feat(github): Added container push workflow (#2709)
* feat(github):Added container push workflow

* Apply suggestions from code review

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-04 11:14:34 +08:00
YuJitang d02f762ab0 feat: refine token usage display modes (#2329)
* feat: refine token usage display modes

* docs: clarify token usage accounting semantics

* fix: avoid duplicate subtask debug keys

* style: format token usage tests

* chore: address token attribution review feedback

* Update test_token_usage_middleware.py

* Update test_token_usage_middleware.py

* chore: simplify token attribution fallback

* fix token usage metadata follow-up handling

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-04 09:56:16 +08:00
Willem Jiang 82e7936d36 fix(docker): set UTF-8 locale to prevent ASCII encoding errors in minimal containers (#2707)
* fix(docker): set UTF-8 locale to prevent ASCII encoding errors in minimal containers

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-04 09:41:10 +08:00
Nan Gao 222a7773cb fix(frontend): avoid misleading error message when agent api is disable (#2697) (#2698) 2026-05-04 09:38:05 +08:00
Nan Gao f80ac961ec fix(harness): restore legacy skills path fallback (#2694) (#2696)
* fix(harness): restore legacy skills path fallback (#2694)

* fix(format): make format

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-03 23:40:59 +08:00
wanxsb 44ab21fc44 feat(community): add Serper web search provider (#2630)
* feat(community): add Serper web search provider

Add a new community search provider backed by the Serper Google Search
API (https://serper.dev). Serper returns real-time Google results via a
simple JSON API and requires only an API key — no extra Python package.

Changes:
- backend/packages/harness/deerflow/community/serper/__init__.py
- backend/packages/harness/deerflow/community/serper/tools.py
  Implements web_search_tool using httpx (already a project dependency).
  API key is read from config.yaml `api_key` field or SERPER_API_KEY env var.
  Follows the same interface / output shape as the existing ddg_search provider.
  Exposes max_results parameter (default 5) with config override logic.
- backend/tests/test_serper_tools.py
  Unit tests covering API key resolution, config overrides, HTTP errors,
  empty results, and parameter passing.
- config.example.yaml: add commented-out Serper example alongside other providers
- .env.example: add SERPER_API_KEY placeholder

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Fix the lint error

* Fix the lint error

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-02 16:22:35 +08:00
Hinotobi e543bbf5d6 [security] fix(upload): reject symlinked upload destinations (#2623)
* fix: reject symlinked upload destinations

* test: harden upload destination checks

* fix: address PR feedback for #2623

* test: cover safe upload re-uploads

* fix: preserve upload limit checks after rebase

* fix(upload): stream safe HTTP upload writes
2026-05-02 15:19:28 +08:00
Xinmin Zeng ca3332f8bf fix(gateway): return ISO 8601 timestamps from threads endpoints (#2599)
* fix(gateway): return ISO 8601 timestamps from threads endpoints (#2594)

ThreadResponse documents created_at / updated_at as ISO timestamps,
matching the LangGraph Platform schema (langgraph_sdk.schema.Thread
exposes them as datetime, JSON-encoded as ISO 8601). The gateway
threads router was instead emitting str(time.time()) — unix-second
floats — breaking frontend new Date() parsing and producing a mixed
ISO/unix wire format that also corrupted the search sort order.

Centralize timestamp generation in deerflow.utils.time:
- now_iso()       — datetime.now(UTC).isoformat()
- coerce_iso(x)   — heals legacy unix-timestamp strings on read so the
                    store converges to ISO without a one-shot migration

threads.py: replace 6 time.time() call sites with now_iso(); wrap all
read paths and Phase-2 checkpoint metadata with coerce_iso(); _store_upsert
opportunistically heals legacy created_at on update; drop unused time import.

thread_runs.py: reuse now_iso() instead of a private duplicate _now_iso(),
preventing future drift between the two timestamp call sites.

Tests: 9 unit tests for the helper; 5 integration tests pinning the ISO
contract for create/get/patch/search and the legacy-healing path on the
internal store upsert. Full suite: 2144 passed, 15 skipped, 0 failed.

Closes #2594

* fix(gateway): coerce checkpoint metadata timestamps to ISO on read

After the merge with main, three additional read paths in ``threads.py``
were still emitting raw ``str(metadata.get("created_at", ""))`` —
``get_thread_state``, ``update_thread_state``, and ``get_thread_history``.

Same root cause as #2594: when the checkpoint metadata's ``created_at``
is a unix-second float (legacy data, or a checkpoint written by an older
Gateway version), ``str(float)`` produces ``"1777252410.411327"`` and the
frontend's ``new Date(...)`` returns ``Invalid Date``. The fix on the
``/threads/{id}`` GET path was already in place; these three sibling
endpoints needed the same treatment.

All four call sites now flow through ``coerce_iso``, so:
- legacy float metadata heals to ISO on the way out,
- ISO metadata passes through unchanged,
- ``datetime`` instances (which the new ``coerce_iso`` branch handles
  explicitly) emit with the ``T`` separator instead of falling through
  to the space-separated ``str(datetime)`` form.

Coverage added for the two endpoints not already pinned by the merge:
- ``test_get_thread_state_returns_iso_for_legacy_checkpoint_metadata``
- ``test_get_thread_history_returns_iso_for_legacy_checkpoint_metadata``

Both pre-seed a checkpoint whose metadata carries the literal float
from the issue body and assert the wire format is ISO.
2026-05-02 15:16:16 +08:00
Willem Jiang bb8b234d85 chroe(2585): keep polishing the code of codex token usage (#2689) 2026-05-02 15:04:11 +08:00
KiteEater 17447fccbe fix(runtime): make rollback restore checkpoint supersede newer checkpoints (#2582)
* Restore rollback checkpoints with fresh ids

* Tighten rollback checkpoint tests and imports

* Update test_run_worker_rollback.py

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-02 11:25:45 +08:00
KiteEater 866d1ca409 Populate Codex usage metadata for token accounting (#2585) 2026-05-02 11:16:03 +08:00
164 changed files with 10772 additions and 996 deletions
+14
View File
@@ -1,3 +1,6 @@
# Serper API Key (Google Search) - https://serper.dev
SERPER_API_KEY=your-serper-api-key
# TAVILY API Key
TAVILY_API_KEY=your-tavily-api-key
@@ -45,3 +48,14 @@ INFOQUEST_API_KEY=your-infoquest-api-key
# Set to "false" to disable Swagger UI, ReDoc, and OpenAPI schema in production
# GATEWAY_ENABLE_DOCS=false
# ── Frontend SSR → Gateway wiring ─────────────────────────────────────────────
# The Next.js server uses these to reach the Gateway during SSR (auth checks,
# /api/* rewrites). They default to localhost values that match `make dev` and
# `make start`, so most local users do not need to set them.
#
# Override only when the Gateway is not on localhost:8001 (e.g. when the
# frontend and gateway run on different hosts, in containers with a service
# alias, or behind a different port). docker-compose already sets these.
# DEER_FLOW_INTERNAL_GATEWAY_BASE_URL=http://localhost:8001
# DEER_FLOW_TRUSTED_ORIGINS=http://localhost:3000,http://localhost:2026
+101
View File
@@ -0,0 +1,101 @@
name: Publish Containers
on:
push:
tags:
- "v*"
jobs:
backend-container:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
attestations: write
id-token: write
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}-backend
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Log in to the Container registry
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 #v3.4.0
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 #v5.7.0
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=tag
type=ref,event=branch
type=sha
type=raw,value=latest,enable={{is_default_branch}}
- name: Build and push Docker image
id: push
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0
with:
context: .
file: backend/Dockerfile
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
- name: Generate artifact attestation
uses: actions/attest-build-provenance@v2
with:
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
frontend-container:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
attestations: write
id-token: write
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}-frontend
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Log in to the Container registry
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 #v3.4.0
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 #v5.7.0
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=tag
type=ref,event=branch
type=sha
type=raw,value=latest,enable={{is_default_branch}}
- name: Build and push Docker image
id: push
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 #v6.18.0
with:
context: .
file: frontend/Dockerfile
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
- name: Generate artifact attestation
uses: actions/attest-build-provenance@v2
with:
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true
+6 -2
View File
@@ -263,8 +263,10 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
- `present_files` - Make output files visible to user (only `/mnt/user-data/outputs`)
- `ask_clarification` - Request clarification (intercepted by ClarificationMiddleware → interrupts)
- `view_image` - Read image as base64 (added only if model supports vision)
- `setup_agent` - Bootstrap-only: persist a brand-new custom agent's `SOUL.md` and `config.yaml`. Bound only when `is_bootstrap=True`.
- `update_agent` - Custom-agent-only: persist self-updates to the current agent's `SOUL.md` / `config.yaml` from inside a normal chat (partial update + atomic write). Bound when `agent_name` is set and `is_bootstrap=False`.
4. **Subagent tool** (if enabled):
- `task` - Delegate to subagent (description, prompt, subagent_type, max_turns)
- `task` - Delegate to subagent (description, prompt, subagent_type)
**Community tools** (`packages/harness/deerflow/community/`):
- `tavily/` - Web search (5 results default) and web fetch (4KB limit)
@@ -354,10 +356,11 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
**Per-User Isolation**:
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
- Custom agent definitions (`SOUL.md` + `config.yaml`) are also per-user at `{base_dir}/users/{user_id}/agents/{agent_name}/`. The legacy shared layout `{base_dir}/agents/{agent_name}/` remains read-only fallback for unmigrated installations
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
- Absolute `storage_path` in config opts out of per-user isolation
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json`, `threads/`, and `agents/` into per-user layout. Supports `--dry-run` (preview changes) and `--user-id USER_ID` (assign unowned legacy data to a user, defaults to `default`).
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
@@ -517,6 +520,7 @@ Multi-file upload with automatic document conversion:
- Rejects directory inputs before copying so uploads stay all-or-nothing
- Reuses one conversion worker per request when called from an active event loop
- Files stored in thread-isolated directories
- Duplicate filenames in a single upload request are auto-renamed with `_N` suffixes so later files do not truncate earlier files
- Agent receives uploaded file list via `UploadsMiddleware`
See [docs/FILE_UPLOAD.md](docs/FILE_UPLOAD.md) for details.
+10
View File
@@ -50,6 +50,12 @@ COPY backend ./backend
RUN --mount=type=cache,target=/root/.cache/uv \
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync ${UV_EXTRAS:+--extra $UV_EXTRAS}"
# UTF-8 locale prevents UnicodeEncodeError on Chinese/emoji content in minimal
# containers where locale configuration may be missing and the default encoding is not UTF-8.
ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8
ENV PYTHONIOENCODING=utf-8
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
# Retains compiler toolchain from builder so startup-time `uv sync` can build
# source distributions in development containers.
@@ -66,6 +72,10 @@ CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app
# Clean image without build-essential — reduces size (~200 MB) and attack surface.
FROM python:3.12-slim-bookworm
ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8
ENV PYTHONIOENCODING=utf-8
# Copy Node.js runtime from builder (provides npx for MCP servers)
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
+1 -1
View File
@@ -124,7 +124,7 @@ FastAPI application providing REST endpoints for frontend integration:
| `POST /api/memory/reload` | Force memory reload |
| `GET /api/memory/config` | Memory configuration |
| `GET /api/memory/status` | Combined config + data |
| `POST /api/threads/{id}/uploads` | Upload files (auto-converts PDF/PPT/Excel/Word to Markdown, rejects directory paths) |
| `POST /api/threads/{id}/uploads` | Upload files (auto-converts PDF/PPT/Excel/Word to Markdown, rejects directory paths, auto-renames duplicate filenames in one request) |
| `GET /api/threads/{id}/uploads/list` | List uploaded files |
| `DELETE /api/threads/{id}` | Delete DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
| `GET /api/threads/{id}/artifacts/{path}` | Serve generated artifacts |
+42 -4
View File
@@ -146,6 +146,13 @@ def _normalize_custom_agent_name(raw_value: str) -> str:
return normalized
def _strip_loop_warning_text(text: str) -> str:
"""Remove middleware-authored loop warning lines from display text."""
if "[LOOP DETECTED]" not in text:
return text
return "\n".join(line for line in text.splitlines() if "[LOOP DETECTED]" not in line).strip()
def _extract_response_text(result: dict | list) -> str:
"""Extract the last AI message text from a LangGraph runs.wait result.
@@ -155,7 +162,7 @@ def _extract_response_text(result: dict | list) -> str:
Handles special cases:
- Regular AI text responses
- Clarification interrupts (``ask_clarification`` tool messages)
- AI messages with tool_calls but no text content
- Strips loop-detection warnings attached to tool-call AI messages
"""
if isinstance(result, list):
messages = result
@@ -185,7 +192,12 @@ def _extract_response_text(result: dict | list) -> str:
# Regular AI message with text content
if msg_type == "ai":
content = msg.get("content", "")
has_tool_calls = bool(msg.get("tool_calls"))
if isinstance(content, str) and content:
if has_tool_calls:
content = _strip_loop_warning_text(content)
if not content:
continue
return content
# content can be a list of content blocks
if isinstance(content, list):
@@ -196,6 +208,8 @@ def _extract_response_text(result: dict | list) -> str:
elif isinstance(block, str):
parts.append(block)
text = "".join(parts)
if has_tool_calls:
text = _strip_loop_warning_text(text)
if text:
return text
return ""
@@ -420,7 +434,13 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
if not msg.files:
return []
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
from deerflow.uploads.manager import (
UnsafeUploadPathError,
claim_unique_filename,
ensure_uploads_dir,
normalize_filename,
write_upload_file_no_symlink,
)
uploads_dir = ensure_uploads_dir(thread_id)
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
@@ -471,7 +491,10 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
dest = uploads_dir / safe_name
try:
dest.write_bytes(data)
dest = write_upload_file_no_symlink(uploads_dir, safe_name, data)
except UnsafeUploadPathError:
logger.warning("[Manager] skipping inbound file with unsafe destination: %s", safe_name)
continue
except Exception:
logger.exception("[Manager] failed to write inbound file: %s", dest)
continue
@@ -580,6 +603,17 @@ class ChannelManager:
user_layer.get("config"),
)
configurable = run_config.get("configurable")
if isinstance(configurable, Mapping):
configurable = dict(configurable)
else:
configurable = {}
run_config["configurable"] = configurable
# Pin channel-triggered runs to the root graph namespace so follow-up
# turns continue from the same conversation checkpoint.
configurable["checkpoint_ns"] = ""
configurable["thread_id"] = thread_id
run_context = _merge_dicts(
DEFAULT_RUN_CONTEXT,
self._default_session.get("context"),
@@ -963,7 +997,11 @@ class ChannelManager:
try:
async with httpx.AsyncClient() as http:
resp = await http.get(f"{self._gateway_url}{path}", timeout=10)
resp = await http.get(
f"{self._gateway_url}{path}",
timeout=10,
headers=create_internal_auth_headers(),
)
resp.raise_for_status()
data = resp.json()
except Exception:
+112 -1
View File
@@ -4,8 +4,10 @@ Per RFC-001:
State-changing operations require CSRF protection.
"""
import os
import secrets
from collections.abc import Callable
from urllib.parse import urlsplit
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
@@ -19,7 +21,7 @@ CSRF_TOKEN_LENGTH = 64 # bytes
def is_secure_request(request: Request) -> bool:
"""Detect whether the original client request was made over HTTPS."""
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
return _request_scheme(request) == "https"
def generate_csrf_token() -> str:
@@ -61,6 +63,109 @@ def is_auth_endpoint(request: Request) -> bool:
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
def _host_with_optional_port(hostname: str, port: int | None, scheme: str) -> str:
"""Return normalized host[:port], omitting default ports."""
host = hostname.lower()
if ":" in host and not host.startswith("["):
host = f"[{host}]"
if port is None or (scheme == "http" and port == 80) or (scheme == "https" and port == 443):
return host
return f"{host}:{port}"
def _normalize_origin(origin: str) -> str | None:
"""Return a normalized scheme://host[:port] origin, or None for invalid input."""
try:
parsed = urlsplit(origin.strip())
port = parsed.port
except ValueError:
return None
scheme = parsed.scheme.lower()
if scheme not in {"http", "https"} or not parsed.hostname:
return None
# Browser Origin is only scheme/host/port. Reject URL-shaped or credentialed values.
if parsed.username or parsed.password or parsed.path or parsed.query or parsed.fragment:
return None
return f"{scheme}://{_host_with_optional_port(parsed.hostname, port, scheme)}"
def _configured_cors_origins() -> set[str]:
"""Return explicit configured browser origins that may call auth routes."""
origins = set()
for raw_origin in os.environ.get("GATEWAY_CORS_ORIGINS", "").split(","):
origin = raw_origin.strip()
if not origin or origin == "*":
continue
normalized = _normalize_origin(origin)
if normalized:
origins.add(normalized)
return origins
def _first_header_value(value: str | None) -> str | None:
"""Return the first value from a comma-separated proxy header."""
if not value:
return None
first = value.split(",", 1)[0].strip()
return first or None
def _forwarded_param(request: Request, name: str) -> str | None:
"""Extract a parameter from the first RFC 7239 Forwarded header entry."""
forwarded = _first_header_value(request.headers.get("forwarded"))
if not forwarded:
return None
for part in forwarded.split(";"):
key, sep, value = part.strip().partition("=")
if sep and key.lower() == name:
return value.strip().strip('"') or None
return None
def _request_scheme(request: Request) -> str:
"""Resolve the original request scheme from trusted proxy headers."""
scheme = _forwarded_param(request, "proto") or _first_header_value(request.headers.get("x-forwarded-proto")) or request.url.scheme
return scheme.lower()
def _request_origin(request: Request) -> str | None:
"""Build the origin for the URL the browser is targeting."""
scheme = _request_scheme(request)
host = _forwarded_param(request, "host") or _first_header_value(request.headers.get("x-forwarded-host")) or request.headers.get("host") or request.url.netloc
forwarded_port = _first_header_value(request.headers.get("x-forwarded-port"))
if forwarded_port and ":" not in host.rsplit("]", 1)[-1]:
host = f"{host}:{forwarded_port}"
return _normalize_origin(f"{scheme}://{host}")
def is_allowed_auth_origin(request: Request) -> bool:
"""Allow auth POSTs only from the same origin or explicit configured origins.
Login/register/initialize are exempt from the double-submit token because
first-time browser clients do not have a CSRF token yet. They still create
a session cookie, so browser requests with a hostile Origin header must be
rejected to prevent login CSRF / session fixation. Requests without Origin
are allowed for non-browser clients such as curl and mobile integrations.
"""
origin = request.headers.get("origin")
if not origin:
return True
normalized_origin = _normalize_origin(origin)
if normalized_origin is None:
return False
request_origin = _request_origin(request)
return normalized_origin in _configured_cors_origins() or (request_origin is not None and normalized_origin == request_origin)
class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
@@ -70,6 +175,12 @@ class CSRFMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable) -> Response:
_is_auth = is_auth_endpoint(request)
if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request):
return JSONResponse(
status_code=403,
content={"detail": "Cross-site auth request denied."},
)
if should_check_csrf(request) and not _is_auth:
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
header_token = request.headers.get(CSRF_HEADER_NAME)
+43 -18
View File
@@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
from deerflow.config.agents_api_config import get_agents_api_config
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["agents"])
@@ -86,11 +87,11 @@ def _require_agents_api_enabled() -> None:
)
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False, *, user_id: str | None = None) -> AgentResponse:
"""Convert AgentConfig to AgentResponse."""
soul: str | None = None
if include_soul:
soul = load_agent_soul(agent_cfg.name) or ""
soul = load_agent_soul(agent_cfg.name, user_id=user_id) or ""
return AgentResponse(
name=agent_cfg.name,
@@ -116,9 +117,10 @@ async def list_agents() -> AgentsListResponse:
"""
_require_agents_api_enabled()
user_id = get_effective_user_id()
try:
agents = list_custom_agents()
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
agents = list_custom_agents(user_id=user_id)
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True, user_id=user_id) for a in agents])
except Exception as e:
logger.error(f"Failed to list agents: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
@@ -144,7 +146,12 @@ async def check_agent_name(name: str) -> dict:
_require_agents_api_enabled()
_validate_agent_name(name)
normalized = _normalize_agent_name(name)
available = not get_paths().agent_dir(normalized).exists()
user_id = get_effective_user_id()
paths = get_paths()
# Treat the name as taken if either the per-user path or the legacy shared
# path holds an agent — picking a name that collides with an unmigrated
# legacy agent would shadow the legacy entry once migration runs.
available = not paths.user_agent_dir(user_id, normalized).exists() and not paths.agent_dir(normalized).exists()
return {"available": available, "name": normalized}
@@ -169,10 +176,11 @@ async def get_agent(name: str) -> AgentResponse:
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
user_id = get_effective_user_id()
try:
agent_cfg = load_agent_config(name)
return _agent_config_to_response(agent_cfg, include_soul=True)
agent_cfg = load_agent_config(name, user_id=user_id)
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
except Exception as e:
@@ -202,10 +210,13 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
_require_agents_api_enabled()
_validate_agent_name(request.name)
normalized_name = _normalize_agent_name(request.name)
user_id = get_effective_user_id()
paths = get_paths()
agent_dir = get_paths().agent_dir(normalized_name)
agent_dir = paths.user_agent_dir(user_id, normalized_name)
legacy_dir = paths.agent_dir(normalized_name)
if agent_dir.exists():
if agent_dir.exists() or legacy_dir.exists():
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
try:
@@ -232,8 +243,8 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
agent_cfg = load_agent_config(normalized_name)
return _agent_config_to_response(agent_cfg, include_soul=True)
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
except HTTPException:
raise
@@ -267,13 +278,20 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
user_id = get_effective_user_id()
try:
agent_cfg = load_agent_config(name)
agent_cfg = load_agent_config(name, user_id=user_id)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
agent_dir = get_paths().agent_dir(name)
paths = get_paths()
agent_dir = paths.user_agent_dir(user_id, name)
if not agent_dir.exists() and paths.agent_dir(name).exists():
raise HTTPException(
status_code=409,
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before updating."),
)
try:
# Update config if any config fields changed
@@ -314,8 +332,8 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
logger.info(f"Updated agent '{name}'")
refreshed_cfg = load_agent_config(name)
return _agent_config_to_response(refreshed_cfg, include_soul=True)
refreshed_cfg = load_agent_config(name, user_id=user_id)
return _agent_config_to_response(refreshed_cfg, include_soul=True, user_id=user_id)
except HTTPException:
raise
@@ -402,15 +420,22 @@ async def delete_agent(name: str) -> None:
name: The agent name.
Raises:
HTTPException: 404 if agent not found.
HTTPException: 404 if no per-user copy exists; 409 if only a legacy
shared copy exists (suggesting the migration script).
"""
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
agent_dir = get_paths().agent_dir(name)
user_id = get_effective_user_id()
paths = get_paths()
agent_dir = paths.user_agent_dir(user_id, name)
if not agent_dir.exists():
if paths.agent_dir(name).exists():
raise HTTPException(
status_code=409,
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
)
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
try:
+24 -3
View File
@@ -68,6 +68,27 @@ class RunResponse(BaseModel):
updated_at: str = ""
class ThreadTokenUsageModelBreakdown(BaseModel):
tokens: int = 0
runs: int = 0
class ThreadTokenUsageCallerBreakdown(BaseModel):
lead_agent: int = 0
subagent: int = 0
middleware: int = 0
class ThreadTokenUsageResponse(BaseModel):
thread_id: str
total_tokens: int = 0
total_input_tokens: int = 0
total_output_tokens: int = 0
total_runs: int = 0
by_model: dict[str, ThreadTokenUsageModelBreakdown] = Field(default_factory=dict)
by_caller: ThreadTokenUsageCallerBreakdown = Field(default_factory=ThreadTokenUsageCallerBreakdown)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@@ -368,10 +389,10 @@ async def list_run_events(
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
@router.get("/{thread_id}/token-usage")
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
@require_permission("threads", "read", owner_check=True)
async def thread_token_usage(thread_id: str, request: Request) -> dict:
async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse:
"""Thread-level token usage aggregation."""
run_store = get_run_store(request)
agg = await run_store.aggregate_tokens_by_thread(thread_id)
return {"thread_id": thread_id, **agg}
return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
+23 -21
View File
@@ -13,11 +13,11 @@ matching the LangGraph Platform wire format expected by the
from __future__ import annotations
import logging
import time
import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from langgraph.checkpoint.base import empty_checkpoint
from pydantic import BaseModel, Field, field_validator
from app.gateway.authz import require_permission
@@ -26,6 +26,7 @@ from app.gateway.utils import sanitize_log_param
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.utils.time import coerce_iso, now_iso
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["threads"])
@@ -233,7 +234,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
checkpointer = get_checkpointer(request)
thread_store = get_thread_store(request)
thread_id = body.thread_id or str(uuid.uuid4())
now = time.time()
now = now_iso()
# ``body.metadata`` is already stripped of server-reserved keys by
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
@@ -243,8 +244,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
return ThreadResponse(
thread_id=thread_id,
status=existing_record.get("status", "idle"),
created_at=str(existing_record.get("created_at", "")),
updated_at=str(existing_record.get("updated_at", "")),
created_at=coerce_iso(existing_record.get("created_at", "")),
updated_at=coerce_iso(existing_record.get("updated_at", "")),
metadata=existing_record.get("metadata", {}),
)
@@ -262,8 +263,6 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
# Write an empty checkpoint so state endpoints work immediately
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
from langgraph.checkpoint.base import empty_checkpoint
ckpt_metadata = {
"step": -1,
"source": "input",
@@ -281,8 +280,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
return ThreadResponse(
thread_id=thread_id,
status="idle",
created_at=str(now),
updated_at=str(now),
created_at=now,
updated_at=now,
metadata=body.metadata,
)
@@ -307,8 +306,11 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
ThreadResponse(
thread_id=r["thread_id"],
status=r.get("status", "idle"),
created_at=r.get("created_at", ""),
updated_at=r.get("updated_at", ""),
# ``coerce_iso`` heals legacy unix-second values that
# ``MemoryThreadMetaStore`` historically wrote with ``time.time()``;
# SQL-backed rows already arrive as ISO strings and pass through.
created_at=coerce_iso(r.get("created_at", "")),
updated_at=coerce_iso(r.get("updated_at", "")),
metadata=r.get("metadata", {}),
values={"title": r["display_name"]} if r.get("display_name") else {},
interrupts={},
@@ -340,8 +342,8 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
return ThreadResponse(
thread_id=thread_id,
status=record.get("status", "idle"),
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
created_at=coerce_iso(record.get("created_at", "")),
updated_at=coerce_iso(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
)
@@ -381,8 +383,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
record = {
"thread_id": thread_id,
"status": "idle",
"created_at": ckpt_meta.get("created_at", ""),
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
"created_at": coerce_iso(ckpt_meta.get("created_at", "")),
"updated_at": coerce_iso(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
}
@@ -396,8 +398,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
return ThreadResponse(
thread_id=thread_id,
status=status,
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
created_at=coerce_iso(record.get("created_at", "")),
updated_at=coerce_iso(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
values=serialize_channel_values(channel_values),
)
@@ -448,10 +450,10 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
values=values,
next=next_tasks,
metadata=metadata,
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
checkpoint={"id": checkpoint_id, "ts": coerce_iso(metadata.get("created_at", ""))},
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_checkpoint_id,
created_at=str(metadata.get("created_at", "")),
created_at=coerce_iso(metadata.get("created_at", "")),
tasks=tasks,
)
@@ -501,7 +503,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
channel_values.update(body.values)
checkpoint["channel_values"] = channel_values
metadata["updated_at"] = time.time()
metadata["updated_at"] = now_iso()
if body.as_node:
metadata["source"] = "update"
@@ -542,7 +544,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
next=[],
metadata=metadata,
checkpoint_id=new_checkpoint_id,
created_at=str(metadata.get("created_at", "")),
created_at=coerce_iso(metadata.get("created_at", "")),
)
@@ -609,7 +611,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
parent_checkpoint_id=parent_id,
metadata=user_meta,
values=values,
created_at=str(metadata.get("created_at", "")),
created_at=coerce_iso(metadata.get("created_at", "")),
next=next_tasks,
)
)
+44 -14
View File
@@ -5,7 +5,7 @@ import os
import stat
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from pydantic import BaseModel
from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_config
@@ -15,12 +15,15 @@ from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
from deerflow.uploads.manager import (
PathTraversalError,
UnsafeUploadPathError,
claim_unique_filename,
delete_file_safe,
enrich_file_listing,
ensure_uploads_dir,
get_uploads_dir,
list_files_in_dir,
normalize_filename,
open_upload_file_no_symlink,
upload_artifact_url,
upload_virtual_path,
)
@@ -42,6 +45,7 @@ class UploadResponse(BaseModel):
success: bool
files: list[dict[str, str]]
message: str
skipped_files: list[str] = Field(default_factory=list)
class UploadLimits(BaseModel):
@@ -116,17 +120,18 @@ def _cleanup_uploaded_paths(paths: list[os.PathLike[str] | str]) -> None:
logger.warning("Failed to clean up upload path after rejected request: %s", path, exc_info=True)
async def _write_upload_file_streaming(
async def _write_upload_file_with_limits(
file: UploadFile,
file_path: os.PathLike[str] | str,
*,
uploads_dir: os.PathLike[str] | str,
display_filename: str,
max_single_file_size: int,
max_total_size: int,
total_size: int,
) -> tuple[int, int]:
) -> tuple[os.PathLike[str] | str, int, int]:
file_size = 0
with open(file_path, "wb") as output:
file_path, fh = open_upload_file_no_symlink(uploads_dir, display_filename)
try:
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
file_size += len(chunk)
total_size += len(chunk)
@@ -134,8 +139,17 @@ async def _write_upload_file_streaming(
raise HTTPException(status_code=413, detail=f"File too large: {display_filename}")
if total_size > max_total_size:
raise HTTPException(status_code=413, detail="Total upload size too large")
output.write(chunk)
return file_size, total_size
fh.write(chunk)
except Exception:
fh.close()
try:
os.unlink(file_path)
except FileNotFoundError:
pass
raise
else:
fh.close()
return file_path, file_size, total_size
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
@@ -177,7 +191,12 @@ async def upload_files(
uploaded_files = []
written_paths = []
sandbox_sync_targets = []
skipped_files = []
total_size = 0
# Track filenames within this request so duplicate form parts do not
# silently truncate each other. Existing uploads keep the historical
# overwrite behavior for a single replacement upload.
seen_filenames: set[str] = set()
sandbox_provider = get_sandbox_provider()
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
@@ -194,22 +213,22 @@ async def upload_files(
continue
try:
safe_filename = normalize_filename(file.filename)
original_filename = normalize_filename(file.filename)
safe_filename = claim_unique_filename(original_filename, seen_filenames)
except ValueError:
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
continue
try:
file_path = uploads_dir / safe_filename
written_paths.append(file_path)
file_size, total_size = await _write_upload_file_streaming(
file_path, file_size, total_size = await _write_upload_file_with_limits(
file,
file_path,
uploads_dir=uploads_dir,
display_filename=safe_filename,
max_single_file_size=limits.max_file_size,
max_total_size=limits.max_total_size,
total_size=total_size,
)
written_paths.append(file_path)
virtual_path = upload_virtual_path(safe_filename)
@@ -223,6 +242,8 @@ async def upload_files(
"virtual_path": virtual_path,
"artifact_url": upload_artifact_url(thread_id, safe_filename),
}
if safe_filename != original_filename:
file_info["original_filename"] = original_filename
logger.info(f"Saved file: {safe_filename} ({file_size} bytes) to {file_info['path']}")
@@ -246,6 +267,10 @@ async def upload_files(
except HTTPException as e:
_cleanup_uploaded_paths(written_paths)
raise e
except UnsafeUploadPathError as e:
logger.warning("Skipping upload with unsafe destination %s: %s", file.filename, e)
skipped_files.append(safe_filename)
continue
except Exception as e:
logger.error(f"Failed to upload {file.filename}: {e}")
_cleanup_uploaded_paths(written_paths)
@@ -256,10 +281,15 @@ async def upload_files(
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, file_path.read_bytes())
message = f"Successfully uploaded {len(uploaded_files)} file(s)"
if skipped_files:
message += f"; skipped {len(skipped_files)} unsafe file(s)"
return UploadResponse(
success=True,
success=not skipped_files,
files=uploaded_files,
message=f"Successfully uploaded {len(uploaded_files)} file(s)",
message=message,
skipped_files=skipped_files,
)
+19
View File
@@ -136,6 +136,24 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
runtime_context.setdefault(key, context[key])
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
"""Stamp the authenticated user into the run context for background tools.
Tool execution may happen after the request handler has returned, so tools
that persist user-scoped files should not rely only on ambient ContextVars.
The value comes from server-side auth state, never from client context.
"""
user = getattr(request.state, "user", None)
user_id = getattr(user, "id", None)
if user_id is None:
return
runtime_context = config.setdefault("context", {})
if isinstance(runtime_context, dict):
runtime_context["user_id"] = str(user_id)
def resolve_agent_factory(assistant_id: str | None):
"""Resolve the agent factory callable from config.
@@ -288,6 +306,7 @@ async def start_run(
# that carries agent configuration (model_name, thinking_enabled, etc.).
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
merge_run_context_overrides(config, getattr(body, "context", None))
inject_authenticated_user_context(config, request)
stream_modes = normalize_stream_modes(body.stream_mode)
+20
View File
@@ -79,7 +79,9 @@ async def main():
from langgraph.runtime import Runtime
from deerflow.agents import make_lead_agent
from deerflow.config.paths import get_paths
from deerflow.mcp import initialize_mcp_tools
from deerflow.runtime.user_context import get_effective_user_id
# Initialize MCP tools at startup
try:
@@ -113,6 +115,8 @@ async def main():
print("Tip: `uv sync --group dev` to enable arrow-key & history support")
print("=" * 50)
seen_artifacts: set[str] = set()
while True:
try:
if session:
@@ -134,6 +138,22 @@ async def main():
last_message = result["messages"][-1]
print(f"\nAgent: {last_message.content}")
# Show files presented to the user this turn (new artifacts only)
artifacts = result.get("artifacts") or []
new_artifacts = [p for p in artifacts if p not in seen_artifacts]
if new_artifacts:
thread_id = config["configurable"]["thread_id"]
user_id = get_effective_user_id()
paths = get_paths()
print("\n[Presented files]")
for virtual in new_artifacts:
try:
physical = paths.resolve_virtual_path(thread_id, virtual, user_id=user_id)
print(f" - {virtual}\n{physical}")
except ValueError as exc:
print(f" - {virtual} (failed to resolve physical path: {exc})")
seen_artifacts.update(new_artifacts)
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break
@@ -173,7 +173,7 @@ def _assemble_from_features(
9. MemoryMiddleware (memory feature)
10. ViewImageMiddleware (vision feature)
11. SubagentLimitMiddleware (subagent feature)
12. LoopDetectionMiddleware (always)
12. LoopDetectionMiddleware (loop_detection feature)
13. ClarificationMiddleware (always last)
Two-phase ordering:
@@ -272,10 +272,15 @@ def _assemble_from_features(
extra_tools.append(task_tool)
# --- [12] LoopDetection (always) ---
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
# --- [12] LoopDetection ---
if feat.loop_detection is not False:
if isinstance(feat.loop_detection, AgentMiddleware):
chain.append(feat.loop_detection)
else:
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.config.loop_detection_config import LoopDetectionConfig
chain.append(LoopDetectionMiddleware())
chain.append(LoopDetectionMiddleware.from_config(LoopDetectionConfig()))
# --- [13] Clarification (always last among built-ins) ---
chain.append(ClarificationMiddleware())
@@ -31,6 +31,7 @@ class RuntimeFeatures:
vision: bool | AgentMiddleware = False
auto_title: bool | AgentMiddleware = False
guardrail: Literal[False] | AgentMiddleware = False
loop_detection: bool | AgentMiddleware = True
# ---------------------------------------------------------------------------
@@ -20,6 +20,8 @@ from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.models import create_chat_model
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
from deerflow.skills.types import Skill
logger = logging.getLogger(__name__)
@@ -256,6 +258,12 @@ def _build_middlewares(
resolved_app_config = app_config or get_app_config()
middlewares = build_lead_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
# Always inject current date (and optionally memory) as <system-reminder> into the
# first HumanMessage to keep the system prompt fully static for prefix-cache reuse.
from deerflow.agents.middlewares.dynamic_context_middleware import DynamicContextMiddleware
middlewares.append(DynamicContextMiddleware(agent_name=agent_name, app_config=resolved_app_config))
# Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
if summarization_middleware is not None:
@@ -297,7 +305,9 @@ def _build_middlewares(
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
# LoopDetectionMiddleware — detect and break repetitive tool call loops
middlewares.append(LoopDetectionMiddleware())
loop_detection_config = resolved_app_config.loop_detection
if loop_detection_config.enabled:
middlewares.append(LoopDetectionMiddleware.from_config(loop_detection_config))
# Inject custom middlewares before ClarificationMiddleware
if custom_middlewares:
@@ -308,6 +318,28 @@ def _build_middlewares(
return middlewares
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
if is_bootstrap:
return {"bootstrap"}
if agent_config and agent_config.skills is not None:
return set(agent_config.skills)
return None
def _load_enabled_skills_for_tool_policy(available_skills: set[str] | None, *, app_config: AppConfig) -> list[Skill]:
try:
from deerflow.agents.lead_agent.prompt import get_enabled_skills_for_config
skills = get_enabled_skills_for_config(app_config)
except Exception:
logger.exception("Failed to load skills for allowed-tools policy")
raise
if available_skills is None:
return skills
return [skill for skill in skills if skill.name in available_skills]
def make_lead_agent(config: RunnableConfig):
"""LangGraph graph factory; keep the signature compatible with LangGraph Server."""
runtime_config = _get_runtime_config(config)
@@ -318,7 +350,7 @@ def make_lead_agent(config: RunnableConfig):
def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
# Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent
from deerflow.tools.builtins import setup_agent, update_agent
cfg = _get_runtime_config(config)
resolved_app_config = app_config
@@ -333,6 +365,7 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
agent_name = validate_agent_name(cfg.get("agent_name"))
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
available_skills = _available_skill_names(agent_config, is_bootstrap)
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
agent_model_name = agent_config.model if agent_config and agent_config.model else None
@@ -371,15 +404,18 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
"is_plan_mode": is_plan_mode,
"subagent_enabled": subagent_enabled,
"tool_groups": agent_config.tool_groups if agent_config else None,
"available_skills": ["bootstrap"] if is_bootstrap else (agent_config.skills if agent_config and agent_config.skills is not None else None),
"available_skills": sorted(available_skills) if available_skills is not None else None,
}
)
skills_for_tool_policy = _load_enabled_skills_for_tool_policy(available_skills, app_config=resolved_app_config)
if is_bootstrap:
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent]
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent],
tools=filter_tools_by_skill_allowed_tools(tools, skills_for_tool_policy),
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled,
@@ -390,15 +426,14 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
state_schema=ThreadState,
)
# Custom agents can update their own SOUL.md / config via update_agent.
# The default agent (no agent_name) does not see this tool.
extra_tools = [update_agent] if agent_name else []
# Default lead agent (unchanged behavior)
tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config)
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config),
tools=get_available_tools(
model_name=model_name,
groups=agent_config.tool_groups if agent_config else None,
subagent_enabled=subagent_enabled,
app_config=resolved_app_config,
),
tools=filter_tools_by_skill_allowed_tools(tools + extra_tools, skills_for_tool_policy),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled,
@@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio
import logging
import threading
from datetime import datetime
from functools import lru_cache
from typing import TYPE_CHECKING
@@ -20,6 +19,7 @@ logger = logging.getLogger(__name__)
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
_enabled_skills_lock = threading.Lock()
_enabled_skills_cache: list[Skill] | None = None
_enabled_skills_by_config_cache: dict[int, tuple[object, list[Skill]]] = {}
_enabled_skills_refresh_active = False
_enabled_skills_refresh_version = 0
_enabled_skills_refresh_event = threading.Event()
@@ -84,6 +84,7 @@ def _invalidate_enabled_skills_cache() -> threading.Event:
_get_cached_skills_prompt_section.cache_clear()
with _enabled_skills_lock:
_enabled_skills_cache = None
_enabled_skills_by_config_cache.clear()
_enabled_skills_refresh_version += 1
_enabled_skills_refresh_event.clear()
if _enabled_skills_refresh_active:
@@ -107,6 +108,15 @@ def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_W
def _get_enabled_skills():
return get_cached_enabled_skills()
def get_cached_enabled_skills() -> list[Skill]:
"""Return the cached enabled-skills list, kicking off a background refresh on miss.
Safe to call from request paths: never blocks on disk I/O. Returns an empty
list on cache miss; the next call will see the warmed result.
"""
with _enabled_skills_lock:
cached = _enabled_skills_cache
@@ -117,17 +127,29 @@ def _get_enabled_skills():
return []
def _get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
def get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
"""Return enabled skills using the caller's config source.
When a concrete ``app_config`` is supplied, bypass the global enabled-skills
cache so the skill list and skill paths are resolved from the same config
object. This keeps request-scoped config injection consistent even while the
release branch still supports global fallback paths.
When a concrete ``app_config`` is supplied, cache the loaded skills by that
config object's identity so request-scoped config injection still resolves
skill paths from the matching config without rescanning storage on every
agent factory call.
"""
if app_config is None:
return _get_enabled_skills()
return list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
cache_key = id(app_config)
with _enabled_skills_lock:
cached = _enabled_skills_by_config_cache.get(cache_key)
if cached is not None:
cached_config, cached_skills = cached
if cached_config is app_config:
return list(cached_skills)
skills = list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
with _enabled_skills_lock:
_enabled_skills_by_config_cache[cache_key] = (app_config, skills)
return list(skills)
def _skill_mutability_label(category: SkillCategory | str) -> str:
@@ -344,8 +366,7 @@ You are {agent_name}, an open-source super agent.
</role>
{soul}
{memory_context}
{self_update_section}
<thinking_style>
- Think concisely and strategically about the user's request BEFORE taking action
- Break down the task: What is clear? What is ambiguous? What is missing?
@@ -604,7 +625,7 @@ You have access to skills that provide optimized workflows for specific tasks. E
def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_config: AppConfig | None = None) -> str:
"""Generate the skills prompt section with available skills list."""
skills = _get_enabled_skills_for_config(app_config)
skills = get_enabled_skills_for_config(app_config)
if app_config is None:
try:
@@ -643,6 +664,26 @@ def get_agent_soul(agent_name: str | None) -> str:
return ""
def _build_self_update_section(agent_name: str | None) -> str:
"""Prompt block that teaches the custom agent to persist self-updates via update_agent."""
if not agent_name:
return ""
return f"""<self_update>
You are running as the custom agent **{agent_name}** with a persisted SOUL.md and config.yaml.
When the user asks you to update your own description, personality, behaviour, skill set, tool groups, or default model,
you MUST persist the change with the `update_agent` tool. Do NOT use `bash`, `write_file`, or any sandbox tool to edit
SOUL.md or config.yaml — those write into a temporary sandbox/tool workspace and the changes will be lost on the next turn.
Rules:
- Always pass the FULL replacement text for `soul` (no patch semantics). Start from your current SOUL above and apply the user's edits.
- Only pass the fields that should change. Omit the others to preserve them.
- Pass `skills=[]` to disable all skills, or omit `skills` to keep the existing whitelist.
- After `update_agent` returns successfully, tell the user the change is persisted and will take effect on the next turn.
</self_update>
"""
def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> str:
"""Generate <available-deferred-tools> block for the system prompt.
@@ -732,9 +773,6 @@ def apply_prompt_template(
available_skills: set[str] | None = None,
app_config: AppConfig | None = None,
) -> str:
# Get memory context
memory_context = _get_memory_context(agent_name, app_config=app_config)
# Include subagent section only if enabled (from runtime parameter)
n = max_concurrent_subagents
subagent_section = _build_subagent_section(n, app_config=app_config) if subagent_enabled else ""
@@ -768,17 +806,18 @@ def apply_prompt_template(
custom_mounts_section = _build_custom_mounts_section(app_config=app_config)
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
# Format the prompt with dynamic skills and memory
prompt = SYSTEM_PROMPT_TEMPLATE.format(
# Build and return the fully static system prompt.
# Memory and current date are injected per-turn via DynamicContextMiddleware
# as a <system-reminder> in the first HumanMessage, keeping this prompt
# identical across users and sessions for maximum prefix-cache reuse.
return SYSTEM_PROMPT_TEMPLATE.format(
agent_name=agent_name or "DeerFlow 2.0",
soul=get_agent_soul(agent_name),
self_update_section=_build_self_update_section(agent_name),
skills_section=skills_section,
deferred_tools_section=deferred_tools_section,
memory_context=memory_context,
subagent_section=subagent_section,
subagent_reminder=subagent_reminder,
subagent_thinking=subagent_thinking,
acp_section=acp_and_mounts_section,
)
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"
@@ -0,0 +1,204 @@
"""Middleware to inject dynamic context (memory, current date) as a system-reminder.
The system prompt is kept fully static for maximum prefix-cache reuse across users
and sessions. The current date is always injected. Per-user memory is also injected
when ``memory.injection_enabled`` is True in the app config. Both are delivered once
per conversation as a dedicated <system-reminder> HumanMessage inserted before the
first user message (frozen-snapshot pattern).
When a conversation spans midnight the middleware detects the date change and injects
a lightweight date-update reminder as a separate HumanMessage before the current turn.
This correction is persisted so subsequent turns on the new day see a consistent history
and do not re-inject.
Reminder format:
<system-reminder>
<memory>...</memory>
<current_date>2026-05-08, Friday</current_date>
</system-reminder>
Date-update format:
<system-reminder>
<current_date>2026-05-09, Saturday</current_date>
</system-reminder>
"""
from __future__ import annotations
import logging
import re
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, override
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
_SUMMARY_MESSAGE_NAME = "summary"
def _extract_date(content: str) -> str | None:
"""Return the first <current_date> value found in *content*, or None."""
m = _DATE_RE.search(content)
return m.group(1) if m else None
def is_dynamic_context_reminder(message: object) -> bool:
"""Return whether *message* is a hidden dynamic-context reminder."""
return isinstance(message, HumanMessage) and bool(message.additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY))
def _last_injected_date(messages: list) -> str | None:
"""Scan messages in reverse and return the most recently injected date.
Detection uses the ``dynamic_context_reminder`` additional_kwargs flag rather
than content substring matching, so user messages containing ``<system-reminder>``
are not mistakenly treated as injected reminders.
"""
for msg in reversed(messages):
if is_dynamic_context_reminder(msg):
content_str = msg.content if isinstance(msg.content, str) else str(msg.content)
return _extract_date(content_str)
return None
def _is_user_injection_target(message: object) -> bool:
"""Return whether *message* can receive a dynamic-context reminder."""
return isinstance(message, HumanMessage) and not is_dynamic_context_reminder(message) and message.name != _SUMMARY_MESSAGE_NAME
class DynamicContextMiddleware(AgentMiddleware):
"""Inject memory and current date into HumanMessages as a <system-reminder>.
First turn
----------
Prepends a full system-reminder (memory + date) to the first HumanMessage and
persists it (same message ID). The first message is then frozen for the whole
session — its content never changes again, so the prefix cache can hit on every
subsequent turn.
Midnight crossing
-----------------
If the conversation spans midnight, the current date differs from the date that
was injected earlier. In that case a lightweight date-update reminder is prepended
to the **current** (last) HumanMessage and persisted. Subsequent turns on the new
day see the corrected date in history and skip re-injection.
"""
def __init__(self, agent_name: str | None = None, *, app_config: AppConfig | None = None):
super().__init__()
self._agent_name = agent_name
self._app_config = app_config
def _build_full_reminder(self) -> str:
from deerflow.agents.lead_agent.prompt import _get_memory_context
# Memory injection is gated by injection_enabled; date is always included.
injection_enabled = self._app_config.memory.injection_enabled if self._app_config else True
memory_context = _get_memory_context(self._agent_name, app_config=self._app_config) if injection_enabled else ""
current_date = datetime.now().strftime("%Y-%m-%d, %A")
lines: list[str] = ["<system-reminder>"]
if memory_context:
lines.append(memory_context.strip())
lines.append("") # blank line separating memory from date
lines.append(f"<current_date>{current_date}</current_date>")
lines.append("</system-reminder>")
return "\n".join(lines)
def _build_date_update_reminder(self) -> str:
current_date = datetime.now().strftime("%Y-%m-%d, %A")
return "\n".join(
[
"<system-reminder>",
f"<current_date>{current_date}</current_date>",
"</system-reminder>",
]
)
@staticmethod
def _make_reminder_and_user_messages(original: HumanMessage, reminder_content: str) -> tuple[HumanMessage, HumanMessage]:
"""Return (reminder_msg, user_msg) using the ID-swap technique.
reminder_msg takes the original message's ID so that add_messages replaces it
in-place (preserving position). user_msg carries the original content with a
derived ``{id}__user`` ID and is appended immediately after by add_messages.
If the original message has no ID a stable UUID is generated so the derived
``{id}__user`` ID never collapses to the ambiguous ``None__user`` string.
"""
stable_id = original.id or str(uuid.uuid4())
reminder_msg = HumanMessage(
content=reminder_content,
id=stable_id,
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
)
user_msg = HumanMessage(
content=original.content,
id=f"{stable_id}__user",
name=original.name,
additional_kwargs=original.additional_kwargs,
)
return reminder_msg, user_msg
def _inject(self, state) -> dict | None:
messages = list(state.get("messages", []))
if not messages:
return None
current_date = datetime.now().strftime("%Y-%m-%d, %A")
last_date = _last_injected_date(messages)
logger.debug(
"DynamicContextMiddleware._inject: msg_count=%d last_date=%r current_date=%r",
len(messages),
last_date,
current_date,
)
if last_date is None:
# ── First turn: inject full reminder as a separate HumanMessage ─────
first_idx = next((i for i, m in enumerate(messages) if _is_user_injection_target(m)), None)
if first_idx is None:
return None
full_reminder = self._build_full_reminder()
logger.info(
"DynamicContextMiddleware: injecting full reminder (len=%d, has_memory=%s) into first HumanMessage id=%r",
len(full_reminder),
"<memory>" in full_reminder,
messages[first_idx].id,
)
reminder_msg, user_msg = self._make_reminder_and_user_messages(messages[first_idx], full_reminder)
return {"messages": [reminder_msg, user_msg]}
if last_date == current_date:
# ── Same day: nothing to do ──────────────────────────────────────────
return None
# ── Midnight crossed: inject date-update reminder as a separate HumanMessage ──
last_human_idx = next((i for i in reversed(range(len(messages))) if _is_user_injection_target(messages[i])), None)
if last_human_idx is None:
return None
reminder_msg, user_msg = self._make_reminder_and_user_messages(messages[last_human_idx], self._build_date_update_reminder())
logger.info("DynamicContextMiddleware: midnight crossing detected — injected date update before current turn")
return {"messages": [reminder_msg, user_msg]}
@override
def before_agent(self, state, runtime: Runtime) -> dict | None:
return self._inject(state)
@override
async def abefore_agent(self, state, runtime: Runtime) -> dict | None:
return self._inject(state)
@@ -12,19 +12,23 @@ Detection strategy:
response so the agent is forced to produce a final text answer.
"""
from __future__ import annotations
import hashlib
import json
import logging
import threading
from collections import OrderedDict, defaultdict
from copy import deepcopy
from typing import override
from typing import TYPE_CHECKING, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
if TYPE_CHECKING:
from deerflow.config.loop_detection_config import LoopDetectionConfig
logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor
@@ -140,6 +144,9 @@ _TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
"""Detects and breaks repetitive tool call loops.
Threshold parameters are validated upstream by :class:`LoopDetectionConfig`;
construct via :meth:`from_config` to ensure values pass Pydantic validation.
Args:
warn_threshold: Number of identical tool call sets before injecting
a warning message. Default: 3.
@@ -155,6 +162,14 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
Default: 30.
tool_freq_hard_limit: Number of calls to the same tool type before
forcing a stop. Default: 50.
tool_freq_overrides: Per-tool overrides for frequency thresholds,
keyed by tool name. Each value is a ``(warn, hard_limit)`` tuple
that replaces ``tool_freq_warn`` / ``tool_freq_hard_limit`` for
that specific tool. Tools not listed here fall back to the global
thresholds. Useful for raising limits on intentionally
high-frequency tools (e.g. ``bash`` in batch pipelines) without
weakening protection on all other tools. Default: ``None``
(no overrides).
"""
def __init__(
@@ -165,6 +180,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
tool_freq_overrides: dict[str, tuple[int, int]] | None = None,
):
super().__init__()
self.warn_threshold = warn_threshold
@@ -173,14 +189,26 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
self.max_tracked_threads = max_tracked_threads
self.tool_freq_warn = tool_freq_warn
self.tool_freq_hard_limit = tool_freq_hard_limit
self._tool_freq_overrides: dict[str, tuple[int, int]] = tool_freq_overrides or {}
self._lock = threading.Lock()
# Per-thread tracking using OrderedDict for LRU eviction
self._history: OrderedDict[str, list[str]] = OrderedDict()
self._warned: dict[str, set[str]] = defaultdict(set)
# Per-thread, per-tool-type cumulative call counts
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
@classmethod
def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware:
"""Construct from a Pydantic-validated config, trusting its validation."""
return cls(
warn_threshold=config.warn_threshold,
hard_limit=config.hard_limit,
window_size=config.window_size,
max_tracked_threads=config.max_tracked_threads,
tool_freq_warn=config.tool_freq_warn,
tool_freq_hard_limit=config.tool_freq_hard_limit,
tool_freq_overrides={name: (o.warn, o.hard_limit) for name, o in config.tool_freq_overrides.items()},
)
def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
@@ -280,7 +308,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
freq[name] += 1
tc_count = freq[name]
if tc_count >= self.tool_freq_hard_limit:
if name in self._tool_freq_overrides:
eff_warn, eff_hard = self._tool_freq_overrides[name]
else:
eff_warn, eff_hard = self.tool_freq_warn, self.tool_freq_hard_limit
if tc_count >= eff_hard:
logger.error(
"Tool frequency hard limit reached — forcing stop",
extra={
@@ -291,7 +324,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
)
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
if tc_count >= self.tool_freq_warn:
if tc_count >= eff_warn:
warned = self._tool_freq_warned[thread_id]
if name not in warned:
warned.add(name)
@@ -356,13 +389,30 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
return {"messages": [stripped_msg]}
if warning:
# Inject as HumanMessage instead of SystemMessage to avoid
# Anthropic's "multiple non-consecutive system messages" error.
# Anthropic models require system messages only at the start of
# the conversation; injecting one mid-conversation crashes
# langchain_anthropic's _format_messages(). HumanMessage works
# with all providers. See #1299.
return {"messages": [HumanMessage(content=warning, name="loop_warning")]}
# WORKAROUND for v2.0-m1 — see #2724.
#
# Append the warning to the AIMessage content instead of
# injecting a separate HumanMessage. Inserting any non-tool
# message between an AIMessage(tool_calls=...) and its
# ToolMessage responses breaks OpenAI/Moonshot strict pairing
# validation ("tool_call_ids did not have response messages")
# because the tools node has not run yet at after_model time.
# tool_calls are preserved so the tools node still executes.
#
# This is a temporary mitigation: mutating an existing
# AIMessage to carry framework-authored text leaks loop-warning
# text into downstream consumers (MemoryMiddleware fact
# extraction, TitleMiddleware, telemetry, model replay) as if
# the model said it. The proper fix is to defer warning
# injection from after_model to wrap_model_call so every prior
# ToolMessage is already in the request — see RFC #2517 (which
# lists "loop intervention does not leave invalid
# tool-call/tool-message state" as acceptance criteria) and
# the prototype on `fix/loop-detection-tool-call-pairing`.
messages = state.get("messages", [])
last_msg = messages[-1]
patched_msg = last_msg.model_copy(update={"content": self._append_text(last_msg.content, warning)})
return {"messages": [patched_msg]}
return None
@@ -7,6 +7,7 @@ from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
logger = logging.getLogger(__name__)
@@ -63,7 +64,7 @@ class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
updated_msg = clone_ai_message_with_tool_calls(last_msg, truncated_tool_calls)
return {"messages": [updated_msg]}
@override
@@ -14,6 +14,9 @@ from langgraph.config import get_config
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
logger = logging.getLogger(__name__)
@@ -78,10 +81,7 @@ def _clone_ai_message(
content: Any | None = None,
) -> AIMessage:
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
update: dict[str, Any] = {"tool_calls": tool_calls}
if content is not None:
update["content"] = content
return message.model_copy(update=update)
return clone_ai_message_with_tool_calls(message, tool_calls, content=content)
@dataclass
@@ -136,6 +136,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
return None
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = self._create_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
@@ -161,6 +162,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
return None
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = await self._acreate_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
@@ -180,6 +182,24 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
"""
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
def _preserve_dynamic_context_reminders(
self,
messages_to_summarize: list[AnyMessage],
preserved_messages: list[AnyMessage],
) -> tuple[list[AnyMessage], list[AnyMessage]]:
"""Keep hidden dynamic-context reminders out of summary compression.
These reminders carry the current date and optional memory. If summarization
removes them, DynamicContextMiddleware can mistake the summary HumanMessage
for the first user message and inject the reminder in the wrong place.
"""
reminders = [msg for msg in messages_to_summarize if is_dynamic_context_reminder(msg)]
if not reminders:
return messages_to_summarize, preserved_messages
remaining = [msg for msg in messages_to_summarize if not is_dynamic_context_reminder(msg)]
return remaining, reminders + preserved_messages
def _partition_with_skill_rescue(
self,
messages: list[AnyMessage],
@@ -9,6 +9,7 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
from deerflow.config.title_config import get_title_config
from deerflow.models import create_chat_model
@@ -61,6 +62,10 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return ""
@staticmethod
def _is_user_message_for_title(message: object) -> bool:
return getattr(message, "type", None) == "human" and not is_dynamic_context_reminder(message)
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
"""Check if we should generate a title for this thread."""
config = self._get_title_config()
@@ -77,7 +82,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return False
# Count user and assistant messages
user_messages = [m for m in messages if m.type == "human"]
user_messages = [m for m in messages if self._is_user_message_for_title(m)]
assistant_messages = [m for m in messages if m.type == "ai"]
# Generate title after first complete exchange
@@ -91,7 +96,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
config = self._get_title_config()
messages = state.get("messages", [])
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
user_msg_content = next((m.content for m in messages if self._is_user_message_for_title(m)), "")
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
user_msg = self._normalize_content(user_msg_content)
@@ -1,37 +1,303 @@
"""Middleware for logging LLM token usage."""
"""Middleware for logging token usage and annotating step attribution."""
from __future__ import annotations
import logging
from typing import override
from collections import defaultdict
from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.todo import Todo
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
TOKEN_USAGE_ATTRIBUTION_KEY = "token_usage_attribution"
def _string_arg(value: Any) -> str | None:
if isinstance(value, str):
normalized = value.strip()
return normalized or None
return None
def _normalize_todos(value: Any) -> list[Todo]:
if not isinstance(value, list):
return []
normalized: list[Todo] = []
for item in value:
if not isinstance(item, dict):
continue
todo: Todo = {}
content = _string_arg(item.get("content"))
status = item.get("status")
if content is not None:
todo["content"] = content
if status in {"pending", "in_progress", "completed"}:
todo["status"] = status
normalized.append(todo)
return normalized
def _todo_action_kind(previous: Todo | None, current: Todo) -> str:
status = current.get("status")
previous_content = previous.get("content") if previous else None
current_content = current.get("content")
if previous is None:
if status == "completed":
return "todo_complete"
if status == "in_progress":
return "todo_start"
return "todo_update"
if previous_content != current_content:
return "todo_update"
if status == "completed":
return "todo_complete"
if status == "in_progress":
return "todo_start"
return "todo_update"
def _build_todo_actions(previous_todos: list[Todo], next_todos: list[Todo]) -> list[dict[str, Any]]:
# This is the single source of truth for precise write_todos token
# attribution. The frontend intentionally falls back to a generic
# "Update to-do list" label when this metadata is missing or malformed.
previous_by_content: dict[str, list[tuple[int, Todo]]] = defaultdict(list)
matched_previous_indices: set[int] = set()
for index, todo in enumerate(previous_todos):
content = todo.get("content")
if isinstance(content, str) and content:
previous_by_content[content].append((index, todo))
actions: list[dict[str, Any]] = []
for index, todo in enumerate(next_todos):
content = todo.get("content")
if not isinstance(content, str) or not content:
continue
previous_match: Todo | None = None
content_matches = previous_by_content.get(content)
if content_matches:
while content_matches and content_matches[0][0] in matched_previous_indices:
content_matches.pop(0)
if content_matches:
previous_index, previous_match = content_matches.pop(0)
matched_previous_indices.add(previous_index)
if previous_match is None and index < len(previous_todos) and index not in matched_previous_indices:
previous_match = previous_todos[index]
matched_previous_indices.add(index)
if previous_match is not None:
previous_content = previous_match.get("content")
previous_status = previous_match.get("status")
if previous_content == content and previous_status == todo.get("status"):
continue
actions.append(
{
"kind": _todo_action_kind(previous_match, todo),
"content": content,
}
)
for index, todo in enumerate(previous_todos):
if index in matched_previous_indices:
continue
content = todo.get("content")
if not isinstance(content, str) or not content:
continue
actions.append(
{
"kind": "todo_remove",
"content": content,
}
)
return actions
def _describe_tool_call(tool_call: dict[str, Any], todos: list[Todo]) -> list[dict[str, Any]]:
name = _string_arg(tool_call.get("name")) or "unknown"
args = tool_call.get("args") if isinstance(tool_call.get("args"), dict) else {}
tool_call_id = _string_arg(tool_call.get("id"))
if name == "write_todos":
next_todos = _normalize_todos(args.get("todos"))
actions = _build_todo_actions(todos, next_todos)
if not actions:
return [
{
"kind": "tool",
"tool_name": name,
"tool_call_id": tool_call_id,
}
]
return [
{
**action,
"tool_call_id": tool_call_id,
}
for action in actions
]
if name == "task":
return [
{
"kind": "subagent",
"description": _string_arg(args.get("description")),
"subagent_type": _string_arg(args.get("subagent_type")),
"tool_call_id": tool_call_id,
}
]
if name in {"web_search", "image_search"}:
query = _string_arg(args.get("query"))
return [
{
"kind": "search",
"tool_name": name,
"query": query,
"tool_call_id": tool_call_id,
}
]
if name == "present_files":
return [
{
"kind": "present_files",
"tool_call_id": tool_call_id,
}
]
if name == "ask_clarification":
return [
{
"kind": "clarification",
"tool_call_id": tool_call_id,
}
]
return [
{
"kind": "tool",
"tool_name": name,
"description": _string_arg(args.get("description")),
"tool_call_id": tool_call_id,
}
]
def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
if actions:
first_kind = actions[0].get("kind")
if len(actions) == 1 and first_kind in {"todo_start", "todo_complete", "todo_update", "todo_remove"}:
return "todo_update"
if len(actions) == 1 and first_kind == "subagent":
return "subagent_dispatch"
return "tool_batch"
if message.content:
return "final_answer"
return "thinking"
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
tool_calls = getattr(message, "tool_calls", None) or []
actions: list[dict[str, Any]] = []
current_todos = list(todos)
for raw_tool_call in tool_calls:
if not isinstance(raw_tool_call, dict):
continue
described_actions = _describe_tool_call(raw_tool_call, current_todos)
actions.extend(described_actions)
if raw_tool_call.get("name") == "write_todos":
args = raw_tool_call.get("args") if isinstance(raw_tool_call.get("args"), dict) else {}
current_todos = _normalize_todos(args.get("todos"))
tool_call_ids: list[str] = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
continue
tool_call_id = _string_arg(tool_call.get("id"))
if tool_call_id is not None:
tool_call_ids.append(tool_call_id)
return {
# Schema changes should remain additive where possible so older
# frontends can ignore unknown fields and fall back safely.
"version": 1,
"kind": _infer_step_kind(message, actions),
"shared_attribution": len(actions) > 1,
"tool_call_ids": tool_call_ids,
"actions": actions,
}
class TokenUsageMiddleware(AgentMiddleware):
"""Logs token usage from model response usage_metadata."""
"""Logs token usage from model responses and annotates the AI step."""
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._log_usage(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._log_usage(state)
def _log_usage(self, state: AgentState) -> None:
def _apply(self, state: AgentState) -> dict | None:
messages = state.get("messages", [])
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage):
return None
usage = getattr(last, "usage_metadata", None)
if usage:
input_token_details = usage.get("input_token_details") or {}
output_token_details = usage.get("output_token_details") or {}
detail_parts = []
if input_token_details:
detail_parts.append(f"input_token_details={input_token_details}")
if output_token_details:
detail_parts.append(f"output_token_details={output_token_details}")
detail_suffix = f" {' '.join(detail_parts)}" if detail_parts else ""
logger.info(
"LLM token usage: input=%s output=%s total=%s",
"LLM token usage: input=%s output=%s total=%s%s",
usage.get("input_tokens", "?"),
usage.get("output_tokens", "?"),
usage.get("total_tokens", "?"),
detail_suffix,
)
return None
todos = state.get("todos") or []
attribution = _build_attribution(last, todos if isinstance(todos, list) else [])
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
return None
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
return {"messages": [updated_msg]}
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state)
@@ -0,0 +1,50 @@
"""Helpers for keeping AIMessage tool-call metadata consistent."""
from __future__ import annotations
from typing import Any
from langchain_core.messages import AIMessage
def _raw_tool_call_id(raw_tool_call: Any) -> str | None:
if not isinstance(raw_tool_call, dict):
return None
raw_id = raw_tool_call.get("id")
return raw_id if isinstance(raw_id, str) and raw_id else None
def clone_ai_message_with_tool_calls(
message: AIMessage,
tool_calls: list[dict[str, Any]],
*,
content: Any | None = None,
) -> AIMessage:
"""Clone an AIMessage while keeping raw provider tool-call metadata in sync."""
kept_ids = {tc["id"] for tc in tool_calls if isinstance(tc.get("id"), str) and tc["id"]}
update: dict[str, Any] = {"tool_calls": tool_calls}
if content is not None:
update["content"] = content
additional_kwargs = dict(getattr(message, "additional_kwargs", {}) or {})
raw_tool_calls = additional_kwargs.get("tool_calls")
if isinstance(raw_tool_calls, list):
synced_raw_tool_calls = [raw_tc for raw_tc in raw_tool_calls if _raw_tool_call_id(raw_tc) in kept_ids]
if synced_raw_tool_calls:
additional_kwargs["tool_calls"] = synced_raw_tool_calls
else:
additional_kwargs.pop("tool_calls", None)
if not tool_calls:
additional_kwargs.pop("function_call", None)
update["additional_kwargs"] = additional_kwargs
response_metadata = dict(getattr(message, "response_metadata", {}) or {})
if not tool_calls and response_metadata.get("finish_reason") == "tool_calls":
response_metadata["finish_reason"] = "stop"
update["response_metadata"] = response_metadata
return message.model_copy(update=update)
+98 -19
View File
@@ -264,25 +264,35 @@ class DeerFlowClient:
return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls]
@staticmethod
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None) -> "StreamEvent":
"""Build a ``messages-tuple`` AI text event, attaching usage when present."""
def _serialize_additional_kwargs(msg) -> dict[str, Any] | None:
"""Copy message additional_kwargs when present."""
additional_kwargs = getattr(msg, "additional_kwargs", None)
if isinstance(additional_kwargs, dict) and additional_kwargs:
return dict(additional_kwargs)
return None
@staticmethod
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
"""Build a ``messages-tuple`` AI text event."""
data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
if usage:
data["usage_metadata"] = usage
if additional_kwargs:
data["additional_kwargs"] = additional_kwargs
return StreamEvent(type="messages-tuple", data=data)
@staticmethod
def _ai_tool_calls_event(msg_id: str | None, tool_calls) -> "StreamEvent":
def _ai_tool_calls_event(msg_id: str | None, tool_calls, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
"""Build a ``messages-tuple`` AI tool-calls event."""
return StreamEvent(
type="messages-tuple",
data={
"type": "ai",
"content": "",
"id": msg_id,
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
},
)
data: dict[str, Any] = {
"type": "ai",
"content": "",
"id": msg_id,
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
}
if additional_kwargs:
data["additional_kwargs"] = additional_kwargs
return StreamEvent(type="messages-tuple", data=data)
@staticmethod
def _tool_message_event(msg: ToolMessage) -> "StreamEvent":
@@ -307,19 +317,30 @@ class DeerFlowClient:
d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls)
if getattr(msg, "usage_metadata", None):
d["usage_metadata"] = msg.usage_metadata
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
if isinstance(msg, ToolMessage):
return {
d = {
"type": "tool",
"content": DeerFlowClient._extract_text(msg.content),
"name": getattr(msg, "name", None),
"tool_call_id": getattr(msg, "tool_call_id", None),
"id": getattr(msg, "id", None),
}
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
if isinstance(msg, HumanMessage):
return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
d = {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
if isinstance(msg, SystemMessage):
return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
d = {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)}
@staticmethod
@@ -542,6 +563,7 @@ class DeerFlowClient:
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str}
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str, "usage_metadata": {...}}
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "additional_kwargs": {...}}
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
- type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}}
"""
@@ -564,6 +586,7 @@ class DeerFlowClient:
# in both the final ``messages`` chunk and the values snapshot —
# count it only on whichever arrives first.
counted_usage_ids: set[str] = set()
sent_additional_kwargs_by_id: dict[str, dict[str, Any]] = {}
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
def _account_usage(msg_id: str | None, usage: Any) -> dict | None:
@@ -593,6 +616,20 @@ class DeerFlowClient:
"total_tokens": total_tokens,
}
def _unsent_additional_kwargs(msg_id: str | None, additional_kwargs: dict[str, Any] | None) -> dict[str, Any] | None:
if not additional_kwargs:
return None
if not msg_id:
return additional_kwargs
sent = sent_additional_kwargs_by_id.setdefault(msg_id, {})
delta = {key: value for key, value in additional_kwargs.items() if sent.get(key) != value}
if not delta:
return None
sent.update(delta)
return delta
for item in self._agent.stream(
state,
config=config,
@@ -620,17 +657,31 @@ class DeerFlowClient:
if isinstance(msg_chunk, AIMessage):
text = self._extract_text(msg_chunk.content)
additional_kwargs = self._serialize_additional_kwargs(msg_chunk)
counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata)
sent_additional_kwargs = False
if text:
if msg_id:
streamed_ids.add(msg_id)
yield self._ai_text_event(msg_id, text, counted_usage)
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_text_event(
msg_id,
text,
counted_usage,
additional_kwargs_delta,
)
sent_additional_kwargs = bool(additional_kwargs_delta)
if msg_chunk.tool_calls:
if msg_id:
streamed_ids.add(msg_id)
yield self._ai_tool_calls_event(msg_id, msg_chunk.tool_calls)
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_tool_calls_event(
msg_id,
msg_chunk.tool_calls,
additional_kwargs_delta,
)
elif isinstance(msg_chunk, ToolMessage):
if msg_id:
@@ -653,17 +704,45 @@ class DeerFlowClient:
if msg_id and msg_id in streamed_ids:
if isinstance(msg, AIMessage):
_account_usage(msg_id, getattr(msg, "usage_metadata", None))
additional_kwargs = self._serialize_additional_kwargs(msg)
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
if additional_kwargs_delta:
# Metadata-only follow-up: ``messages-tuple`` has no
# dedicated attribution event, so clients should
# merge this empty-content AI event by message id
# and ignore it for text rendering.
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
continue
if isinstance(msg, AIMessage):
counted_usage = _account_usage(msg_id, msg.usage_metadata)
additional_kwargs = self._serialize_additional_kwargs(msg)
sent_additional_kwargs = False
if msg.tool_calls:
yield self._ai_tool_calls_event(msg_id, msg.tool_calls)
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_tool_calls_event(
msg_id,
msg.tool_calls,
additional_kwargs_delta,
)
sent_additional_kwargs = bool(additional_kwargs_delta)
text = self._extract_text(msg.content)
if text:
yield self._ai_text_event(msg_id, text, counted_usage)
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_text_event(
msg_id,
text,
counted_usage,
additional_kwargs_delta,
)
elif msg_id:
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
if not additional_kwargs_delta:
continue
# See the metadata-only follow-up convention above.
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
elif isinstance(msg, ToolMessage):
yield self._tool_message_event(msg)
@@ -80,6 +80,7 @@ class AioSandboxProvider(SandboxProvider):
port: 8080 # Base port for local containers
container_prefix: deer-flow-sandbox
idle_timeout: 600 # Idle timeout in seconds (0 to disable)
auto_restart: true # Restart crashed containers automatically
replicas: 3 # Max concurrent sandbox containers (LRU eviction when exceeded)
mounts: # Volume mounts for local containers
- host_path: /path/on/host
@@ -164,12 +165,14 @@ class AioSandboxProvider(SandboxProvider):
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
replicas = getattr(sandbox_config, "replicas", None)
auto_restart = getattr(sandbox_config, "auto_restart", True)
return {
"image": sandbox_config.image or DEFAULT_IMAGE,
"port": sandbox_config.port or DEFAULT_PORT,
"container_prefix": sandbox_config.container_prefix or DEFAULT_CONTAINER_PREFIX,
"idle_timeout": idle_timeout if idle_timeout is not None else DEFAULT_IDLE_TIMEOUT,
"auto_restart": auto_restart,
"replicas": replicas if replicas is not None else DEFAULT_REPLICAS,
"mounts": sandbox_config.mounts or [],
"environment": self._resolve_env_vars(sandbox_config.environment or {}),
@@ -608,18 +611,58 @@ class AioSandboxProvider(SandboxProvider):
def get(self, sandbox_id: str) -> Sandbox | None:
"""Get a sandbox by ID. Updates last activity timestamp.
When ``auto_restart`` is enabled (the default), the container's liveness
is verified on each lookup. If the underlying container has crashed, the
sandbox is evicted from all caches so that the next ``acquire()`` call will
transparently create a fresh container.
Args:
sandbox_id: The ID of the sandbox.
Returns:
The sandbox instance if found, None otherwise.
The sandbox instance if found and alive, None otherwise.
"""
with self._lock:
sandbox = self._sandboxes.get(sandbox_id)
if sandbox is not None:
self._last_activity[sandbox_id] = time.time()
if sandbox is None:
return None
self._last_activity[sandbox_id] = time.time()
auto_restart = self._config.get("auto_restart", True)
info = self._sandbox_infos.get(sandbox_id) if auto_restart else None
if not info:
return sandbox
if self._backend.is_alive(info):
return sandbox
info_to_destroy = None
with self._lock:
current_sandbox = self._sandboxes.get(sandbox_id)
current_info = self._sandbox_infos.get(sandbox_id)
if current_sandbox is None:
return None
if current_info is not info:
self._last_activity[sandbox_id] = time.time()
return current_sandbox
logger.warning(f"Sandbox {sandbox_id} container is not alive, evicting from cache for auto-restart")
self._sandboxes.pop(sandbox_id, None)
self._sandbox_infos.pop(sandbox_id, None)
self._last_activity.pop(sandbox_id, None)
self._warm_pool.pop(sandbox_id, None)
thread_ids = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids:
del self._thread_sandboxes[tid]
info_to_destroy = info
if info_to_destroy:
try:
self._backend.destroy(info_to_destroy)
except Exception as e:
logger.warning(f"Failed to cleanup dead sandbox {sandbox_id}: {e}")
return None
def release(self, sandbox_id: str) -> None:
"""Release a sandbox from active use into the warm pool.
@@ -84,8 +84,52 @@ class RemoteSandboxBackend(SandboxBackend):
"""
return self._provisioner_discover(sandbox_id)
def list_running(self) -> list[SandboxInfo]:
"""Return all sandboxes currently managed by the provisioner.
Calls ``GET /api/sandboxes`` so that ``AioSandboxProvider._reconcile_orphans()``
can adopt pods that were created by a previous process and were never
explicitly destroyed.
Without this, a process restart silently orphans all existing k8s Pods
they stay running forever because the idle checker only
tracks in-process state.
"""
return self._provisioner_list()
# ── Provisioner API calls ─────────────────────────────────────────────
def _provisioner_list(self) -> list[SandboxInfo]:
"""GET /api/sandboxes → list all running sandboxes."""
try:
resp = requests.get(f"{self._provisioner_url}/api/sandboxes", timeout=10)
resp.raise_for_status()
data = resp.json()
if not isinstance(data, dict):
logger.warning("Provisioner list_running returned non-dict payload: %r", type(data))
return []
sandboxes = data.get("sandboxes", [])
if not isinstance(sandboxes, list):
logger.warning("Provisioner list_running returned non-list sandboxes: %r", type(sandboxes))
return []
infos: list[SandboxInfo] = []
for sandbox in sandboxes:
if not isinstance(sandbox, dict):
logger.warning("Provisioner list_running entry is not a dict: %r", type(sandbox))
continue
sandbox_id = sandbox.get("sandbox_id")
sandbox_url = sandbox.get("sandbox_url")
if isinstance(sandbox_id, str) and sandbox_id and isinstance(sandbox_url, str) and sandbox_url:
infos.append(SandboxInfo(sandbox_id=sandbox_id, sandbox_url=sandbox_url))
logger.info("Provisioner list_running: %d sandbox(es) found", len(infos))
return infos
except requests.RequestException as exc:
logger.warning("Provisioner list_running failed: %s", exc)
return []
def _provisioner_create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
"""POST /api/sandboxes → create Pod + Service."""
try:
@@ -0,0 +1,3 @@
from .tools import web_search_tool
__all__ = ["web_search_tool"]
@@ -0,0 +1,95 @@
"""
Web Search Tool - Search the web using Serper (Google Search API).
Serper provides real-time Google Search results via a JSON API.
An API key is required. Sign up at https://serper.dev to get one.
"""
import json
import logging
import os
import httpx
from langchain.tools import tool
from deerflow.config import get_app_config
logger = logging.getLogger(__name__)
_SERPER_ENDPOINT = "https://google.serper.dev/search"
_api_key_warned = False
def _get_api_key() -> str | None:
config = get_app_config().get_tool_config("web_search")
if config is not None:
api_key = config.model_extra.get("api_key")
if isinstance(api_key, str) and api_key.strip():
return api_key
return os.getenv("SERPER_API_KEY")
@tool("web_search", parse_docstring=True)
def web_search_tool(query: str, max_results: int = 5) -> str:
"""Search the web for information using Google Search via Serper.
Args:
query: Search keywords describing what you want to find. Be specific for better results.
max_results: Maximum number of search results to return. Default is 5.
"""
global _api_key_warned
config = get_app_config().get_tool_config("web_search")
if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results", max_results)
api_key = _get_api_key()
if not api_key:
if not _api_key_warned:
_api_key_warned = True
logger.warning("Serper API key is not set. Set SERPER_API_KEY in your environment or provide api_key in config.yaml. Sign up at https://serper.dev")
return json.dumps(
{"error": "SERPER_API_KEY is not configured", "query": query},
ensure_ascii=False,
)
headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
}
payload = {"q": query, "num": max_results}
try:
with httpx.Client(timeout=30) as client:
response = client.post(_SERPER_ENDPOINT, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
except httpx.HTTPStatusError as e:
logger.error(f"Serper API returned HTTP {e.response.status_code}: {e.response.text}")
return json.dumps(
{"error": f"Serper API error: HTTP {e.response.status_code}", "query": query},
ensure_ascii=False,
)
except Exception as e:
logger.error(f"Serper search failed: {type(e).__name__}: {e}")
return json.dumps({"error": str(e), "query": query}, ensure_ascii=False)
organic = data.get("organic", [])
if not organic:
return json.dumps({"error": "No results found", "query": query}, ensure_ascii=False)
normalized_results = [
{
"title": r.get("title", ""),
"url": r.get("link", ""),
"content": r.get("snippet", ""),
}
for r in organic[:max_results]
]
output = {
"query": query,
"total_results": len(normalized_results),
"results": normalized_results,
}
return json.dumps(output, indent=2, ensure_ascii=False)
@@ -1,5 +1,6 @@
from .app_config import get_app_config
from .extensions_config import ExtensionsConfig, get_extensions_config
from .loop_detection_config import LoopDetectionConfig
from .memory_config import MemoryConfig, get_memory_config
from .paths import Paths, get_paths
from .skill_evolution_config import SkillEvolutionConfig
@@ -20,6 +21,7 @@ __all__ = [
"SkillsConfig",
"ExtensionsConfig",
"get_extensions_config",
"LoopDetectionConfig",
"MemoryConfig",
"get_memory_config",
"get_tracing_config",
@@ -1,13 +1,22 @@
"""Configuration and loaders for custom agents."""
"""Configuration and loaders for custom agents.
Custom agents are stored per-user under ``{base_dir}/users/{user_id}/agents/{name}/``.
A legacy shared layout at ``{base_dir}/agents/{name}/`` is still readable so that
installations that pre-date user isolation continue to work until they run the
``scripts/migrate_user_isolation.py`` migration. New writes always target the
per-user layout.
"""
import logging
import re
from pathlib import Path
from typing import Any
import yaml
from pydantic import BaseModel
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
@@ -40,14 +49,47 @@ class AgentConfig(BaseModel):
skills: list[str] | None = None
def load_agent_config(name: str | None) -> AgentConfig | None:
def resolve_agent_dir(name: str, *, user_id: str | None = None) -> Path:
"""Return the on-disk directory for an agent, preferring the per-user layout.
Resolution order:
1. ``{base_dir}/users/{user_id}/agents/{name}/`` (per-user, current layout).
2. ``{base_dir}/agents/{name}/`` (legacy shared layout read-only fallback).
If neither exists, the per-user path is returned so callers that intend to
create the agent write into the new layout.
Args:
name: Validated agent name.
user_id: Owner of the agent. Defaults to the effective user from the
request context (or ``"default"`` in no-auth mode).
"""
paths = get_paths()
effective_user = user_id or get_effective_user_id()
user_path = paths.user_agent_dir(effective_user, name)
if user_path.exists():
return user_path
legacy_path = paths.agent_dir(name)
if legacy_path.exists():
return legacy_path
return user_path
def load_agent_config(name: str | None, *, user_id: str | None = None) -> AgentConfig | None:
"""Load the custom or default agent's config from its directory.
Reads from the per-user layout first; falls back to the legacy shared layout
for installations that have not yet been migrated.
Args:
name: The agent name.
user_id: Owner of the agent. Defaults to the effective user from the
current request context.
Returns:
AgentConfig instance.
AgentConfig instance, or ``None`` if ``name`` is ``None``.
Raises:
FileNotFoundError: If the agent directory or config.yaml does not exist.
@@ -58,7 +100,7 @@ def load_agent_config(name: str | None) -> AgentConfig | None:
return None
name = validate_agent_name(name)
agent_dir = get_paths().agent_dir(name)
agent_dir = resolve_agent_dir(name, user_id=user_id)
config_file = agent_dir / "config.yaml"
if not agent_dir.exists():
@@ -84,7 +126,7 @@ def load_agent_config(name: str | None) -> AgentConfig | None:
return AgentConfig(**data)
def load_agent_soul(agent_name: str | None) -> str | None:
def load_agent_soul(agent_name: str | None, *, user_id: str | None = None) -> str | None:
"""Read the SOUL.md file for a custom agent, if it exists.
SOUL.md defines the agent's personality, values, and behavioral guardrails.
@@ -92,11 +134,16 @@ def load_agent_soul(agent_name: str | None) -> str | None:
Args:
agent_name: The name of the agent or None for the default agent.
user_id: Owner of the agent. Defaults to the effective user from the
current request context.
Returns:
The SOUL.md content as a string, or None if the file does not exist.
"""
agent_dir = get_paths().agent_dir(agent_name) if agent_name else get_paths().base_dir
if agent_name:
agent_dir = resolve_agent_dir(agent_name, user_id=user_id)
else:
agent_dir = get_paths().base_dir
soul_path = agent_dir / SOUL_FILENAME
if not soul_path.exists():
return None
@@ -104,32 +151,50 @@ def load_agent_soul(agent_name: str | None) -> str | None:
return content or None
def list_custom_agents() -> list[AgentConfig]:
def list_custom_agents(*, user_id: str | None = None) -> list[AgentConfig]:
"""Scan the agents directory and return all valid custom agents.
Returns the union of agents in the per-user layout and the legacy shared
layout, so that pre-migration installations remain visible until they are
migrated. Per-user entries shadow legacy entries with the same name.
Args:
user_id: Owner whose agents to list. Defaults to the effective user
from the current request context.
Returns:
List of AgentConfig for each valid agent directory found.
"""
agents_dir = get_paths().agents_dir
if not agents_dir.exists():
return []
paths = get_paths()
effective_user = user_id or get_effective_user_id()
seen: set[str] = set()
agents: list[AgentConfig] = []
for entry in sorted(agents_dir.iterdir()):
if not entry.is_dir():
user_root = paths.user_agents_dir(effective_user)
legacy_root = paths.agents_dir
for root in (user_root, legacy_root):
if not root.exists():
continue
for entry in sorted(root.iterdir()):
if not entry.is_dir():
continue
if entry.name in seen:
continue
config_file = entry / "config.yaml"
if not config_file.exists():
logger.debug(f"Skipping {entry.name}: no config.yaml")
continue
config_file = entry / "config.yaml"
if not config_file.exists():
logger.debug(f"Skipping {entry.name}: no config.yaml")
continue
try:
agent_cfg = load_agent_config(entry.name)
agents.append(agent_cfg)
except Exception as e:
logger.warning(f"Skipping agent '{entry.name}': {e}")
try:
agent_cfg = load_agent_config(entry.name, user_id=effective_user)
if agent_cfg is None:
continue
agents.append(agent_cfg)
seen.add(entry.name)
except Exception as e:
logger.warning(f"Skipping agent '{entry.name}': {e}")
agents.sort(key=lambda a: a.name)
return agents
@@ -1,5 +1,6 @@
import logging
import os
from collections.abc import Mapping
from contextvars import ContextVar
from pathlib import Path
from typing import Any, Self
@@ -14,6 +15,7 @@ from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpo
from deerflow.config.database_config import DatabaseConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
from deerflow.config.loop_detection_config import LoopDetectionConfig
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
@@ -99,6 +101,7 @@ class AppConfig(BaseModel):
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
model_config = ConfigDict(extra="allow")
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
@@ -157,56 +160,54 @@ class AppConfig(BaseModel):
config_data = cls.resolve_env_variables(config_data)
cls._apply_database_defaults(config_data)
# Load title config if present
if "title" in config_data:
load_title_config_from_dict(config_data["title"])
# Load summarization config if present
if "summarization" in config_data:
load_summarization_config_from_dict(config_data["summarization"])
# Load memory config if present
if "memory" in config_data:
load_memory_config_from_dict(config_data["memory"])
# Always refresh agents API config so removed config sections reset
# singleton-backed state to its default/disabled values on reload.
load_agents_api_config_from_dict(config_data.get("agents_api") or {})
# Load subagents config if present
if "subagents" in config_data:
load_subagents_config_from_dict(config_data["subagents"])
# Load tool_search config if present
if "tool_search" in config_data:
load_tool_search_config_from_dict(config_data["tool_search"])
# Load guardrails config if present
if "guardrails" in config_data:
load_guardrails_config_from_dict(config_data["guardrails"])
# Load circuit_breaker config if present
if "circuit_breaker" in config_data:
config_data["circuit_breaker"] = config_data["circuit_breaker"]
# Load checkpointer config if present
if "checkpointer" in config_data:
load_checkpointer_config_from_dict(config_data["checkpointer"])
# Load stream bridge config if present
if "stream_bridge" in config_data:
load_stream_bridge_config_from_dict(config_data["stream_bridge"])
# Always refresh ACP agent config so removed entries do not linger across reloads.
load_acp_config_from_dict(config_data.get("acp_agents", {}))
# Load extensions config separately (it's in a different file)
extensions_config = ExtensionsConfig.from_file()
config_data["extensions"] = extensions_config.model_dump()
result = cls.model_validate(config_data)
acp_agents = cls._validate_acp_agents(config_data.get("acp_agents", {}))
cls._apply_singleton_configs(result, acp_agents)
return result
@classmethod
def _validate_acp_agents(
cls,
config_data: Mapping[str, Mapping[str, object]] | None,
) -> dict[str, ACPAgentConfig]:
if config_data is None:
config_data = {}
return {name: ACPAgentConfig(**cfg) for name, cfg in config_data.items()}
@classmethod
def _apply_singleton_configs(cls, config: Self, acp_agents: dict[str, ACPAgentConfig]) -> None:
from deerflow.config.checkpointer_config import get_checkpointer_config
previous_checkpointer_config = get_checkpointer_config()
load_title_config_from_dict(config.title.model_dump())
load_summarization_config_from_dict(config.summarization.model_dump())
load_memory_config_from_dict(config.memory.model_dump())
load_agents_api_config_from_dict(config.agents_api.model_dump())
load_subagents_config_from_dict(config.subagents.model_dump())
load_tool_search_config_from_dict(config.tool_search.model_dump())
load_guardrails_config_from_dict(config.guardrails.model_dump())
load_checkpointer_config_from_dict(config.checkpointer.model_dump() if config.checkpointer is not None else None)
load_stream_bridge_config_from_dict(config.stream_bridge.model_dump() if config.stream_bridge is not None else None)
load_acp_config_from_dict({name: agent.model_dump() for name, agent in acp_agents.items()})
if previous_checkpointer_config != config.checkpointer:
# These runtime singletons derive their backend from checkpointer config.
# Keep imports local to avoid cycles: both providers import get_app_config.
from deerflow.runtime.checkpointer import reset_checkpointer
from deerflow.runtime.store import reset_store
reset_checkpointer()
reset_store()
@classmethod
def _apply_database_defaults(cls, config_data: dict[str, Any]) -> None:
"""Apply config.yaml defaults for persistence when the section is absent."""
@@ -14,12 +14,13 @@ class CheckpointerConfig(BaseModel):
description="Checkpointer backend type. "
"'memory' is in-process only (lost on restart). "
"'sqlite' persists to a local file (requires langgraph-checkpoint-sqlite). "
"'postgres' persists to PostgreSQL (requires langgraph-checkpoint-postgres)."
"'postgres' persists to PostgreSQL (install with deerflow-harness[postgres])."
)
connection_string: str | None = Field(
default=None,
description="Connection string for sqlite (file path) or postgres (DSN). "
"Required for sqlite and postgres types. "
"Optional for sqlite and defaults to 'store.db' when omitted. "
"Required for postgres. "
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
)
@@ -40,7 +41,10 @@ def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
_checkpointer_config = config
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
def load_checkpointer_config_from_dict(config_dict: dict | None) -> None:
"""Load checkpointer configuration from a dictionary."""
global _checkpointer_config
if config_dict is None:
_checkpointer_config = None
return
_checkpointer_config = CheckpointerConfig(**config_dict)
@@ -0,0 +1,73 @@
"""Configuration for loop detection middleware."""
from pydantic import BaseModel, Field, model_validator
class ToolFreqOverride(BaseModel):
"""Per-tool frequency threshold override.
Can be higher or lower than the global defaults. Commonly used to raise
thresholds for high-frequency tools like bash in batch workflows (e.g.
RNA-seq pipelines) without weakening protection on every other tool.
"""
warn: int = Field(ge=1)
hard_limit: int = Field(ge=1)
@model_validator(mode="after")
def _validate(self) -> "ToolFreqOverride":
if self.hard_limit < self.warn:
raise ValueError("hard_limit must be >= warn")
return self
class LoopDetectionConfig(BaseModel):
"""Configuration for repetitive tool-call loop detection."""
enabled: bool = Field(
default=True,
description="Whether to enable repetitive tool-call loop detection",
)
warn_threshold: int = Field(
default=3,
ge=1,
description="Number of identical tool-call sets before injecting a warning",
)
hard_limit: int = Field(
default=5,
ge=1,
description="Number of identical tool-call sets before forcing a stop",
)
window_size: int = Field(
default=20,
ge=1,
description="Number of recent tool-call sets to track per thread",
)
max_tracked_threads: int = Field(
default=100,
ge=1,
description="Maximum number of thread histories to keep in memory",
)
tool_freq_warn: int = Field(
default=30,
ge=1,
description="Number of calls to the same tool type before injecting a frequency warning",
)
tool_freq_hard_limit: int = Field(
default=50,
ge=1,
description="Number of calls to the same tool type before forcing a stop",
)
tool_freq_overrides: dict[str, ToolFreqOverride] = Field(
default_factory=dict,
description=("Per-tool overrides for tool_freq_warn / tool_freq_hard_limit, keyed by tool name. Values can be higher or lower than the global defaults. Commonly used to raise thresholds for high-frequency tools like bash."),
)
@model_validator(mode="after")
def validate_thresholds(self) -> "LoopDetectionConfig":
"""Ensure hard stop cannot happen before the warning threshold."""
if self.hard_limit < self.warn_threshold:
raise ValueError("hard_limit must be greater than or equal to warn_threshold")
if self.tool_freq_hard_limit < self.tool_freq_warn:
raise ValueError("tool_freq_hard_limit must be greater than or equal to tool_freq_warn")
return self
@@ -132,15 +132,20 @@ class Paths:
@property
def agents_dir(self) -> Path:
"""Root directory for all custom agents: `{base_dir}/agents/`."""
"""Legacy root for shared (pre user-isolation) custom agents: `{base_dir}/agents/`.
New code should use :meth:`user_agents_dir` instead. This property remains
only as a read-side fallback for installations that have not yet run the
``migrate_user_isolation.py`` script.
"""
return self.base_dir / "agents"
def agent_dir(self, name: str) -> Path:
"""Directory for a specific agent: `{base_dir}/agents/{name}/`."""
"""Legacy per-agent directory (no user isolation): `{base_dir}/agents/{name}/`."""
return self.agents_dir / name.lower()
def agent_memory_file(self, name: str) -> Path:
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
"""Legacy per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
return self.agent_dir(name) / "memory.json"
def user_dir(self, user_id: str) -> Path:
@@ -151,9 +156,17 @@ class Paths:
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
return self.user_dir(user_id) / "memory.json"
def user_agents_dir(self, user_id: str) -> Path:
"""Per-user root for that user's custom agents: `{base_dir}/users/{user_id}/agents/`."""
return self.user_dir(user_id) / "agents"
def user_agent_dir(self, user_id: str, agent_name: str) -> Path:
"""Per-user per-agent directory: `{base_dir}/users/{user_id}/agents/{name}/`."""
return self.user_agents_dir(user_id) / agent_name.lower()
def user_agent_memory_file(self, user_id: str, agent_name: str) -> Path:
"""Per-user per-agent memory: `{base_dir}/users/{user_id}/agents/{name}/memory.json`."""
return self.user_dir(user_id) / "agents" / agent_name.lower() / "memory.json"
return self.user_agent_dir(user_id, agent_name) / "memory.json"
def thread_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
@@ -23,6 +23,9 @@ class SandboxConfig(BaseModel):
replicas: Maximum number of concurrent sandbox containers (default: 3). When the limit is reached the least-recently-used sandbox is evicted to make room.
container_prefix: Prefix for container names (default: deer-flow-sandbox)
idle_timeout: Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.
auto_restart: Automatically restart sandbox containers that have crashed (default: true). When a tool call
detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated
on the next acquire. Set to false to disable.
mounts: List of volume mounts to share directories with the container
environment: Environment variables to inject into the container (values starting with $ are resolved from host env)
"""
@@ -55,6 +58,10 @@ class SandboxConfig(BaseModel):
default=None,
description="Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.",
)
auto_restart: bool = Field(
default=True,
description="Automatically restart sandbox containers that have crashed. When a tool call detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated on the next acquire.",
)
mounts: list[VolumeMountConfig] = Field(
default_factory=list,
description="List of volume mounts to share directories between host and container",
@@ -6,6 +6,13 @@ from pydantic import BaseModel, Field
from deerflow.config.runtime_paths import project_root, resolve_path
def _legacy_skills_candidates() -> tuple[Path, ...]:
"""Return source-tree skills locations for monorepo compatibility."""
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
return (repo_root / "skills",)
class SkillsConfig(BaseModel):
"""Configuration for skills system"""
@@ -15,7 +22,7 @@ class SkillsConfig(BaseModel):
)
path: str | None = Field(
default=None,
description="Path to skills directory. If not specified, defaults to skills under the caller project root.",
description=("Path to skills directory. If not specified, defaults to `skills` under the caller project root, falling back to the legacy repo-root location for monorepo compatibility."),
)
container_path: str = Field(
default="/mnt/skills",
@@ -26,15 +33,30 @@ class SkillsConfig(BaseModel):
"""
Get the resolved skills directory path.
Returns:
Path to the skills directory
Resolution order:
1. Explicit ``path`` field
2. ``DEER_FLOW_SKILLS_PATH`` environment variable
3. ``skills`` under the caller project root (``project_root()``)
4. Legacy repo-root candidates for monorepo compatibility (``_legacy_skills_candidates``)
When none of (3) or (4) exist on disk, the project-root default is returned so callers
can still surface a stable "no skills" location without raising.
"""
if self.path:
# Use configured path (can be absolute or relative to project root)
return resolve_path(self.path)
if env_path := os.getenv("DEER_FLOW_SKILLS_PATH"):
return resolve_path(env_path)
return project_root() / "skills"
project_default = project_root() / "skills"
if project_default.is_dir():
return project_default
for candidate in _legacy_skills_candidates():
if candidate.is_dir():
return candidate
return project_default
def get_skill_container_path(self, skill_name: str, category: str = "public") -> str:
"""
@@ -40,7 +40,10 @@ def set_stream_bridge_config(config: StreamBridgeConfig | None) -> None:
_stream_bridge_config = config
def load_stream_bridge_config_from_dict(config_dict: dict) -> None:
def load_stream_bridge_config_from_dict(config_dict: dict | None) -> None:
"""Load stream bridge configuration from a dictionary."""
global _stream_bridge_config
if config_dict is None:
_stream_bridge_config = None
return
_stream_bridge_config = StreamBridgeConfig(**config_dict)
@@ -179,9 +179,3 @@ def load_subagents_config_from_dict(config_dict: dict) -> None:
overrides_summary or "none",
custom_agents_names or "none",
)
else:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
)
@@ -4,4 +4,4 @@ from pydantic import BaseModel, Field
class TokenUsageConfig(BaseModel):
"""Configuration for token usage tracking."""
enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
enabled: bool = Field(default=True, description="Enable token usage tracking middleware")
@@ -196,6 +196,10 @@ class ClaudeChatModel(ChatAnthropic):
enforced by both the Anthropic API and AWS Bedrock. Breakpoints are
placed on the *last* eligible blocks because later breakpoints cover a
larger prefix and yield better cache hit rates.
The system prompt is expected to be fully static (no per-user memory or
current date). Dynamic context is injected per-turn via
DynamicContextMiddleware as a <system-reminder> in the first HumanMessage.
"""
MAX_CACHE_BREAKPOINTS = 4
@@ -27,6 +27,34 @@ from deerflow.models.credential_loader import CodexCliCredential, load_codex_cli
logger = logging.getLogger(__name__)
CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
def _build_usage_metadata(oai_usage: dict) -> dict:
"""Convert Codex/Responses API usage dict to LangChain usage_metadata format.
Maps OpenAI Responses API token usage fields to the dict structure that
LangChain AIMessage.usage_metadata expects. This avoids depending on
langchain_openai private helpers like ``_create_usage_metadata_responses``.
"""
input_tokens = oai_usage.get("input_tokens", 0)
output_tokens = oai_usage.get("output_tokens", 0)
total_tokens = oai_usage.get("total_tokens", input_tokens + output_tokens)
metadata: dict = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
}
input_details = oai_usage.get("input_tokens_details") or {}
output_details = oai_usage.get("output_tokens_details") or {}
cache_read = input_details.get("cached_tokens")
if cache_read is not None:
metadata["input_token_details"] = {"cache_read": cache_read}
reasoning = output_details.get("reasoning_tokens")
if reasoning is not None:
metadata["output_token_details"] = {"reasoning": reasoning}
return metadata
MAX_RETRIES = 3
@@ -346,6 +374,7 @@ class CodexChatModel(BaseChatModel):
)
usage = response.get("usage", {})
usage_metadata = _build_usage_metadata(usage) if usage else None
additional_kwargs = {}
if reasoning_content:
additional_kwargs["reasoning_content"] = reasoning_content
@@ -355,6 +384,7 @@ class CodexChatModel(BaseChatModel):
tool_calls=tool_calls if tool_calls else [],
invalid_tool_calls=invalid_tool_calls,
additional_kwargs=additional_kwargs,
usage_metadata=usage_metadata,
response_metadata={
"model": response.get("model", self.model),
"usage": usage,
@@ -81,7 +81,16 @@ async def init_engine(
try:
import asyncpg # noqa: F401
except ImportError:
raise ImportError("database.backend is set to 'postgres' but asyncpg is not installed.\nInstall it with:\n uv sync --extra postgres\nOr switch to backend: sqlite in config.yaml for single-node deployment.") from None
raise ImportError(
"database.backend is set to 'postgres' but asyncpg is not installed.\n"
"Install it with:\n"
" cd backend && uv sync --all-packages --extra postgres\n"
"On the next `make dev` the postgres extra is auto-detected from\n"
"config.yaml (database.backend: postgres) and reinstalled, so it\n"
"will not be wiped again. Set UV_EXTRAS=postgres in .env to opt in\n"
"explicitly. Or switch to backend: sqlite in config.yaml for\n"
"single-node deployment."
) from None
if backend == "sqlite":
import os
@@ -7,13 +7,13 @@ router for thread records.
from __future__ import annotations
import time
from typing import Any
from langgraph.store.base import BaseStore
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
from deerflow.utils.time import coerce_iso, now_iso
THREADS_NS: tuple[str, ...] = ("threads",)
@@ -48,7 +48,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
metadata: dict | None = None,
) -> dict:
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create")
now = time.time()
now = now_iso()
record: dict[str, Any] = {
"thread_id": thread_id,
"assistant_id": assistant_id,
@@ -106,7 +106,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
if record is None:
return
record["display_name"] = display_name
record["updated_at"] = time.time()
record["updated_at"] = now_iso()
await self._store.aput(THREADS_NS, thread_id, record)
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
@@ -114,7 +114,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
if record is None:
return
record["status"] = status
record["updated_at"] = time.time()
record["updated_at"] = now_iso()
await self._store.aput(THREADS_NS, thread_id, record)
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
@@ -124,7 +124,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
merged = dict(record.get("metadata") or {})
merged.update(metadata)
record["metadata"] = merged
record["updated_at"] = time.time()
record["updated_at"] = now_iso()
await self._store.aput(THREADS_NS, thread_id, record)
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
@@ -144,6 +144,8 @@ class MemoryThreadMetaStore(ThreadMetaStore):
"display_name": val.get("display_name"),
"status": val.get("status", "idle"),
"metadata": val.get("metadata", {}),
"created_at": str(val.get("created_at", "")),
"updated_at": str(val.get("updated_at", "")),
# ``coerce_iso`` heals legacy unix-second values written by
# earlier Gateway versions that called ``str(time.time())``.
"created_at": coerce_iso(val.get("created_at", "")),
"updated_at": coerce_iso(val.get("updated_at", "")),
}
@@ -36,7 +36,9 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
SQLITE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite checkpointer. Install it with: uv add langgraph-checkpoint-sqlite"
POSTGRES_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
POSTGRES_INSTALL = (
"langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install the package extra with: pip install 'deerflow-harness[postgres]' (or use: uv sync --all-packages --extra postgres when developing locally)"
)
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
# ---------------------------------------------------------------------------
@@ -9,6 +9,7 @@ from __future__ import annotations
import json
import logging
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@@ -33,20 +34,21 @@ class DbRunEventStore(RunEventStore):
if isinstance(val, datetime):
d["created_at"] = val.isoformat()
d.pop("id", None)
# Restore dict content that was JSON-serialized on write
# Restore structured content that was JSON-serialized on write.
raw = d.get("content", "")
if isinstance(raw, str) and d.get("metadata", {}).get("content_is_dict"):
metadata = d.get("metadata", {})
if isinstance(raw, str) and (metadata.get("content_is_json") or metadata.get("content_is_dict")):
try:
d["content"] = json.loads(raw)
except (json.JSONDecodeError, ValueError):
# Content looked like JSON (content_is_dict flag) but failed to parse;
# Content looked like JSON but failed to parse;
# keep the raw string as-is.
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
return d
def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]:
def _truncate_trace(self, category: str, content: Any, metadata: dict | None) -> tuple[Any, dict]:
if category == "trace":
text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
text = content if isinstance(content, str) else json.dumps(content, default=str, ensure_ascii=False)
encoded = text.encode("utf-8")
if len(encoded) > self._max_trace_content:
# Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
@@ -54,6 +56,18 @@ class DbRunEventStore(RunEventStore):
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
return content, metadata or {}
@staticmethod
def _content_to_db(content: Any, metadata: dict | None) -> tuple[str, dict]:
metadata = metadata or {}
if isinstance(content, str):
return content, metadata
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**metadata, "content_is_json": True}
if isinstance(content, dict):
metadata["content_is_dict"] = True
return db_content, metadata
@staticmethod
def _user_id_from_context() -> str | None:
"""Soft read of user_id from contextvar for write paths.
@@ -82,11 +96,7 @@ class DbRunEventStore(RunEventStore):
the initial ``human_message`` event (once per run).
"""
content, metadata = self._truncate_trace(category, content, metadata)
if isinstance(content, dict):
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
db_content, metadata = self._content_to_db(content, metadata)
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
@@ -128,11 +138,7 @@ class DbRunEventStore(RunEventStore):
category = e.get("category", "trace")
metadata = e.get("metadata")
content, metadata = self._truncate_trace(category, content, metadata)
if isinstance(content, dict):
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
db_content, metadata = self._content_to_db(content, metadata)
row = RunEventRow(
thread_id=e["thread_id"],
run_id=e["run_id"],
@@ -6,9 +6,10 @@ import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from deerflow.utils.time import now_iso as _now_iso
from .schemas import DisconnectMode, RunStatus
if TYPE_CHECKING:
@@ -17,10 +18,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _now_iso() -> str:
return datetime.now(UTC).isoformat()
@dataclass
class RunRecord:
"""Mutable record for a single run."""
@@ -23,6 +23,8 @@ from dataclasses import dataclass, field
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, cast
from langgraph.checkpoint.base import empty_checkpoint
if TYPE_CHECKING:
from langchain_core.messages import HumanMessage
@@ -442,6 +444,12 @@ async def _rollback_to_pre_run_checkpoint(
if checkpoint_to_restore.get("id") is None:
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
return
restore_marker = _new_checkpoint_marker()
checkpoint_to_restore = {
**checkpoint_to_restore,
"id": restore_marker["id"],
"ts": restore_marker["ts"],
}
metadata = pre_run_snapshot.get("metadata", {})
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
@@ -493,6 +501,11 @@ async def _rollback_to_pre_run_checkpoint(
)
def _new_checkpoint_marker() -> dict[str, str]:
marker = empty_checkpoint()
return {"id": marker["id"], "ts": marker["ts"]}
def _lg_mode_to_sse_event(mode: str) -> str:
"""Map LangGraph internal stream_mode name to SSE event name.
@@ -36,7 +36,9 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
SQLITE_STORE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite store. Install it with: uv add langgraph-checkpoint-sqlite"
POSTGRES_STORE_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL store. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
POSTGRES_STORE_INSTALL = (
"langgraph-checkpoint-postgres is required for the PostgreSQL store. Install the package extra with: pip install 'deerflow-harness[postgres]' (or use: uv sync --all-packages --extra postgres when developing locally)"
)
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
# ---------------------------------------------------------------------------
@@ -42,6 +42,13 @@ class LocalSandbox(Sandbox):
"""Return whether the selected shell is cmd.exe."""
return LocalSandbox._shell_name(shell) in {"cmd", "cmd.exe"}
@staticmethod
def _is_msys_shell(shell: str) -> bool:
"""Return whether the selected shell is a Git Bash/MSYS shell."""
normalized = shell.replace("\\", "/").lower()
shell_name = LocalSandbox._shell_name(shell)
return shell_name in {"sh.exe", "bash.exe"} and any(part in normalized for part in ("/git/", "/mingw", "/msys"))
@staticmethod
def _find_first_available_shell(candidates: tuple[str, ...]) -> str | None:
"""Return the first executable shell path or command found from candidates."""
@@ -303,12 +310,19 @@ class LocalSandbox(Sandbox):
shell = self._get_shell()
if os.name == "nt":
env = None
if self._is_powershell(shell):
args = [shell, "-NoProfile", "-Command", resolved_command]
elif self._is_cmd_shell(shell):
args = [shell, "/c", resolved_command]
else:
args = [shell, "-c", resolved_command]
if self._is_msys_shell(shell):
env = {
**os.environ,
"MSYS_NO_PATHCONV": "1",
"MSYS2_ARG_CONV_EXCL": "*",
}
result = subprocess.run(
args,
@@ -316,6 +330,7 @@ class LocalSandbox(Sandbox):
capture_output=True,
text=True,
timeout=600,
env=env,
)
else:
args = [shell, "-c", resolved_command]
@@ -3,10 +3,9 @@ import re
import shlex
from pathlib import Path
from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT
from langchain.tools import tool
from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.agents.thread_state import ThreadDataState
from deerflow.config import get_app_config
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
from deerflow.sandbox.exceptions import (
@@ -19,6 +18,7 @@ from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.sandbox.search import GrepMatch
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
from deerflow.tools.types import Runtime
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
_FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE)
@@ -419,7 +419,7 @@ def _join_path_preserving_style(base: str, relative: str) -> str:
return f"{stripped_base}{separator}{normalized_relative}"
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
def _sanitize_error(error: Exception, runtime: Runtime | None = None) -> str:
"""Sanitize an error message to avoid leaking host filesystem paths.
In local-sandbox mode, resolved host paths in the error string are masked
@@ -994,7 +994,7 @@ def _apply_cwd_prefix(command: str, thread_data: ThreadDataState | None) -> str:
return command
def get_thread_data(runtime: ToolRuntime[ContextT, ThreadState] | None) -> ThreadDataState | None:
def get_thread_data(runtime: Runtime | None) -> ThreadDataState | None:
"""Extract thread_data from runtime state."""
if runtime is None:
return None
@@ -1003,7 +1003,7 @@ def get_thread_data(runtime: ToolRuntime[ContextT, ThreadState] | None) -> Threa
return runtime.state.get("thread_data")
def is_local_sandbox(runtime: ToolRuntime[ContextT, ThreadState] | None) -> bool:
def is_local_sandbox(runtime: Runtime | None) -> bool:
"""Check if the current sandbox is a local sandbox.
Path replacement is only needed for local sandbox since aio sandbox
@@ -1019,7 +1019,7 @@ def is_local_sandbox(runtime: ToolRuntime[ContextT, ThreadState] | None) -> bool
return sandbox_state.get("sandbox_id") == "local"
def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
def sandbox_from_runtime(runtime: Runtime | None = None) -> Sandbox:
"""Extract sandbox instance from tool runtime.
DEPRECATED: Use ensure_sandbox_initialized() for lazy initialization support.
@@ -1048,7 +1048,7 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
return sandbox
def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
def ensure_sandbox_initialized(runtime: Runtime | None = None) -> Sandbox:
"""Ensure sandbox is initialized, acquiring lazily if needed.
On first call, acquires a sandbox from the provider and stores it in runtime state.
@@ -1107,7 +1107,7 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
return sandbox
def ensure_thread_directories_exist(runtime: ToolRuntime[ContextT, ThreadState] | None) -> None:
def ensure_thread_directories_exist(runtime: Runtime | None) -> None:
"""Ensure thread data directories (workspace, uploads, outputs) exist.
This function is called lazily when any sandbox tool is first used.
@@ -1221,7 +1221,7 @@ def _truncate_ls_output(output: str, max_chars: int) -> str:
@tool("bash", parse_docstring=True)
def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str:
def bash_tool(runtime: Runtime, description: str, command: str) -> str:
"""Execute a bash command in a Linux environment.
@@ -1270,7 +1270,7 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
@tool("ls", parse_docstring=True)
def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path: str) -> str:
def ls_tool(runtime: Runtime, description: str, path: str) -> str:
"""List the contents of a directory up to 2 levels deep in tree format.
Args:
@@ -1318,7 +1318,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
@tool("glob", parse_docstring=True)
def glob_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
description: str,
pattern: str,
path: str,
@@ -1368,7 +1368,7 @@ def glob_tool(
@tool("grep", parse_docstring=True)
def grep_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
description: str,
pattern: str,
path: str,
@@ -1438,7 +1438,7 @@ def grep_tool(
@tool("read_file", parse_docstring=True)
def read_file_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
description: str,
path: str,
start_line: int | None = None,
@@ -1493,7 +1493,7 @@ def read_file_tool(
@tool("write_file", parse_docstring=True)
def write_file_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
description: str,
path: str,
content: str,
@@ -1533,7 +1533,7 @@ def write_file_tool(
@tool("str_replace", parse_docstring=True)
def str_replace_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
description: str,
path: str,
old_str: str,
@@ -9,6 +9,29 @@ from .types import SKILL_MD_FILE, Skill, SkillCategory
logger = logging.getLogger(__name__)
def parse_allowed_tools(raw: object, skill_file: Path) -> list[str] | None:
"""Parse the optional allowed-tools frontmatter field.
Returns None when the field is omitted. Returns a list when the field is a
YAML sequence of strings, including an empty list for explicit no-tool
skills. Raises ValueError for malformed values.
"""
if raw is None:
return None
if not isinstance(raw, list):
raise ValueError(f"allowed-tools in {skill_file} must be a list of strings")
allowed_tools: list[str] = []
for item in raw:
if not isinstance(item, str):
raise ValueError(f"allowed-tools in {skill_file} must contain only strings")
tool_name = item.strip()
if not tool_name:
raise ValueError(f"allowed-tools in {skill_file} cannot contain empty tool names")
allowed_tools.append(tool_name)
return allowed_tools
def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: Path | None = None) -> Skill | None:
"""Parse a SKILL.md file and extract metadata.
@@ -64,6 +87,12 @@ def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: P
if license_text is not None:
license_text = str(license_text).strip() or None
try:
allowed_tools = parse_allowed_tools(metadata.get("allowed-tools"), skill_file)
except ValueError as exc:
logger.error("Invalid allowed-tools in %s: %s", skill_file, exc)
return None
return Skill(
name=name,
description=description,
@@ -72,6 +101,7 @@ def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: P
skill_file=skill_file,
relative_path=relative_path or Path(skill_file.parent.name),
category=category,
allowed_tools=allowed_tools,
enabled=True, # Actual state comes from the extensions config file.
)
@@ -0,0 +1,44 @@
import logging
from typing import Protocol
from deerflow.skills.types import Skill
logger = logging.getLogger(__name__)
class NamedTool(Protocol):
name: str
def allowed_tool_names_for_skills(skills: list[Skill]) -> set[str] | None:
"""Return the union of explicit skill allowed-tools declarations.
None means legacy allow-all behavior. It is returned only when no loaded
skill declares allowed-tools. Once any skill declares the field, legacy
skills without the field contribute no tools instead of disabling the
explicit restrictions from other skills.
"""
if not skills:
return None
allowed: set[str] = set()
has_explicit_declaration = False
for skill in skills:
if skill.allowed_tools is None:
continue
has_explicit_declaration = True
if not skill.allowed_tools:
logger.info("Skill %s declared empty allowed-tools", skill.name)
allowed.update(skill.allowed_tools)
if not has_explicit_declaration:
return None
return allowed
def filter_tools_by_skill_allowed_tools[ToolT: NamedTool](tools: list[ToolT], skills: list[Skill]) -> list[ToolT]:
allowed = allowed_tool_names_for_skills(skills)
if allowed is None:
return tools
return [tool for tool in tools if tool.name in allowed]
@@ -27,6 +27,7 @@ class Skill:
skill_file: Path
relative_path: Path # Relative path from category root to skill directory
category: SkillCategory # 'public' or 'custom'
allowed_tools: list[str] | None = None
enabled: bool = False # Whether this skill is enabled
@property
@@ -8,6 +8,7 @@ from pathlib import Path
import yaml
from deerflow.skills.parser import parse_allowed_tools
from deerflow.skills.types import SKILL_MD_FILE
# Allowed properties in SKILL.md frontmatter
@@ -84,4 +85,9 @@ def _validate_skill_frontmatter(skill_dir: Path) -> tuple[bool, str, str | None]
if len(description) > 1024:
return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters.", None
try:
parse_allowed_tools(frontmatter.get("allowed-tools"), skill_md)
except ValueError as e:
return False, str(e).replace(str(skill_md), SKILL_MD_FILE), None
return True, "Skill is valid!", name
@@ -23,6 +23,8 @@ from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadSt
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
from deerflow.skills.types import Skill
from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name
logger = logging.getLogger(__name__)
@@ -260,16 +262,16 @@ class SubagentExecutor:
# Generate trace_id if not provided (for top-level calls)
self.trace_id = trace_id or str(uuid.uuid4())[:8]
# Filter tools based on config
self.tools = _filter_tools(
self._base_tools = _filter_tools(
tools,
config.tools,
config.disallowed_tools,
)
self.tools = self._base_tools
logger.info(f"[trace={self.trace_id}] SubagentExecutor initialized: {config.name} with {len(self.tools)} tools")
def _create_agent(self):
def _create_agent(self, tools: list[BaseTool] | None = None):
"""Create the agent instance."""
app_config = self.app_config or get_app_config()
if self.model_name is None:
@@ -283,26 +285,14 @@ class SubagentExecutor:
return create_agent(
model=model,
tools=self.tools,
tools=tools if tools is not None else self.tools,
middleware=middlewares,
system_prompt=self.config.system_prompt,
state_schema=ThreadState,
)
async def _load_skill_messages(self) -> list[SystemMessage]:
"""Load skill content as conversation items based on config.skills.
Aligned with Codex's pattern: each subagent loads its own skills
per-session and injects them as conversation items (developer messages),
not as system prompt text. The config.skills whitelist controls which
skills are loaded:
- None: load all enabled skills
- []: no skills
- ["skill-a", "skill-b"]: only these skills
Returns:
List of SystemMessages containing skill content.
"""
async def _load_skills(self) -> list[Skill]:
"""Load enabled skill metadata based on config.skills."""
if self.config.skills is not None and len(self.config.skills) == 0:
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} skills=[] — skipping skill loading")
return []
@@ -316,8 +306,8 @@ class SubagentExecutor:
all_skills = await asyncio.to_thread(storage.load_skills, enabled_only=True)
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} loaded {len(all_skills)} enabled skills from disk")
except Exception:
logger.warning(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}", exc_info=True)
return []
logger.exception(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}")
raise
if not all_skills:
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} no enabled skills found")
@@ -326,10 +316,26 @@ class SubagentExecutor:
# Filter by config.skills whitelist
if self.config.skills is not None:
allowed = set(self.config.skills)
skills = [s for s in all_skills if s.name in allowed]
else:
skills = all_skills
return [s for s in all_skills if s.name in allowed]
return all_skills
def _apply_skill_allowed_tools(self, skills: list[Skill]) -> list[BaseTool]:
return filter_tools_by_skill_allowed_tools(self._base_tools, skills)
async def _load_skill_messages(self, skills: list[Skill]) -> list[SystemMessage]:
"""Load skill content as conversation items based on config.skills.
Aligned with Codex's pattern: each subagent loads its own skills
per-session and injects them as conversation items (developer messages),
not as system prompt text. The config.skills whitelist controls which
skills are loaded:
- None: load all enabled skills
- []: no skills
- ["skill-a", "skill-b"]: only these skills
Returns:
List of SystemMessages containing skill content.
"""
if not skills:
return []
@@ -347,19 +353,21 @@ class SubagentExecutor:
return messages
async def _build_initial_state(self, task: str) -> dict[str, Any]:
async def _build_initial_state(self, task: str) -> tuple[dict[str, Any], list[BaseTool]]:
"""Build the initial state for agent execution.
Args:
task: The task description.
Returns:
Initial state dictionary.
Initial state dictionary and tools filtered by loaded skill metadata.
"""
# Load skills as conversation items (Codex pattern)
skill_messages = await self._load_skill_messages()
skills = await self._load_skills()
filtered_tools = self._apply_skill_allowed_tools(skills)
skill_messages = await self._load_skill_messages(skills)
messages: list = []
messages: list[Any] = []
# Skill content injected as developer/system messages before the task
messages.extend(skill_messages)
# Then the actual task
@@ -375,7 +383,7 @@ class SubagentExecutor:
if self.thread_data is not None:
state["thread_data"] = self.thread_data
return state
return state, filtered_tools
async def _aexecute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
"""Execute a task asynchronously.
@@ -405,8 +413,8 @@ class SubagentExecutor:
result.ai_messages = ai_messages
try:
agent = self._create_agent()
state = await self._build_initial_state(task)
state, filtered_tools = await self._build_initial_state(task)
agent = self._create_agent(filtered_tools)
# Build config with thread_id for sandbox access and recursion limit
run_config: RunnableConfig = {
@@ -2,10 +2,12 @@ from .clarification_tool import ask_clarification_tool
from .present_file_tool import present_file_tool
from .setup_agent_tool import setup_agent
from .task_tool import task_tool
from .update_agent_tool import update_agent
from .view_image_tool import view_image_tool
__all__ = [
"setup_agent",
"update_agent",
"present_file_tool",
"ask_clarification_tool",
"view_image_tool",
@@ -1,20 +1,19 @@
from pathlib import Path
from typing import Annotated
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langchain.tools import InjectedToolCallId, tool
from langchain_core.messages import ToolMessage
from langgraph.config import get_config
from langgraph.types import Command
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadState
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.tools.types import Runtime
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
def _get_thread_id(runtime: Runtime) -> str | None:
"""Resolve the current thread id from runtime context or RunnableConfig."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id:
@@ -32,7 +31,7 @@ def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
def _normalize_presented_filepath(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
filepath: str,
) -> str:
"""Normalize a presented file path to the `/mnt/user-data/outputs/*` contract.
@@ -83,7 +82,7 @@ def _normalize_presented_filepath(
@tool("present_files", parse_docstring=True)
def present_file_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
filepaths: list[str],
tool_call_id: Annotated[str, InjectedToolCallId],
) -> Command:
@@ -3,20 +3,28 @@ import logging
import yaml
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langgraph.prebuilt import ToolRuntime
from langgraph.types import Command
from deerflow.config.agents_config import validate_agent_name
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
def _get_runtime_user_id(runtime: Runtime) -> str:
context_user_id = runtime.context.get("user_id") if runtime.context else None
if context_user_id:
return str(context_user_id)
return get_effective_user_id()
@tool
def setup_agent(
soul: str,
description: str,
runtime: ToolRuntime,
runtime: Runtime,
skills: list[str] | None = None,
) -> Command:
"""Setup the custom DeerFlow agent.
@@ -34,7 +42,14 @@ def setup_agent(
try:
agent_name = validate_agent_name(agent_name)
paths = get_paths()
agent_dir = paths.agent_dir(agent_name) if agent_name else paths.base_dir
if agent_name:
# Custom agents are persisted under the current user's bucket so
# different users do not see each other's agents.
user_id = _get_runtime_user_id(runtime)
agent_dir = paths.user_agent_dir(user_id, agent_name)
else:
# Default agent (no agent_name): SOUL.md lives at the global base dir.
agent_dir = paths.base_dir
is_new_dir = not agent_dir.exists()
agent_dir.mkdir(parents=True, exist_ok=True)
@@ -6,11 +6,9 @@ import uuid
from dataclasses import replace
from typing import TYPE_CHECKING, Annotated, Any, cast
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langchain.tools import InjectedToolCallId, tool
from langgraph.config import get_stream_writer
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadState
from deerflow.config import get_app_config
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
@@ -21,6 +19,7 @@ from deerflow.subagents.executor import (
get_background_task_result,
request_cancel_background_task,
)
from deerflow.tools.types import Runtime
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
@@ -50,12 +49,11 @@ def _merge_skill_allowlists(parent: list[str] | None, child: list[str] | None) -
@tool("task", parse_docstring=True)
async def task_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
description: str,
prompt: str,
subagent_type: str,
tool_call_id: Annotated[str, InjectedToolCallId],
max_turns: int | None = None,
) -> str:
"""Delegate a task to a specialized subagent that runs in its own context.
@@ -91,7 +89,6 @@ async def task_tool(
description: A short (3-5 word) description of the task for logging/display. ALWAYS PROVIDE THIS PARAMETER FIRST.
prompt: The task description for the subagent. Be specific and clear about what needs to be done. ALWAYS PROVIDE THIS PARAMETER SECOND.
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max.
"""
runtime_app_config = _get_runtime_app_config(runtime)
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
@@ -113,9 +110,6 @@ async def task_tool(
# each subagent loads its own skills based on config, injected as conversation items).
# No longer appended to system_prompt here.
if max_turns is not None:
overrides["max_turns"] = max_turns
# Extract parent context from runtime
sandbox_state = None
thread_data = None
@@ -0,0 +1,241 @@
"""update_agent tool — let a custom agent persist updates to its own SOUL.md / config.
Bound to the lead agent only when ``runtime.context['agent_name']`` is set
(i.e. inside an existing custom agent's chat). The default agent does not see
this tool, and the bootstrap flow continues to use ``setup_agent`` for the
initial creation handshake.
The tool writes back to ``{base_dir}/users/{user_id}/agents/{agent_name}/{config.yaml,SOUL.md}``
so an agent created by one user is never visible to (or mutable by) another.
Writes are staged into temp files first; both files are renamed into place only
after both temp files are successfully written, so a partial failure cannot leave
config.yaml updated while SOUL.md still holds stale content.
"""
from __future__ import annotations
import logging
import tempfile
from pathlib import Path
from typing import Any
import yaml
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langgraph.types import Command
from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import get_app_config
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
def _stage_temp(path: Path, text: str) -> Path:
"""Write ``text`` into a sibling temp file and return its path.
The caller is responsible for ``Path.replace``-ing the temp into the target
once every staged file is ready, or for unlinking it on failure.
"""
path.parent.mkdir(parents=True, exist_ok=True)
fd = tempfile.NamedTemporaryFile(
mode="w",
dir=path.parent,
suffix=".tmp",
delete=False,
encoding="utf-8",
)
try:
fd.write(text)
fd.flush()
fd.close()
return Path(fd.name)
except BaseException:
fd.close()
Path(fd.name).unlink(missing_ok=True)
raise
def _cleanup_temps(temps: list[Path]) -> None:
"""Best-effort removal of staged temp files."""
for tmp in temps:
try:
tmp.unlink(missing_ok=True)
except OSError:
logger.debug("Failed to clean up temp file %s", tmp, exc_info=True)
@tool
def update_agent(
runtime: Runtime,
soul: str | None = None,
description: str | None = None,
skills: list[str] | None = None,
tool_groups: list[str] | None = None,
model: str | None = None,
) -> Command:
"""Persist updates to the current custom agent's SOUL.md and config.yaml.
Use this when the user asks to refine the agent's identity, description,
skill whitelist, tool-group whitelist, or default model. Only the fields
you explicitly pass are updated; omitted fields keep their existing values.
Pass ``soul`` as the FULL replacement SOUL.md content there is no patch
semantics, so always start from the current SOUL and apply your edits.
Pass ``skills=[]`` to disable all skills for this agent. Omit ``skills``
entirely to keep the existing whitelist.
Args:
soul: Optional full replacement SOUL.md content.
description: Optional new one-line description.
skills: Optional skill whitelist. ``[]`` = no skills, omit = unchanged.
tool_groups: Optional tool-group whitelist. ``[]`` = empty, omit = unchanged.
model: Optional model override (must match a configured model name).
Returns:
Command with a ToolMessage describing the result. Changes take effect
on the next user turn (when the lead agent is rebuilt with the fresh
SOUL.md and config.yaml).
"""
tool_call_id = runtime.tool_call_id
agent_name_raw: str | None = runtime.context.get("agent_name") if runtime.context else None
def _err(message: str) -> Command:
return Command(update={"messages": [ToolMessage(content=f"Error: {message}", tool_call_id=tool_call_id)]})
if soul is None and description is None and skills is None and tool_groups is None and model is None:
return _err("No fields provided. Pass at least one of: soul, description, skills, tool_groups, model.")
try:
agent_name = validate_agent_name(agent_name_raw)
except ValueError as e:
return _err(str(e))
if not agent_name:
return _err("update_agent is only available inside a custom agent's chat. There is no agent_name in the current runtime context, so there is nothing to update. If you are inside the bootstrap flow, use setup_agent instead.")
# Resolve the active user so that updates only affect this user's agent.
# ``get_effective_user_id`` returns DEFAULT_USER_ID when no auth context
# is set (matching how memory and thread storage behave).
user_id = get_effective_user_id()
# Reject an unknown ``model`` *before* touching the filesystem. Otherwise
# ``_resolve_model_name`` silently falls back to the default at runtime
# and the user sees confusing repeated warnings on every later turn.
if model is not None and get_app_config().get_model_config(model) is None:
return _err(f"Unknown model '{model}'. Pass a model name that exists in config.yaml's models section.")
paths = get_paths()
agent_dir = paths.user_agent_dir(user_id, agent_name)
if not agent_dir.exists() and paths.agent_dir(agent_name).exists():
return _err(f"Agent '{agent_name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before updating.")
try:
existing_cfg = load_agent_config(agent_name, user_id=user_id)
except FileNotFoundError:
return _err(f"Agent '{agent_name}' does not exist for the current user. Use setup_agent to create a new agent first.")
except ValueError as e:
return _err(f"Agent '{agent_name}' has an unreadable config: {e}")
if existing_cfg is None:
return _err(f"Agent '{agent_name}' could not be loaded.")
updated_fields: list[str] = []
# Force the on-disk ``name`` to match the directory we are writing into,
# even if ``existing_cfg.name`` had drifted (e.g. from manual yaml edits).
config_data: dict[str, Any] = {"name": agent_name}
new_description = description if description is not None else existing_cfg.description
config_data["description"] = new_description
if description is not None and description != existing_cfg.description:
updated_fields.append("description")
new_model = model if model is not None else existing_cfg.model
if new_model is not None:
config_data["model"] = new_model
if model is not None and model != existing_cfg.model:
updated_fields.append("model")
new_tool_groups = tool_groups if tool_groups is not None else existing_cfg.tool_groups
if new_tool_groups is not None:
config_data["tool_groups"] = new_tool_groups
if tool_groups is not None and tool_groups != existing_cfg.tool_groups:
updated_fields.append("tool_groups")
new_skills = skills if skills is not None else existing_cfg.skills
if new_skills is not None:
config_data["skills"] = new_skills
if skills is not None and skills != existing_cfg.skills:
updated_fields.append("skills")
config_changed = bool({"description", "model", "tool_groups", "skills"} & set(updated_fields))
# Stage every file we intend to rewrite into a temp sibling. Only after
# *all* temp files exist do we rename them into place — so a failure on
# SOUL.md cannot leave config.yaml already replaced.
pending: list[tuple[Path, Path]] = []
staged_temps: list[Path] = []
try:
agent_dir.mkdir(parents=True, exist_ok=True)
if config_changed:
yaml_text = yaml.dump(config_data, default_flow_style=False, allow_unicode=True, sort_keys=False)
config_target = agent_dir / "config.yaml"
config_tmp = _stage_temp(config_target, yaml_text)
staged_temps.append(config_tmp)
pending.append((config_tmp, config_target))
if soul is not None:
soul_target = agent_dir / "SOUL.md"
soul_tmp = _stage_temp(soul_target, soul)
staged_temps.append(soul_tmp)
pending.append((soul_tmp, soul_target))
updated_fields.append("soul")
# Commit phase. ``Path.replace`` is atomic per file on POSIX/NTFS and
# the staging step above means any earlier failure has already been
# reported. The remaining failure mode is a crash *between* two
# ``replace`` calls, which is reported via the partial-write error
# branch below so the caller knows which files are now on disk.
committed: list[Path] = []
try:
for tmp, target in pending:
tmp.replace(target)
committed.append(target)
except Exception as e:
_cleanup_temps([t for t, _ in pending if t not in committed])
if committed:
logger.error(
"[update_agent] Partial write for agent '%s' (user=%s): committed=%s, failed during rename: %s",
agent_name,
user_id,
[p.name for p in committed],
e,
exc_info=True,
)
return _err(f"Partial update for agent '{agent_name}': {[p.name for p in committed]} were updated, but the rest failed ({e}). Re-run update_agent to retry the remaining fields.")
raise
except Exception as e:
_cleanup_temps(staged_temps)
logger.error("[update_agent] Failed to update agent '%s' (user=%s): %s", agent_name, user_id, e, exc_info=True)
return _err(f"Failed to update agent '{agent_name}': {e}")
if not updated_fields:
return Command(update={"messages": [ToolMessage(content=f"No changes applied to agent '{agent_name}'. The provided values matched the existing config.", tool_call_id=tool_call_id)]})
logger.info("[update_agent] Updated agent '%s' (user=%s) fields: %s", agent_name, user_id, updated_fields)
return Command(
update={
"messages": [
ToolMessage(
content=(f"Agent '{agent_name}' updated successfully. Changed: {', '.join(updated_fields)}. The new configuration takes effect on the next user turn."),
tool_call_id=tool_call_id,
)
]
}
)
@@ -3,13 +3,13 @@ import mimetypes
from pathlib import Path
from typing import Annotated
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langchain.tools import InjectedToolCallId, tool
from langchain_core.messages import ToolMessage
from langgraph.types import Command
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.agents.thread_state import ThreadDataState
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
from deerflow.tools.types import Runtime
_ALLOWED_IMAGE_VIRTUAL_ROOTS = (
f"{VIRTUAL_PATH_PREFIX}/workspace",
@@ -48,7 +48,7 @@ def _sanitize_image_error(error: Exception, thread_data: ThreadDataState | None)
@tool("view_image", parse_docstring=True)
def view_image_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
image_path: str,
tool_call_id: Annotated[str, InjectedToolCallId],
) -> Command:
@@ -7,16 +7,15 @@ import logging
from typing import Any
from weakref import WeakValueDictionary
from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT
from langchain.tools import tool
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.agents.thread_state import ThreadState
from deerflow.mcp.tools import _make_sync_tool_wrapper
from deerflow.skills.security_scanner import scan_skill_content
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.storage.skill_storage import SkillStorage
from deerflow.skills.types import SKILL_MD_FILE
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
@@ -31,7 +30,7 @@ def _get_lock(name: str) -> asyncio.Lock:
return lock
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None:
def _get_thread_id(runtime: Runtime | None) -> str | None:
if runtime is None:
return None
if runtime.context and runtime.context.get("thread_id"):
@@ -65,7 +64,7 @@ async def _to_thread(func, /, *args, **kwargs):
async def _skill_manage_impl(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
action: str,
name: str,
content: str | None = None,
@@ -204,7 +203,7 @@ async def _skill_manage_impl(
@tool("skill_manage", parse_docstring=True)
async def skill_manage_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
action: str,
name: str,
content: str | None = None,
@@ -0,0 +1,11 @@
from typing import Any
from langchain.tools import ToolRuntime
from deerflow.agents.thread_state import ThreadState
# Concrete runtime type used by all DeerFlow tools.
# Using dict[str, Any] for the context parameter instead of the unbound ContextT
# TypeVar prevents PydanticSerializationUnexpectedValue warnings when LangChain
# calls model_dump() on a tool's auto-generated args_schema.
Runtime = ToolRuntime[dict[str, Any], ThreadState]
@@ -4,8 +4,10 @@ Pure business logic — no FastAPI/HTTP dependencies.
Both Gateway and Client delegate to these functions.
"""
import errno
import os
import re
import stat
from pathlib import Path
from urllib.parse import quote
@@ -17,6 +19,10 @@ class PathTraversalError(ValueError):
"""Raised when a path escapes its allowed base directory."""
class UnsafeUploadPathError(ValueError):
"""Raised when an upload destination is not a safe regular file path."""
# thread_id must be alphanumeric, hyphens, underscores, or dots only.
_SAFE_THREAD_ID = re.compile(r"^[a-zA-Z0-9._-]+$")
@@ -109,6 +115,108 @@ def validate_path_traversal(path: Path, base: Path) -> None:
raise PathTraversalError("Path traversal detected") from None
def open_upload_file_no_symlink(base_dir: Path, filename: str) -> tuple[Path, object]:
"""Open an upload destination for safe streaming writes.
Upload directories may be mounted into local sandboxes. A sandbox process can
therefore leave a symlink at a future upload filename. Normal ``Path.write_bytes``
follows that link and can overwrite files outside the uploads directory with
gateway privileges. This helper rejects symlink destinations using ``O_NOFOLLOW``
on POSIX. On Windows (which lacks ``O_NOFOLLOW``), it uses dual ``lstat`` checks
and ``fstat`` validation after ``open()`` to reduce the TOCTOU window; this does
not eliminate all races but makes exploitation significantly harder. Path-traversal
validation prevents escapes from *base_dir* in both cases.
"""
safe_name = normalize_filename(filename)
dest = base_dir / safe_name
try:
st = os.lstat(dest)
except FileNotFoundError:
st = None
if st is not None and not stat.S_ISREG(st.st_mode):
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
validate_path_traversal(dest, base_dir)
has_nofollow = hasattr(os, "O_NOFOLLOW")
if has_nofollow:
# POSIX: O_NOFOLLOW makes open() fail with ELOOP if dest is a symlink.
flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW
if hasattr(os, "O_NONBLOCK"):
flags |= os.O_NONBLOCK
try:
fd = os.open(dest, flags, 0o600)
except OSError as exc:
if exc.errno in {errno.ELOOP, errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
raise
try:
opened_stat = os.fstat(fd)
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink != 1:
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
os.ftruncate(fd, 0)
fh = os.fdopen(fd, "wb")
fd = -1
finally:
if fd >= 0:
os.close(fd)
return dest, fh
# Windows: no O_NOFOLLOW available. Uses a second lstat immediately before open()
# to narrow the TOCTOU window, then fstat after open() as a further defence.
# Note: a narrow race window remains between the pre-open lstat and open(); the
# path-traversal check mitigates escapes from base_dir but cannot prevent an
# attacker who can atomically replace dest with a symlink after the check.
if st is not None and st.st_nlink > 1:
raise UnsafeUploadPathError(f"Upload destination has multiple links: {safe_name}")
flags = os.O_WRONLY | os.O_CREAT
if hasattr(os, "O_BINARY"):
flags |= os.O_BINARY
try:
pre_open_st = os.lstat(dest)
except FileNotFoundError:
pre_open_st = None
if pre_open_st is not None and not stat.S_ISREG(pre_open_st.st_mode):
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
if pre_open_st is not None and pre_open_st.st_nlink > 1:
raise UnsafeUploadPathError(f"Upload destination has multiple links: {safe_name}")
try:
fd = os.open(dest, flags, 0o600)
except OSError as exc:
if exc.errno in {errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
raise
try:
opened_stat = os.fstat(fd)
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink > 1:
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
os.ftruncate(fd, 0)
fh = os.fdopen(fd, "wb")
fd = -1
finally:
if fd >= 0:
os.close(fd)
return dest, fh
def write_upload_file_no_symlink(base_dir: Path, filename: str, data: bytes) -> Path:
"""Write upload bytes without following a pre-existing destination symlink."""
dest, fh = open_upload_file_no_symlink(base_dir, filename)
with fh:
fh.write(data)
return dest
def list_files_in_dir(directory: Path) -> dict:
"""List files (not directories) in *directory*.
@@ -0,0 +1,75 @@
"""ISO 8601 timestamp helpers for the Gateway and embedded runtime.
DeerFlow stores and serializes thread/run timestamps as ISO 8601 UTC
strings to match the LangGraph Platform schema (see
``langgraph_sdk.schema.Thread``, where ``created_at`` / ``updated_at``
are ``datetime`` and JSON-encode to ISO 8601). All timestamp generation
should funnel through :func:`now_iso` so the wire format stays
consistent across endpoints, the embedded ``RunManager``, and the
checkpoint metadata written by the Gateway.
:func:`coerce_iso` provides a forward-compatible read path for legacy
records that historically stored ``str(time.time())`` floats.
"""
from __future__ import annotations
import re
from datetime import UTC, datetime
__all__ = ["coerce_iso", "now_iso"]
_UNIX_TIMESTAMP_PATTERN = re.compile(r"^\d{10}(?:\.\d+)?$")
"""Matches the unix-timestamp string shape historically written by
``str(time.time())`` (10-digit seconds with optional fractional part).
The 10-digit anchor avoids accidentally rewriting ISO years like
``"2026"`` and stays valid until the year 2286.
"""
def now_iso() -> str:
"""Return the current UTC time as an ISO 8601 string.
Example: ``"2026-04-27T03:19:46.511479+00:00"``.
"""
return datetime.now(UTC).isoformat()
def coerce_iso(value: object) -> str:
"""Best-effort coerce a stored timestamp to an ISO 8601 string.
Translates legacy unix-timestamp floats / strings written by older
DeerFlow versions into ISO without a one-shot migration. ISO strings
pass through unchanged; ``datetime`` instances are normalised to UTC
(tz-naive values are assumed to be UTC) and emitted via
``isoformat()`` so the wire format always uses the ``T`` separator;
empty values become ``""``; unrecognised values are stringified as a
last resort.
"""
if value is None or value == "":
return ""
if isinstance(value, bool):
# ``bool`` is a subclass of ``int`` — treat as garbage, not 0/1.
return str(value)
if isinstance(value, datetime):
# ``datetime`` must be handled before the ``int``/``float`` check;
# str(datetime) would produce ``"YYYY-MM-DD HH:MM:SS+00:00"``
# (space separator), which breaks strict ISO 8601 consumers.
if value.tzinfo is None:
value = value.replace(tzinfo=UTC)
else:
value = value.astimezone(UTC)
return value.isoformat()
if isinstance(value, (int, float)):
try:
return datetime.fromtimestamp(float(value), UTC).isoformat()
except (ValueError, OverflowError, OSError):
return str(value)
if isinstance(value, str):
if _UNIX_TIMESTAMP_PATTERN.match(value):
try:
return datetime.fromtimestamp(float(value), UTC).isoformat()
except (ValueError, OverflowError, OSError):
return value
return value
return str(value)
+1 -2
View File
@@ -8,7 +8,7 @@ dependencies = [
"deerflow-harness",
"fastapi>=0.115.0",
"httpx>=0.28.0",
"python-multipart>=0.0.26",
"python-multipart>=0.0.27",
"sse-starlette>=2.1.0",
"uvicorn[standard]>=0.34.0",
"lark-oapi>=1.4.0",
@@ -47,4 +47,3 @@ members = ["packages/harness"]
[tool.uv.sources]
deerflow-harness = { workspace = true }
+86 -3
View File
@@ -1,7 +1,7 @@
"""One-time migration: move legacy thread dirs and memory into per-user layout.
Usage:
PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run]
PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run] [--user-id USER_ID]
The script is idempotent re-running it after a successful migration is a no-op.
"""
@@ -69,6 +69,67 @@ def migrate_thread_dirs(
return report
def migrate_agents(
paths: Paths,
user_id: str = "default",
*,
dry_run: bool = False,
) -> list[dict]:
"""Move legacy custom-agent directories into per-user layout.
Legacy layout: ``{base_dir}/agents/{name}/``
Per-user layout: ``{base_dir}/users/{user_id}/agents/{name}/``
Pre-existing per-user agents take precedence: if a destination already
exists for an agent name, the legacy copy is moved to
``{base_dir}/migration-conflicts/agents/{name}/`` for manual review.
Args:
paths: Paths instance.
user_id: Target user to receive the legacy agents (defaults to
``"default"``, matching ``DEFAULT_USER_ID`` for no-auth setups).
dry_run: If True, only log what would happen.
Returns:
List of migration report entries, one per legacy agent directory found.
"""
report: list[dict] = []
legacy_agents = paths.agents_dir
if not legacy_agents.exists():
logger.info("No legacy agents directory found — nothing to migrate.")
return report
for agent_dir in sorted(legacy_agents.iterdir()):
if not agent_dir.is_dir():
continue
agent_name = agent_dir.name
dest = paths.user_agent_dir(user_id, agent_name)
entry = {"agent": agent_name, "user_id": user_id, "action": ""}
if dest.exists():
conflicts_dir = paths.base_dir / "migration-conflicts" / "agents" / agent_name
entry["action"] = f"conflict -> {conflicts_dir}"
if not dry_run:
conflicts_dir.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(agent_dir), str(conflicts_dir))
logger.warning("Conflict for agent %s: moved legacy copy to %s", agent_name, conflicts_dir)
else:
entry["action"] = f"moved -> {dest}"
if not dry_run:
dest.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(agent_dir), str(dest))
logger.info("Migrated agent %s -> user %s", agent_name, user_id)
report.append(entry)
# Clean up empty legacy agents dir
if not dry_run and legacy_agents.exists() and not any(legacy_agents.iterdir()):
legacy_agents.rmdir()
return report
def migrate_memory(
paths: Paths,
user_id: str = "default",
@@ -127,6 +188,12 @@ def _build_owner_map_from_db(paths: Paths) -> dict[str, str]:
def main() -> None:
parser = argparse.ArgumentParser(description="Migrate DeerFlow data to per-user layout")
parser.add_argument("--dry-run", action="store_true", help="Log actions without making changes")
parser.add_argument(
"--user-id",
default="default",
metavar="USER_ID",
help=("User ID to claim un-owned legacy data (global memory.json and legacy custom agents). Defaults to 'default'. In multi-user installs, set this to the operator account that should inherit those legacy artifacts."),
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
@@ -134,26 +201,42 @@ def main() -> None:
paths = get_paths()
logger.info("Base directory: %s", paths.base_dir)
logger.info("Dry run: %s", args.dry_run)
logger.info("Claiming un-owned legacy data for user_id=%s", args.user_id)
owner_map = _build_owner_map_from_db(paths)
logger.info("Found %d thread ownership records in DB", len(owner_map))
report = migrate_thread_dirs(paths, owner_map, dry_run=args.dry_run)
migrate_memory(paths, user_id="default", dry_run=args.dry_run)
migrate_memory(paths, user_id=args.user_id, dry_run=args.dry_run)
agent_report = migrate_agents(paths, user_id=args.user_id, dry_run=args.dry_run)
if report:
logger.info("Migration report:")
logger.info("Thread migration report:")
for entry in report:
logger.info(" thread=%s user=%s action=%s", entry["thread_id"], entry["user_id"], entry["action"])
else:
logger.info("No threads to migrate.")
if agent_report:
logger.info("Agent migration report:")
for entry in agent_report:
logger.info(" agent=%s user=%s action=%s", entry["agent"], entry["user_id"], entry["action"])
else:
logger.info("No agents to migrate.")
unowned = [e for e in report if e["user_id"] == "default"]
if unowned:
logger.warning("%d thread(s) had no owner and were assigned to 'default':", len(unowned))
for e in unowned:
logger.warning(" %s", e["thread_id"])
if agent_report:
logger.warning(
"%d legacy agent(s) were assigned to '%s'. If those agents belonged to other users, move them manually under {base_dir}/users/<user_id>/agents/.",
len(agent_report),
args.user_id,
)
if __name__ == "__main__":
main()
@@ -0,0 +1,210 @@
"""Tests for AioSandboxProvider auto-restart of crashed containers."""
import importlib
import threading
from unittest.mock import MagicMock, patch
def _import_provider():
return importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
def _make_provider(*, auto_restart=True, alive=True):
"""Build a minimal AioSandboxProvider with a mock backend.
Args:
auto_restart: Value for the auto_restart config key.
alive: Whether the mock backend reports containers as alive.
"""
mod = _import_provider()
with patch.object(mod.AioSandboxProvider, "_start_idle_checker"):
provider = mod.AioSandboxProvider.__new__(mod.AioSandboxProvider)
provider._config = {"auto_restart": auto_restart}
provider._lock = threading.Lock()
provider._sandboxes = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._thread_locks = {}
provider._last_activity = {}
provider._warm_pool = {}
provider._shutdown_called = False
provider._idle_checker_stop = threading.Event()
backend = MagicMock()
backend.is_alive.return_value = alive
provider._backend = backend
return provider, backend
def _seed_sandbox(provider, sandbox_id="dead-beef", thread_id="thread-1"):
"""Insert a sandbox into the provider's caches as if it were acquired."""
sandbox = MagicMock()
info = MagicMock()
provider._sandboxes[sandbox_id] = sandbox
provider._sandbox_infos[sandbox_id] = info
provider._last_activity[sandbox_id] = 0.0
if thread_id:
provider._thread_sandboxes[thread_id] = sandbox_id
return sandbox, info
# ── get() returns sandbox when container is alive ──────────────────────────
def test_get_returns_sandbox_when_container_alive():
"""When auto_restart is on and the container is alive, get() returns the sandbox."""
provider, backend = _make_provider(auto_restart=True, alive=True)
sandbox, _ = _seed_sandbox(provider)
result = provider.get("dead-beef")
assert result is sandbox
backend.is_alive.assert_called_once()
def test_get_returns_sandbox_when_auto_restart_disabled():
"""When auto_restart is off, get() skips the health check entirely."""
provider, backend = _make_provider(auto_restart=False)
sandbox, _ = _seed_sandbox(provider)
result = provider.get("dead-beef")
assert result is sandbox
backend.is_alive.assert_not_called()
# ── get() evicts dead sandbox when auto_restart is on ──────────────────────
def test_get_evicts_dead_sandbox_when_auto_restart_enabled():
"""When the container is dead and auto_restart is on, get() returns None and cleans caches."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_, info = _seed_sandbox(provider, sandbox_id="dead-beef", thread_id="thread-1")
result = provider.get("dead-beef")
assert result is None
assert "dead-beef" not in provider._sandboxes
assert "dead-beef" not in provider._sandbox_infos
assert "dead-beef" not in provider._last_activity
assert "thread-1" not in provider._thread_sandboxes
backend.destroy.assert_called_once_with(info)
def test_get_returns_dead_sandbox_when_auto_restart_disabled():
"""When auto_restart is off, get() returns the cached sandbox even if the container is dead."""
provider, backend = _make_provider(auto_restart=False, alive=False)
sandbox, _ = _seed_sandbox(provider)
result = provider.get("dead-beef")
assert result is sandbox
# Caches are untouched
assert "dead-beef" in provider._sandboxes
def test_get_eviction_cleans_multiple_thread_mappings():
"""A sandbox mapped to multiple thread IDs has all mappings cleaned on eviction."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-1", thread_id="t-a")
# Manually add a second thread mapping to the same sandbox
provider._thread_sandboxes["t-b"] = "sid-1"
result = provider.get("sid-1")
assert result is None
assert "t-a" not in provider._thread_sandboxes
assert "t-b" not in provider._thread_sandboxes
# ── get() does not check health for unknown sandbox IDs ────────────────────
def test_get_returns_none_for_unknown_id():
"""If the sandbox_id is not in cache, get() returns None without checking health."""
provider, backend = _make_provider(auto_restart=True, alive=True)
result = provider.get("nonexistent")
assert result is None
backend.is_alive.assert_not_called()
# ── get() handles missing sandbox_info gracefully ──────────────────────────
def test_get_handles_missing_info_gracefully():
"""If sandbox is cached but info is missing, get() skips the health check."""
provider, backend = _make_provider(auto_restart=True, alive=False)
sandbox = MagicMock()
provider._sandboxes["sid-x"] = sandbox
provider._sandbox_infos.pop("sid-x", None) # Ensure no info
provider._last_activity["sid-x"] = 0.0
result = provider.get("sid-x")
# No info → cannot call is_alive → sandbox returned as-is
assert result is sandbox
backend.is_alive.assert_not_called()
def test_get_liveness_check_runs_outside_provider_lock():
"""get() should not hold the provider lock while checking backend liveness."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-locked", thread_id="thread-1")
def _assert_lock_not_held(_):
assert not provider._lock.locked()
return False
backend.is_alive.side_effect = _assert_lock_not_held
assert provider.get("sid-locked") is None
def test_get_still_evicts_when_backend_destroy_fails():
"""Cleanup errors should not keep stale sandbox state in memory."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-fail", thread_id="thread-1")
backend.destroy.side_effect = RuntimeError("boom")
assert provider.get("sid-fail") is None
assert "sid-fail" not in provider._sandboxes
assert "sid-fail" not in provider._sandbox_infos
assert "thread-1" not in provider._thread_sandboxes
backend.destroy.assert_called_once()
# ── Integration: eviction clears caches for recreation ─────────────────────
def test_eviction_clears_all_caches_for_recreation():
"""After eviction, all caches are clean so _acquire_internal can recreate.
This verifies the preconditions for transparent restart: when get() evicts
a dead sandbox, the next _acquire_internal call will find no cached entry,
no warm-pool entry, and fall through to _create_sandbox.
"""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-1", thread_id="thread-1")
# Before eviction: caches populated
assert "sid-1" in provider._sandboxes
assert "sid-1" in provider._sandbox_infos
assert "thread-1" in provider._thread_sandboxes
# get() detects the dead container and evicts
assert provider.get("sid-1") is None
# After eviction: all caches clean
assert "sid-1" not in provider._sandboxes
assert "sid-1" not in provider._sandbox_infos
assert "thread-1" not in provider._thread_sandboxes
assert "sid-1" not in provider._warm_pool
# _acquire_internal for the same thread would find nothing cached
# and generate the deterministic ID, then discover fails (container
# is gone), falling through to _create_sandbox — a fresh start.
+213 -1
View File
@@ -4,10 +4,40 @@ import json
import os
from pathlib import Path
import pytest
import yaml
from pydantic import ValidationError
from deerflow.config.agents_api_config import get_agents_api_config
import deerflow.config.app_config as app_config_module
from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.agents_api_config import get_agents_api_config, load_agents_api_config_from_dict
from deerflow.config.app_config import AppConfig, get_app_config, reset_app_config
from deerflow.config.checkpointer_config import get_checkpointer_config, load_checkpointer_config_from_dict
from deerflow.config.guardrails_config import get_guardrails_config, load_guardrails_config_from_dict
from deerflow.config.memory_config import get_memory_config, load_memory_config_from_dict
from deerflow.config.stream_bridge_config import get_stream_bridge_config, load_stream_bridge_config_from_dict
from deerflow.config.subagents_config import get_subagents_app_config, load_subagents_config_from_dict
from deerflow.config.summarization_config import get_summarization_config, load_summarization_config_from_dict
from deerflow.config.title_config import get_title_config, load_title_config_from_dict
from deerflow.config.tool_search_config import get_tool_search_config, load_tool_search_config_from_dict
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
from deerflow.runtime.store import get_store, reset_store
def _reset_config_singletons() -> None:
load_title_config_from_dict({})
load_summarization_config_from_dict({})
load_memory_config_from_dict({})
load_agents_api_config_from_dict({})
load_subagents_config_from_dict({})
load_tool_search_config_from_dict({})
load_guardrails_config_from_dict({})
load_checkpointer_config_from_dict(None)
load_stream_bridge_config_from_dict(None)
load_acp_config_from_dict({})
reset_checkpointer()
reset_store()
reset_app_config()
def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> None:
@@ -53,6 +83,23 @@ def _write_config_with_agents_api(
path.write_text(yaml.safe_dump(config), encoding="utf-8")
def _write_config_with_sections(path: Path, sections: dict | None = None) -> None:
config = {
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
"models": [
{
"name": "first-model",
"use": "langchain_openai:ChatOpenAI",
"model": "gpt-test",
}
],
}
if sections:
config.update(sections)
path.write_text(yaml.safe_dump(config), encoding="utf-8")
def _write_extensions_config(path: Path) -> None:
path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
@@ -175,3 +222,168 @@ def test_get_app_config_resets_agents_api_config_when_section_removed(tmp_path,
assert get_agents_api_config().enabled is False
finally:
reset_app_config()
def test_get_app_config_resets_singleton_configs_when_sections_removed(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config_with_sections(
config_path,
{
"title": {"enabled": False, "max_words": 3},
"summarization": {"enabled": True},
"memory": {"enabled": False, "max_facts": 50},
"subagents": {"timeout_seconds": 42, "agents": {"reviewer": {"max_turns": 2}}},
"tool_search": {"enabled": True},
"guardrails": {"enabled": True, "fail_closed": False},
"checkpointer": {"type": "memory"},
"stream_bridge": {"type": "memory", "queue_maxsize": 12},
},
)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
reset_app_config()
try:
get_app_config()
assert get_title_config().enabled is False
assert get_summarization_config().enabled is True
assert get_memory_config().enabled is False
assert get_subagents_app_config().timeout_seconds == 42
assert get_tool_search_config().enabled is True
assert get_guardrails_config().enabled is True
assert get_checkpointer_config() is not None
assert get_stream_bridge_config() is not None
_write_config_with_sections(config_path)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
get_app_config()
assert get_title_config().enabled is True
assert get_summarization_config().enabled is False
assert get_memory_config().enabled is True
assert get_subagents_app_config().timeout_seconds == 900
assert get_tool_search_config().enabled is False
assert get_guardrails_config().enabled is False
assert get_checkpointer_config() is None
assert get_stream_bridge_config() is None
finally:
_reset_config_singletons()
def test_get_app_config_resets_persistence_runtime_singletons_when_checkpointer_removed(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config_with_sections(config_path, {"checkpointer": {"type": "memory"}})
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
reset_checkpointer()
reset_store()
reset_app_config()
try:
get_app_config()
initial_checkpointer = get_checkpointer()
initial_store = get_store()
_write_config_with_sections(config_path)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
get_app_config()
assert get_checkpointer_config() is None
assert get_checkpointer() is not initial_checkpointer
assert get_store() is not initial_store
finally:
_reset_config_singletons()
def test_get_app_config_keeps_persistence_runtime_singletons_when_checkpointer_unchanged(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config_with_sections(
config_path,
{
"title": {"enabled": False},
"checkpointer": {"type": "memory"},
},
)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
_reset_config_singletons()
try:
get_app_config()
initial_checkpointer = get_checkpointer()
initial_store = get_store()
_write_config_with_sections(
config_path,
{
"title": {"enabled": True},
"checkpointer": {"type": "memory"},
},
)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
get_app_config()
assert get_checkpointer() is initial_checkpointer
assert get_store() is initial_store
finally:
_reset_config_singletons()
def test_get_app_config_does_not_mutate_singletons_when_reload_validation_fails(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config_with_sections(
config_path,
{
"title": {"enabled": False},
"tool_search": {"enabled": True},
"checkpointer": {"type": "memory"},
},
)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
_reset_config_singletons()
try:
previous_app_config = get_app_config()
initial_checkpointer = get_checkpointer()
initial_store = get_store()
_write_config_with_sections(
config_path,
{
"title": False,
"tool_search": False,
"checkpointer": {"type": "memory"},
},
)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
with pytest.raises(ValidationError):
get_app_config()
assert app_config_module._app_config is previous_app_config
assert get_title_config().enabled is False
assert get_tool_search_config().enabled is True
assert get_checkpointer_config() is not None
assert get_checkpointer() is initial_checkpointer
assert get_store() is initial_store
finally:
_reset_config_singletons()
+105 -1
View File
@@ -3,11 +3,12 @@
from __future__ import annotations
import asyncio
import os
from pathlib import Path
from unittest.mock import MagicMock, patch
from app.channels.base import Channel
from app.channels.message_bus import MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.message_bus import InboundMessage, MessageBus, OutboundMessage, ResolvedAttachment
def _run(coro):
@@ -248,6 +249,109 @@ class TestResolveAttachments:
assert result[0].filename == "data.csv"
# ---------------------------------------------------------------------------
# Inbound file ingestion tests
# ---------------------------------------------------------------------------
class TestInboundFileIngestion:
def test_rejects_preexisting_symlink_destination(self, tmp_path):
from app.channels import manager
uploads_dir = tmp_path / "uploads"
uploads_dir.mkdir()
outside_file = tmp_path / "outside-created.txt"
(uploads_dir / "victim.txt").symlink_to(outside_file)
msg = InboundMessage(
channel_name="test-channel",
chat_id="chat-1",
user_id="user-1",
text="see attachment",
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
)
async def fake_reader(file_info, client):
return b"attacker data"
with (
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
):
result = _run(manager._ingest_inbound_files("thread-1", msg))
assert result == []
assert not outside_file.exists()
assert (uploads_dir / "victim.txt").is_symlink()
def test_rejects_dangling_symlink_destination(self, tmp_path):
from app.channels import manager
uploads_dir = tmp_path / "uploads"
uploads_dir.mkdir()
missing_target = tmp_path / "missing-created.txt"
(uploads_dir / "victim.txt").symlink_to(missing_target)
msg = InboundMessage(
channel_name="test-channel",
chat_id="chat-1",
user_id="user-1",
text="see attachment",
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
)
async def fake_reader(file_info, client):
return b"attacker data"
with (
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
):
result = _run(manager._ingest_inbound_files("thread-1", msg))
assert result == []
assert not missing_target.exists()
assert (uploads_dir / "victim.txt").is_symlink()
def test_hardlinked_existing_file_is_not_overwritten(self, tmp_path):
from app.channels import manager
uploads_dir = tmp_path / "uploads"
uploads_dir.mkdir()
outside_file = tmp_path / "outside-created.txt"
outside_file.write_text("protected", encoding="utf-8")
os.link(outside_file, uploads_dir / "victim.txt")
msg = InboundMessage(
channel_name="test-channel",
chat_id="chat-1",
user_id="user-1",
text="see attachment",
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
)
async def fake_reader(file_info, client):
return b"new attachment data"
with (
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
):
result = _run(manager._ingest_inbound_files("thread-1", msg))
assert result == [
{
"filename": "victim_1.txt",
"size": len(b"new attachment data"),
"path": "/mnt/user-data/uploads/victim_1.txt",
"is_image": False,
}
]
assert outside_file.read_text(encoding="utf-8") == "protected"
assert (uploads_dir / "victim.txt").read_text(encoding="utf-8") == "protected"
assert (uploads_dir / "victim_1.txt").read_bytes() == b"new attachment data"
# ---------------------------------------------------------------------------
# Channel base class _on_outbound with attachments
# ---------------------------------------------------------------------------
+199
View File
@@ -372,6 +372,37 @@ class TestExtractResponseText:
# Should return "" (no text in current turn), NOT "Hi there!" from previous turn
assert _extract_response_text(result) == ""
def test_does_not_publish_loop_warning_on_tool_calling_ai_message(self):
"""Loop-detection warning text on a tool-calling AI message is middleware-authored."""
from app.channels.manager import _extract_response_text
result = {
"messages": [
{"type": "human", "content": "search the repo"},
{
"type": "ai",
"content": "[LOOP DETECTED] You are repeating the same tool calls.",
"tool_calls": [{"name": "grep", "args": {"pattern": "TODO"}, "id": "call_1"}],
},
]
}
assert _extract_response_text(result) == ""
def test_preserves_visible_text_when_stripping_loop_warning(self):
from app.channels.manager import _extract_response_text
result = {
"messages": [
{"type": "human", "content": "prepare the report"},
{
"type": "ai",
"content": "Here is the report.\n\n[LOOP DETECTED] You are repeating the same tool calls.",
"tool_calls": [{"name": "present_files", "args": {"filepaths": ["/mnt/user-data/outputs/report.md"]}, "id": "call_1"}],
},
]
}
assert _extract_response_text(result) == "Here is the report."
# ---------------------------------------------------------------------------
# ChannelManager tests
@@ -435,6 +466,47 @@ class TestChannelManager:
assert headers["Cookie"] == f"csrf_token={csrf_token}"
assert headers["X-DeerFlow-Internal-Token"]
def test_fetch_gateway_includes_internal_auth_headers(self, monkeypatch):
from app.channels.manager import ChannelManager
class MockResponse:
def raise_for_status(self):
return None
def json(self):
return {"models": [{"name": "default"}]}
class MockAsyncClient:
def __init__(self, *args, **kwargs):
return None
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, **kwargs):
calls.append({"url": url, **kwargs})
return MockResponse()
calls = []
monkeypatch.setattr("app.channels.manager.httpx.AsyncClient", MockAsyncClient)
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(bus=bus, store=store, gateway_url="http://gateway:8001")
reply = await manager._fetch_gateway("/api/models", "models")
assert reply == "Available models:\n• default"
assert calls[0]["url"] == "http://gateway:8001/api/models"
assert calls[0]["timeout"] == 10
assert calls[0]["headers"]["X-DeerFlow-Internal-Token"]
_run(go())
def test_handle_chat_calls_channel_receive_file_for_inbound_files(self, monkeypatch):
from app.channels.manager import ChannelManager
@@ -530,6 +602,8 @@ class TestChannelManager:
assert call_args[0][0] == "test-thread-123" # thread_id
assert call_args[0][1] == "lead_agent" # assistant_id
assert call_args[1]["input"]["messages"][0]["content"] == "hi"
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
assert len(outbound_received) == 1
assert outbound_received[0].text == "Hello from agent!"
@@ -661,12 +735,135 @@ class TestChannelManager:
call_args = mock_client.runs.wait.call_args
assert call_args[0][1] == "lead_agent"
assert call_args[1]["config"]["recursion_limit"] == 55
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
assert call_args[1]["context"]["thinking_enabled"] is False
assert call_args[1]["context"]["subagent_enabled"] is True
assert call_args[1]["context"]["agent_name"] == "mobile-agent"
_run(go())
def test_clarification_follow_up_preserves_history(self):
"""Conversation should continue after ask_clarification instead of resetting history."""
from app.channels.manager import ChannelManager
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(bus=bus, store=store)
outbound_received = []
async def capture_outbound(msg):
outbound_received.append(msg)
bus.subscribe_outbound(capture_outbound)
history_by_checkpoint: dict[tuple[str, str], list[str]] = {}
async def _runs_wait(thread_id, assistant_id, *, input, config, context):
del assistant_id, context # unused in this test, kept for signature parity
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
key = (thread_id, str(checkpoint_ns))
history = history_by_checkpoint.setdefault(key, [])
human_text = input["messages"][0]["content"]
history.append(human_text)
if len(history) == 1:
return {
"messages": [
{"type": "human", "content": history[0]},
{
"type": "ai",
"content": "",
"tool_calls": [
{
"name": "ask_clarification",
"args": {"question": "Which environment should I use?"},
}
],
},
{
"type": "tool",
"name": "ask_clarification",
"content": "Which environment should I use?",
},
]
}
if len(history) == 2 and history[0] == "Deploy my app" and history[1] == "prod":
return {
"messages": [
{"type": "human", "content": history[0]},
{
"type": "ai",
"content": "",
"tool_calls": [
{
"name": "ask_clarification",
"args": {"question": "Which environment should I use?"},
}
],
},
{
"type": "tool",
"name": "ask_clarification",
"content": "Which environment should I use?",
},
{"type": "human", "content": history[1]},
{"type": "ai", "content": "Got it. I will deploy to prod."},
]
}
return {
"messages": [
{"type": "human", "content": history[-1]},
{"type": "ai", "content": "History missing; clarification repeated."},
]
}
mock_client = MagicMock()
mock_client.threads.create = AsyncMock(return_value={"thread_id": "clarify-thread-1"})
mock_client.threads.get = AsyncMock(return_value={"thread_id": "clarify-thread-1"})
mock_client.runs.wait = AsyncMock(side_effect=_runs_wait)
manager._client = mock_client
await manager.start()
await bus.publish_inbound(
InboundMessage(
channel_name="test",
chat_id="chat1",
user_id="user1",
text="Deploy my app",
)
)
await _wait_for(lambda: len(outbound_received) >= 1)
await bus.publish_inbound(
InboundMessage(
channel_name="test",
chat_id="chat1",
user_id="user1",
text="prod",
)
)
await _wait_for(lambda: len(outbound_received) >= 2)
await manager.stop()
assert outbound_received[0].text == "Which environment should I use?"
assert outbound_received[1].text == "Got it. I will deploy to prod."
assert mock_client.runs.wait.call_count == 2
first_call = mock_client.runs.wait.call_args_list[0]
second_call = mock_client.runs.wait.call_args_list[1]
assert first_call.kwargs["config"]["configurable"]["checkpoint_ns"] == ""
assert second_call.kwargs["config"]["configurable"]["checkpoint_ns"] == ""
_run(go())
def test_handle_chat_uses_user_session_overrides(self):
from app.channels.manager import ChannelManager
@@ -1343,6 +1540,8 @@ class TestChannelManager:
call_args = mock_client.runs.stream.call_args
assert call_args[1]["input"]["messages"][0]["content"] == "hello"
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
assert call_args[1]["context"]["is_bootstrap"] is True
# Final message should be published
+41 -1
View File
@@ -1,6 +1,8 @@
"""Unit tests for checkpointer config and singleton factory."""
"""Unit tests for checkpointer config, packaging metadata, and factories."""
import sys
import tomllib
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -13,6 +15,8 @@ from deerflow.config.checkpointer_config import (
set_checkpointer_config,
)
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
from deerflow.runtime.checkpointer.provider import POSTGRES_INSTALL
from deerflow.runtime.store.provider import POSTGRES_STORE_INSTALL
@pytest.fixture(autouse=True)
@@ -67,6 +71,42 @@ class TestCheckpointerConfig:
with pytest.raises(Exception):
load_checkpointer_config_from_dict({"type": "unknown"})
def test_connection_string_description_matches_runtime_defaults(self):
description = CheckpointerConfig.model_fields["connection_string"].description
assert description is not None
assert "Optional for sqlite" in description
assert "defaults to 'store.db'" in description
assert "Required for postgres" in description
class TestHarnessPackaging:
def test_pyproject_declares_postgres_extra(self):
pyproject_path = Path(__file__).resolve().parents[1] / "packages" / "harness" / "pyproject.toml"
data = tomllib.loads(pyproject_path.read_text())
optional_dependencies = data["project"]["optional-dependencies"]
assert "postgres" in optional_dependencies
assert optional_dependencies["postgres"] == [
"asyncpg>=0.29",
"langgraph-checkpoint-postgres>=3.0.5",
"psycopg[binary]>=3.3.3",
"psycopg-pool>=3.3.0",
]
def test_workspace_pyproject_forwards_postgres_extra_to_harness(self):
pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
data = tomllib.loads(pyproject_path.read_text())
optional_dependencies = data["project"]["optional-dependencies"]
assert optional_dependencies["postgres"] == ["deerflow-harness[postgres]"]
def test_postgres_missing_dependency_messages_recommend_package_extra(self):
assert "deerflow-harness[postgres]" in POSTGRES_INSTALL
assert "deerflow-harness[postgres]" in POSTGRES_STORE_INSTALL
assert "uv sync --all-packages --extra postgres" in POSTGRES_INSTALL
assert "uv sync --all-packages --extra postgres" in POSTGRES_STORE_INSTALL
# ---------------------------------------------------------------------------
# Factory tests
+79
View File
@@ -437,6 +437,85 @@ class TestStream:
call_kwargs = agent.stream.call_args.kwargs
assert "messages" in call_kwargs["stream_mode"]
def test_stream_emits_additional_kwargs_updates_for_streamed_ai_messages(self, client):
"""stream() emits a follow-up AI event when attribution metadata arrives via values."""
assembled = AIMessage(
content="Hello!",
id="ai-1",
additional_kwargs={
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
},
)
agent = MagicMock()
agent.stream.return_value = iter(
[
("messages", (AIMessageChunk(content="Hello!", id="ai-1"), {})),
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
]
)
with (
patch.object(client, "_ensure_agent"),
patch.object(client, "_agent", agent),
):
events = list(client.stream("hi", thread_id="t-stream-kwargs"))
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
assert any(event.data.get("content") == "Hello!" for event in ai_events)
assert any(event.data.get("additional_kwargs", {}).get("token_usage_attribution", {}).get("kind") == "final_answer" for event in ai_events)
def test_stream_emits_new_additional_kwargs_after_prior_metadata(self, client):
"""stream() emits later attribution metadata even after earlier kwargs for the same id."""
attribution = {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
assembled = AIMessage(
content="Hello!",
id="ai-1",
additional_kwargs={
"reasoning_content": "Thinking first.",
"token_usage_attribution": attribution,
},
)
agent = MagicMock()
agent.stream.return_value = iter(
[
(
"messages",
(
AIMessageChunk(
content="Hello!",
id="ai-1",
additional_kwargs={"reasoning_content": "Thinking first."},
),
{},
),
),
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
]
)
with (
patch.object(client, "_ensure_agent"),
patch.object(client, "_agent", agent),
):
events = list(client.stream("hi", thread_id="t-stream-kwargs-delta"))
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
metadata_events = [event for event in ai_events if event.data.get("additional_kwargs")]
assert metadata_events[0].data["additional_kwargs"] == {"reasoning_content": "Thinking first."}
assert metadata_events[1].data["content"] == ""
assert metadata_events[1].data["additional_kwargs"] == {"token_usage_attribution": attribution}
def test_chat_accumulates_streamed_deltas(self, client):
"""chat() concatenates per-id deltas from messages mode."""
agent = MagicMock()
@@ -0,0 +1,53 @@
"""Tests for DeerFlowClient message serialization helpers."""
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.client import DeerFlowClient
def test_serialize_ai_message_preserves_additional_kwargs():
message = AIMessage(
content="done",
additional_kwargs={
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
},
usage_metadata={"input_tokens": 12, "output_tokens": 3, "total_tokens": 15},
)
serialized = DeerFlowClient._serialize_message(message)
assert serialized["type"] == "ai"
assert serialized["usage_metadata"] == {
"input_tokens": 12,
"output_tokens": 3,
"total_tokens": 15,
}
assert serialized["additional_kwargs"] == {
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
}
def test_serialize_human_message_preserves_additional_kwargs():
message = HumanMessage(
content="hello",
additional_kwargs={"files": [{"name": "diagram.png"}]},
)
serialized = DeerFlowClient._serialize_message(message)
assert serialized == {
"type": "human",
"content": "hello",
"id": None,
"additional_kwargs": {"files": [{"name": "diagram.png"}]},
}
+30
View File
@@ -82,6 +82,36 @@ def test_parse_response_text_content():
assert result.generations[0].message.content == "Hello world"
def test_parse_response_populates_usage_metadata():
model = _make_model()
response = {
"output": [
{
"type": "message",
"content": [{"type": "output_text", "text": "Hello world"}],
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15,
"input_tokens_details": {"cached_tokens": 3},
"output_tokens_details": {"reasoning_tokens": 2},
},
"model": "gpt-5.4",
}
result = model._parse_response(response)
meta = result.generations[0].message.usage_metadata
assert meta is not None
assert meta["input_tokens"] == 10
assert meta["output_tokens"] == 5
assert meta["total_tokens"] == 15
assert meta["input_token_details"]["cache_read"] == 3
assert meta["output_token_details"]["reasoning"] == 2
def test_parse_response_reasoning_content():
model = _make_model()
response = {
@@ -192,6 +192,7 @@ def test_agent_features_defaults():
assert f.vision is False
assert f.auto_title is False
assert f.guardrail is False
assert f.loop_detection is True
# ---------------------------------------------------------------------------
@@ -630,6 +631,51 @@ def test_loop_detection_before_clarification(mock_create_agent):
assert loop_idx == clar_idx - 1
# ---------------------------------------------------------------------------
# 30b. loop_detection=False skips LoopDetectionMiddleware
# ---------------------------------------------------------------------------
@patch("deerflow.agents.factory.create_agent")
def test_loop_detection_disabled(mock_create_agent):
mock_create_agent.return_value = MagicMock()
create_deerflow_agent(
_make_mock_model(),
features=RuntimeFeatures(sandbox=False, loop_detection=False),
)
call_kwargs = mock_create_agent.call_args[1]
mw_types = [type(m).__name__ for m in call_kwargs["middleware"]]
assert "LoopDetectionMiddleware" not in mw_types
# ---------------------------------------------------------------------------
# 30c. loop_detection=<custom AgentMiddleware> replaces the default
# ---------------------------------------------------------------------------
@patch("deerflow.agents.factory.create_agent")
def test_loop_detection_custom_middleware(mock_create_agent):
from langchain.agents.middleware import AgentMiddleware as AM
mock_create_agent.return_value = MagicMock()
class MyLoopDetection(AM):
pass
custom = MyLoopDetection()
create_deerflow_agent(
_make_mock_model(),
features=RuntimeFeatures(sandbox=False, loop_detection=custom),
)
call_kwargs = mock_create_agent.call_args[1]
middleware = call_kwargs["middleware"]
assert custom in middleware
mw_types = [type(m).__name__ for m in middleware]
# Default LoopDetectionMiddleware must not also appear.
assert "LoopDetectionMiddleware" not in mw_types
# Custom replacement still sits immediately before ClarificationMiddleware.
assert mw_types[-1] == "ClarificationMiddleware"
assert mw_types[-2] == "MyLoopDetection"
# ---------------------------------------------------------------------------
# 31. plan_mode=True adds TodoMiddleware
# ---------------------------------------------------------------------------
+2
View File
@@ -85,6 +85,8 @@ def test_load_claude_code_credential_from_override_path(tmp_path, monkeypatch):
def test_load_claude_code_credential_ignores_directory_path(tmp_path, monkeypatch):
_clear_claude_code_env(monkeypatch)
# Redirect HOME so the default ~/.claude/.credentials.json doesn't exist
monkeypatch.setenv("HOME", str(tmp_path))
cred_dir = tmp_path / "claude-creds-dir"
cred_dir.mkdir()
monkeypatch.setenv("CLAUDE_CODE_CREDENTIALS_PATH", str(cred_dir))
+235
View File
@@ -0,0 +1,235 @@
"""Tests for CSRF middleware."""
from fastapi import FastAPI
from starlette.testclient import TestClient
from app.gateway.csrf_middleware import CSRFMiddleware
def _make_app() -> FastAPI:
app = FastAPI()
app.add_middleware(CSRFMiddleware)
@app.post("/api/v1/auth/login/local")
async def login_local():
return {"ok": True}
@app.post("/api/v1/auth/register")
async def register():
return {"ok": True}
@app.post("/api/threads/abc/runs/stream")
async def protected_mutation():
return {"ok": True}
return app
def test_auth_post_rejects_cross_origin_browser_request():
"""CSRF-exempt auth routes must not accept hostile browser origins.
Login/register endpoints intentionally skip the double-submit token because
first-time callers do not have a token yet. They still set an auth session,
so a hostile cross-site form POST must be rejected to avoid login CSRF /
session fixation.
"""
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://evil.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
def test_auth_post_allows_same_origin_browser_request():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_rejects_malformed_origin_with_path():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example/path"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
assert response.cookies.get("csrf_token") is None
def test_auth_post_rejects_malformed_origin_with_invalid_port():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example:bad"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
assert response.cookies.get("csrf_token") is None
def test_auth_post_allows_same_origin_default_port_equivalence():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example:443"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_allows_forwarded_same_origin():
client = TestClient(_make_app(), base_url="http://internal:8000")
response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "https://deerflow.example",
"X-Forwarded-Proto": "https",
"X-Forwarded-Host": "deerflow.example, internal:8000",
},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_allows_forwarded_same_origin_with_non_default_port():
client = TestClient(_make_app(), base_url="http://internal:8000")
response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "http://localhost:2026",
"X-Forwarded-Proto": "http",
"X-Forwarded-Host": "localhost:2026",
},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_allows_rfc_forwarded_same_origin():
client = TestClient(_make_app(), base_url="http://internal:8000")
response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "https://deerflow.example",
"Forwarded": "proto=https;host=deerflow.example",
},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
assert "secure" in response.headers["set-cookie"].lower()
def test_auth_post_allows_explicit_configured_origin(monkeypatch):
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "https://app.example")
client = TestClient(_make_app(), base_url="https://api.example")
response = client.post(
"/api/v1/auth/register",
headers={"Origin": "https://app.example"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_does_not_treat_wildcard_cors_as_allowed_origin(monkeypatch):
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "*")
client = TestClient(_make_app(), base_url="https://api.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://evil.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
def test_auth_post_sets_strict_samesite_csrf_cookie():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 200
set_cookie = response.headers["set-cookie"].lower()
assert "csrf_token=" in set_cookie
assert "samesite=strict" in set_cookie
assert "secure" in set_cookie
def test_auth_post_without_origin_still_allows_non_browser_clients():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post("/api/v1/auth/login/local")
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_non_auth_mutation_still_requires_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/threads/abc/runs/stream",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."
def test_non_auth_mutation_allows_valid_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
client.cookies.set("csrf_token", "known-token")
response = client.post(
"/api/threads/abc/runs/stream",
headers={
"Origin": "https://deerflow.example",
"X-CSRF-Token": "known-token",
},
)
assert response.status_code == 200
def test_non_auth_mutation_rejects_mismatched_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
client.cookies.set("csrf_token", "cookie-token")
response = client.post(
"/api/threads/abc/runs/stream",
headers={
"Origin": "https://deerflow.example",
"X-CSRF-Token": "header-token",
},
)
assert response.status_code == 403
assert response.json()["detail"] == "CSRF token mismatch."
+16 -2
View File
@@ -537,7 +537,10 @@ class TestAgentsAPI:
def test_create_persists_files_on_disk(self, agent_client, tmp_path):
agent_client.post("/api/agents", json={"name": "disk-check", "soul": "disk soul"})
agent_dir = tmp_path / "agents" / "disk-check"
# tests/conftest.py installs an autouse fixture that sets the
# contextvar to "test-user-autouse", so the agent is persisted under
# users/test-user-autouse/agents/ rather than the legacy shared dir.
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "disk-check"
assert agent_dir.exists()
assert (agent_dir / "config.yaml").exists()
assert (agent_dir / "SOUL.md").exists()
@@ -545,12 +548,23 @@ class TestAgentsAPI:
def test_delete_removes_files_from_disk(self, agent_client, tmp_path):
agent_client.post("/api/agents", json={"name": "remove-me", "soul": "bye"})
agent_dir = tmp_path / "agents" / "remove-me"
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "remove-me"
assert agent_dir.exists()
agent_client.delete("/api/agents/remove-me")
assert not agent_dir.exists()
def test_create_rejects_legacy_name_collision(self, agent_client, tmp_path):
"""An unmigrated legacy agent must still block name collision so that
running the migration script later won't shadow the legacy entry."""
legacy_dir = tmp_path / "agents" / "legacy-agent"
legacy_dir.mkdir(parents=True)
(legacy_dir / "config.yaml").write_text("name: legacy-agent\n", encoding="utf-8")
(legacy_dir / "SOUL.md").write_text("legacy soul", encoding="utf-8")
response = agent_client.post("/api/agents", json={"name": "legacy-agent", "soul": "x"})
assert response.status_code == 409
# ===========================================================================
# 9. Gateway API User Profile endpoints
+201
View File
@@ -0,0 +1,201 @@
"""Unit tests for scripts/detect_uv_extras.py.
The detector resolves uv extras for `make dev` so that postgres (and any
future opt-in extras) are not wiped on every restart see Issue #2754.
"""
from __future__ import annotations
import importlib.util
from pathlib import Path
import pytest
REPO_ROOT = Path(__file__).resolve().parents[2]
DETECT_SCRIPT_PATH = REPO_ROOT / "scripts" / "detect_uv_extras.py"
spec = importlib.util.spec_from_file_location("deerflow_detect_uv_extras", DETECT_SCRIPT_PATH)
assert spec is not None and spec.loader is not None
detect = importlib.util.module_from_spec(spec)
spec.loader.exec_module(detect)
@pytest.fixture
def isolated_cwd(tmp_path, monkeypatch):
"""Isolate `find_config_file()` from the real repo by chdir + clearing env."""
monkeypatch.chdir(tmp_path)
monkeypatch.delenv("UV_EXTRAS", raising=False)
monkeypatch.delenv("DEER_FLOW_CONFIG_PATH", raising=False)
return tmp_path
def test_parse_env_extras_supports_comma_and_whitespace():
assert detect.parse_env_extras("postgres") == ["postgres"]
assert detect.parse_env_extras("postgres,ollama") == ["postgres", "ollama"]
assert detect.parse_env_extras("postgres ollama") == ["postgres", "ollama"]
assert detect.parse_env_extras(" postgres , ollama ,") == ["postgres", "ollama"]
assert detect.parse_env_extras("") == []
def test_parse_env_extras_drops_shell_metacharacters(capsys):
"""A `.env` value containing shell injection bait must not pass through.
The whitelist guarantees the *bytes* that reach `uv sync` cannot include
shell metacharacters. Any name that looks identifier-like still survives
(uv itself will reject unknown extras with its own error), but `;`, `&`,
backticks, parentheses, slashes, etc. are stripped.
"""
# Pure-metacharacter inputs collapse to empty.
assert detect.parse_env_extras(";") == []
assert detect.parse_env_extras("$(whoami)") == []
assert detect.parse_env_extras("`echo bad`") == []
assert detect.parse_env_extras("postgres;evil") == [] # single token, contains `;`
# Splitting on whitespace yields ['rm'] which is identifier-shaped, but the
# destructive bits (`;`, `-rf`, `/`) are dropped.
assert detect.parse_env_extras("; rm -rf /") == ["rm"]
err = capsys.readouterr().err
assert "ignoring invalid UV_EXTRAS entry" in err
assert "';'" in err # confirms the dangerous token was reported and dropped
def test_parse_env_extras_rejects_leading_digits_and_punctuation():
"""Names must start with a letter — pyproject extras follow this shape."""
assert detect.parse_env_extras("1postgres") == []
assert detect.parse_env_extras("-postgres") == []
# Hyphens and underscores inside the name are fine.
assert detect.parse_env_extras("post_gres") == ["post_gres"]
assert detect.parse_env_extras("post-gres") == ["post-gres"]
def test_format_flags_emits_one_flag_per_extra():
assert detect.format_flags([]) == ""
assert detect.format_flags(["postgres"]) == "--extra postgres"
assert detect.format_flags(["postgres", "ollama"]) == "--extra postgres --extra ollama"
def test_strip_comment_preserves_quoted_hash():
assert detect._strip_comment("backend: postgres # trailing") == "backend: postgres"
assert detect._strip_comment('name: "value#with-hash"') == 'name: "value#with-hash"'
assert detect._strip_comment("# whole line comment") == ""
def test_section_value_finds_nested_key():
yaml_lines = [
"database:",
" backend: postgres",
" postgres_url: $DATABASE_URL",
"",
"checkpointer:",
" type: sqlite",
]
assert detect.section_value(yaml_lines, "database", "backend") == "postgres"
assert detect.section_value(yaml_lines, "checkpointer", "type") == "sqlite"
assert detect.section_value(yaml_lines, "database", "missing") is None
assert detect.section_value(yaml_lines, "absent_section", "anything") is None
def test_section_value_ignores_commented_lines():
yaml_lines = [
"# database:",
"# backend: postgres",
"database:",
" backend: sqlite",
]
assert detect.section_value(yaml_lines, "database", "backend") == "sqlite"
def test_section_value_strips_quotes():
yaml_lines = [
"database:",
' backend: "postgres"',
]
assert detect.section_value(yaml_lines, "database", "backend") == "postgres"
def test_section_value_does_not_descend_into_grandchildren():
yaml_lines = [
"database:",
" backend: sqlite",
" nested:",
" backend: postgres",
]
# Only the immediate child level counts — keeps the parser predictable.
assert detect.section_value(yaml_lines, "database", "backend") == "sqlite"
def test_detect_from_config_postgres_via_database(tmp_path):
cfg = tmp_path / "config.yaml"
cfg.write_text("database:\n backend: postgres\n postgres_url: $DATABASE_URL\n")
assert detect.detect_from_config(cfg) == ["postgres"]
def test_detect_from_config_postgres_via_checkpointer(tmp_path):
cfg = tmp_path / "config.yaml"
cfg.write_text("checkpointer:\n type: postgres\n connection_string: postgresql://localhost/db\n")
assert detect.detect_from_config(cfg) == ["postgres"]
def test_detect_from_config_sqlite_returns_no_extras(tmp_path):
cfg = tmp_path / "config.yaml"
cfg.write_text("database:\n backend: sqlite\n sqlite_dir: .deer-flow/data\n")
assert detect.detect_from_config(cfg) == []
def test_detect_from_config_dedupes_when_both_present(tmp_path):
cfg = tmp_path / "config.yaml"
cfg.write_text("checkpointer:\n type: postgres\ndatabase:\n backend: postgres\n")
# Sorted unique extras, no double-counting.
assert detect.detect_from_config(cfg) == ["postgres"]
def test_detect_from_config_missing_file_returns_empty(tmp_path):
assert detect.detect_from_config(tmp_path / "does-not-exist.yaml") == []
def test_resolve_extras_env_overrides_config(isolated_cwd, monkeypatch):
cfg = isolated_cwd / "config.yaml"
cfg.write_text("database:\n backend: sqlite\n")
monkeypatch.setenv("UV_EXTRAS", "postgres")
assert detect.resolve_extras() == ["postgres"]
def test_resolve_extras_env_supports_multiple(isolated_cwd, monkeypatch):
monkeypatch.setenv("UV_EXTRAS", "postgres,ollama")
assert detect.resolve_extras() == ["postgres", "ollama"]
def test_resolve_extras_falls_back_to_config(isolated_cwd):
(isolated_cwd / "config.yaml").write_text("database:\n backend: postgres\n")
assert detect.resolve_extras() == ["postgres"]
def test_resolve_extras_respects_explicit_config_path(tmp_path, monkeypatch):
monkeypatch.delenv("UV_EXTRAS", raising=False)
elsewhere = tmp_path / "elsewhere.yaml"
elsewhere.write_text("database:\n backend: postgres\n")
monkeypatch.chdir(tmp_path)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(elsewhere))
assert detect.resolve_extras() == ["postgres"]
def test_resolve_extras_no_config_no_env(isolated_cwd):
assert detect.resolve_extras() == []
def test_resolve_extras_finds_backend_subdir_config(isolated_cwd):
sub = isolated_cwd / "backend"
sub.mkdir()
(sub / "config.yaml").write_text("database:\n backend: postgres\n")
assert detect.resolve_extras() == ["postgres"]
def test_resolve_extras_root_config_takes_precedence(isolated_cwd):
(isolated_cwd / "config.yaml").write_text("database:\n backend: sqlite\n")
sub = isolated_cwd / "backend"
sub.mkdir()
(sub / "config.yaml").write_text("database:\n backend: postgres\n")
# Root config.yaml is checked first, matching the precedence in serve.sh.
assert detect.resolve_extras() == []
+102
View File
@@ -0,0 +1,102 @@
"""Unit tests for docker/dev-entrypoint.sh (UV_EXTRAS validation + parsing).
Exercises the script via its `--print-extras` dry-run hook so we don't actually
launch uvicorn or hit /app/logs. Together with test_detect_uv_extras.py these
cover both the local make-dev path and the docker-compose-dev path with the
same shape see PR #2767 / Issue #2754.
"""
from __future__ import annotations
import os
import subprocess
from pathlib import Path
import pytest
REPO_ROOT = Path(__file__).resolve().parents[2]
ENTRYPOINT = REPO_ROOT / "docker" / "dev-entrypoint.sh"
def _run(uv_extras: str | None) -> subprocess.CompletedProcess[str]:
"""Invoke `dev-entrypoint.sh --print-extras` with UV_EXTRAS set."""
env = os.environ.copy()
env.pop("UV_EXTRAS", None)
if uv_extras is not None:
env["UV_EXTRAS"] = uv_extras
return subprocess.run(
["sh", str(ENTRYPOINT), "--print-extras"],
env=env,
capture_output=True,
text=True,
check=False,
)
def test_entrypoint_script_exists_and_is_posix_sh():
assert ENTRYPOINT.is_file()
# Catch syntax errors before runtime — `sh -n` is a parse-only check.
proc = subprocess.run(["sh", "-n", str(ENTRYPOINT)], capture_output=True, text=True, check=False)
assert proc.returncode == 0, proc.stderr
def test_no_uv_extras_yields_empty_flags():
proc = _run(None)
assert proc.returncode == 0
assert proc.stdout.strip() == ""
def test_single_extra():
proc = _run("postgres")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra postgres"
def test_multi_extra_comma_separated():
proc = _run("postgres,ollama")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra postgres --extra ollama"
def test_multi_extra_whitespace_separated():
proc = _run("postgres ollama")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra postgres --extra ollama"
def test_multi_extra_mixed_separators():
proc = _run(" postgres , ollama ,")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra postgres --extra ollama"
def test_empty_string_yields_empty_flags():
proc = _run("")
assert proc.returncode == 0
assert proc.stdout.strip() == ""
@pytest.mark.parametrize(
"bad_value",
[
"; rm -rf /", # the canonical injection attempt
"$(whoami)", # command substitution
"`echo bad`", # backticks
"postgres;evil", # mixed legal+illegal in a single token
"1postgres", # leading digit
"-postgres", # leading hyphen
"post gres extra/path", # contains slash
],
)
def test_metacharacters_abort_with_nonzero_exit(bad_value):
proc = _run(bad_value)
assert proc.returncode != 0, f"expected abort for {bad_value!r}, got 0"
assert "is invalid" in proc.stderr
assert proc.stdout.strip() == ""
def test_underscores_and_hyphens_in_name_are_allowed():
"""Mirrors uv's accepted shape for `[project.optional-dependencies]` keys."""
proc = _run("post_gres,post-gres")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra post_gres --extra post-gres"
@@ -0,0 +1,336 @@
"""Tests for DynamicContextMiddleware.
Verifies that memory and current date are injected as a <system-reminder> into
the first HumanMessage exactly once per session (frozen-snapshot pattern).
"""
from types import SimpleNamespace
from unittest import mock
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares.dynamic_context_middleware import (
_DYNAMIC_CONTEXT_REMINDER_KEY,
DynamicContextMiddleware,
)
_SYSTEM_REMINDER_TAG = "<system-reminder>"
def _make_middleware(**kwargs) -> DynamicContextMiddleware:
return DynamicContextMiddleware(**kwargs)
def _fake_runtime():
return SimpleNamespace(context={})
def _reminder_msg(content: str, msg_id: str) -> HumanMessage:
"""Build a reminder HumanMessage the way the middleware would produce it."""
return HumanMessage(
content=content,
id=msg_id,
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
)
# ---------------------------------------------------------------------------
# Basic injection
# ---------------------------------------------------------------------------
def test_injects_system_reminder_into_first_human_message():
mw = _make_middleware()
state = {"messages": [HumanMessage(content="Hello", id="msg-1")]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
updated_msgs = result["messages"]
assert len(updated_msgs) == 2
reminder_msg = updated_msgs[0]
assert isinstance(reminder_msg, HumanMessage)
assert reminder_msg.id == "msg-1" # takes the original ID (position swap)
assert reminder_msg.additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert _SYSTEM_REMINDER_TAG in reminder_msg.content
assert "<current_date>2026-05-08, Friday</current_date>" in reminder_msg.content
assert "Hello" not in reminder_msg.content # reminder only — no user text
user_msg = updated_msgs[1]
assert isinstance(user_msg, HumanMessage)
assert user_msg.id == "msg-1__user" # derived ID
assert user_msg.content == "Hello"
def test_memory_included_when_present():
mw = _make_middleware()
state = {"messages": [HumanMessage(content="Hi", id="msg-1")]}
with (
mock.patch(
"deerflow.agents.lead_agent.prompt._get_memory_context",
return_value="<memory>\nUser prefers Python.\n</memory>",
),
mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt,
):
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
# Reminder is the first returned message; user query is the second
reminder_content = result["messages"][0].content
assert "User prefers Python." in reminder_content
assert "<current_date>2026-05-08, Friday</current_date>" in reminder_content
assert result["messages"][1].content == "Hi"
# ---------------------------------------------------------------------------
# Frozen-snapshot: no re-injection within a session
# ---------------------------------------------------------------------------
def test_skips_injection_if_already_present():
"""Second turn: separate reminder message already present → no update."""
mw = _make_middleware()
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
state = {
"messages": [
_reminder_msg(reminder_content, "msg-1"),
HumanMessage(content="Hello", id="msg-1__user"),
AIMessage(content="Hi there"),
HumanMessage(content="Follow-up", id="msg-2"),
]
}
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is None # no update needed
def test_injects_only_into_first_human_message_not_later_ones():
"""Reminder targets the first HumanMessage; subsequent messages are not touched."""
mw = _make_middleware()
state = {
"messages": [
HumanMessage(content="First", id="msg-1"),
AIMessage(content="Reply"),
HumanMessage(content="Second", id="msg-2"),
]
}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
msgs = result["messages"]
# Only the two injected messages are returned (reminder + original first query)
assert len(msgs) == 2
assert msgs[0].id == "msg-1" # reminder takes first message's ID
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert _SYSTEM_REMINDER_TAG in msgs[0].content
assert msgs[1].id == "msg-1__user" # original content with derived ID
assert msgs[1].content == "First"
# "Second" (msg-2) is not in the returned update — it is left unchanged
assert all(m.id != "msg-2" for m in msgs)
def test_summary_human_message_is_not_used_as_injection_target():
"""After summarization, the synthetic summary HumanMessage is not a user turn."""
mw = _make_middleware()
state = {
"messages": [
HumanMessage(content="Here is a summary of the conversation to date:\n\n...", id="summary-1", name="summary"),
AIMessage(content="Earlier reply"),
HumanMessage(content="Follow-up", id="msg-2"),
]
}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
msgs = result["messages"]
assert len(msgs) == 2
assert msgs[0].id == "msg-2"
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert msgs[1].id == "msg-2__user"
assert msgs[1].content == "Follow-up"
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
def test_no_messages_returns_none():
mw = _make_middleware()
result = mw.before_agent({"messages": []}, _fake_runtime())
assert result is None
def test_no_human_message_returns_none():
mw = _make_middleware()
state = {"messages": [AIMessage(content="assistant only")]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""):
result = mw.before_agent(state, _fake_runtime())
assert result is None
def test_list_content_message_handled_as_separate_reminder():
"""List-content (e.g. multi-modal) messages remain intact; reminder is a separate message."""
mw = _make_middleware()
original_content = [{"type": "text", "text": "Hello"}]
state = {"messages": [HumanMessage(content=original_content, id="msg-1")]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
msgs = result["messages"]
assert len(msgs) == 2
# Reminder is a plain string message with the flag set
assert isinstance(msgs[0].content, str)
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert _SYSTEM_REMINDER_TAG in msgs[0].content
# Original list-content message is untouched
assert msgs[1].content == original_content
def test_reminder_uses_original_id_user_message_uses_derived_id():
"""Reminder takes original ID (position swap); user message gets {id}__user."""
mw = _make_middleware()
original_id = "original-id-abc"
state = {"messages": [HumanMessage(content="Hello", id=original_id)]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result["messages"][0].id == original_id
assert result["messages"][1].id == f"{original_id}__user"
def test_message_without_id_gets_stable_uuid():
"""If the original HumanMessage has no ID, a UUID is generated and used consistently."""
mw = _make_middleware()
state = {"messages": [HumanMessage(content="Hello", id=None)]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
reminder_id = result["messages"][0].id
user_id = result["messages"][1].id
assert reminder_id is not None
assert reminder_id != "None"
assert user_id == f"{reminder_id}__user"
def test_user_message_containing_system_reminder_tag_does_not_prevent_injection():
"""A user message containing '<system-reminder>' must not be mistaken for a reminder."""
mw = _make_middleware()
state = {
"messages": [
HumanMessage(content="What is <system-reminder>?", id="msg-1"),
]
}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
# Injection must happen — the user message does NOT carry the reminder flag
assert result is not None
assert result["messages"][0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
# ---------------------------------------------------------------------------
# Midnight crossing
# ---------------------------------------------------------------------------
def test_midnight_crossing_injects_date_update_as_separate_message():
"""When the date has changed, a separate date-update reminder is injected before
the current turn's HumanMessage using the ID-swap technique."""
mw = _make_middleware()
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
state = {
"messages": [
_reminder_msg(reminder_content, "msg-1"),
HumanMessage(content="Hello", id="msg-1__user"),
AIMessage(content="Response"),
HumanMessage(content="Good morning", id="msg-2"),
]
}
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
msgs = result["messages"]
assert len(msgs) == 2
# Date-update reminder takes the current message's ID
assert msgs[0].id == "msg-2"
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert _SYSTEM_REMINDER_TAG in msgs[0].content
assert "<current_date>2026-05-09, Saturday</current_date>" in msgs[0].content
assert "Good morning" not in msgs[0].content # reminder only
# Original user text appended with derived ID
assert msgs[1].id == "msg-2__user"
assert msgs[1].content == "Good morning"
def test_midnight_crossing_id_swap():
"""Date-update reminder uses original ID; user message uses {id}__user."""
mw = _make_middleware()
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
state = {
"messages": [
_reminder_msg(reminder_content, "msg-1"),
HumanMessage(content="Next day message", id="msg-2"),
]
}
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
result = mw.before_agent(state, _fake_runtime())
assert result["messages"][0].id == "msg-2"
assert result["messages"][1].id == "msg-2__user"
def test_no_second_midnight_injection_once_date_updated():
"""After a midnight update is persisted, the same-day path skips re-injection."""
mw = _make_middleware()
date_update_content = "<system-reminder>\n<current_date>2026-05-09, Saturday</current_date>\n</system-reminder>"
state = {
"messages": [
_reminder_msg(
"<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
"msg-1",
),
HumanMessage(content="Hello", id="msg-1__user"),
AIMessage(content="Response"),
_reminder_msg(date_update_content, "msg-2"),
HumanMessage(content="Good morning", id="msg-2__user"),
AIMessage(content="Good morning!"),
HumanMessage(content="Third turn", id="msg-3"),
]
}
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
result = mw.before_agent(state, _fake_runtime())
assert result is None # same day as last injected date → no update
@@ -50,7 +50,7 @@ def test_nginx_routes_official_langgraph_prefix_to_gateway_api():
assert "/api/langgraph-compat" not in content
assert "proxy_pass http://langgraph" not in content
assert "rewrite ^/api/langgraph/(.*) /api/$1 break;" in content
assert "proxy_pass http://gateway" in content
assert "proxy_pass http://gateway" in content or "proxy_pass http://$gateway_upstream" in content
def test_frontend_rewrites_langgraph_prefix_to_gateway():
+15
View File
@@ -324,6 +324,21 @@ def test_context_does_not_override_existing_configurable():
assert config["configurable"]["subagent_enabled"] is True
def test_inject_authenticated_user_context_overrides_client_user_id():
"""Run context should carry the authenticated user, not client-supplied user_id."""
from types import SimpleNamespace
from app.gateway.services import build_run_config, inject_authenticated_user_context
config = build_run_config("thread-1", None, None)
config["context"] = {"user_id": "spoofed-client"}
request = SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(id="auth-user-42")))
inject_authenticated_user_context(config, request)
assert config["context"]["user_id"] == "auth-user-42"
# ---------------------------------------------------------------------------
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
# ---------------------------------------------------------------------------
@@ -8,17 +8,20 @@ from unittest.mock import MagicMock
import pytest
from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.config.app_config import AppConfig
from deerflow.config.loop_detection_config import LoopDetectionConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.summarization_config import SummarizationConfig
def _make_app_config(models: list[ModelConfig]) -> AppConfig:
def _make_app_config(models: list[ModelConfig], loop_detection: LoopDetectionConfig | None = None) -> AppConfig:
return AppConfig(
models=models,
sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider"),
loop_detection=loop_detection or LoopDetectionConfig(),
)
@@ -340,6 +343,59 @@ def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypa
assert middlewares[0] == "base-middleware"
def test_build_middlewares_uses_loop_detection_config(monkeypatch):
app_config = _make_app_config(
[_make_model("safe-model", supports_thinking=False)],
loop_detection=LoopDetectionConfig(
warn_threshold=7,
hard_limit=9,
window_size=30,
max_tracked_threads=40,
tool_freq_warn=50,
tool_freq_hard_limit=60,
),
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
model_name="safe-model",
app_config=app_config,
)
loop_detection = next(m for m in middlewares if isinstance(m, LoopDetectionMiddleware))
assert loop_detection.warn_threshold == 7
assert loop_detection.hard_limit == 9
assert loop_detection.window_size == 30
assert loop_detection.max_tracked_threads == 40
assert loop_detection.tool_freq_warn == 50
assert loop_detection.tool_freq_hard_limit == 60
def test_build_middlewares_omits_loop_detection_when_disabled(monkeypatch):
app_config = _make_app_config(
[_make_model("safe-model", supports_thinking=False)],
loop_detection=LoopDetectionConfig(enabled=False),
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
model_name="safe-model",
app_config=app_config,
)
assert not any(isinstance(m, LoopDetectionMiddleware) for m in middlewares)
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
app_config = _make_app_config([_make_model("model-masswork", supports_thinking=False)])
app_config.summarization = SummarizationConfig(enabled=True, model_name="model-masswork")
+70 -3
View File
@@ -1,22 +1,37 @@
import threading
from types import SimpleNamespace
from typing import cast
import anyio
from deerflow.agents.lead_agent import prompt as prompt_module
from deerflow.config.app_config import AppConfig
from deerflow.config.subagents_config import CustomSubagentConfig, SubagentsAppConfig
from deerflow.skills.types import Skill
from deerflow.skills.types import Skill, SkillCategory
def _set_skills_cache_state(*, skills=None, active=False, version=0):
prompt_module._get_cached_skills_prompt_section.cache_clear()
with prompt_module._enabled_skills_lock:
prompt_module._enabled_skills_cache = skills
prompt_module._enabled_skills_by_config_cache.clear()
prompt_module._enabled_skills_refresh_active = active
prompt_module._enabled_skills_refresh_version = version
prompt_module._enabled_skills_refresh_event.clear()
def test_build_self_update_section_empty_for_default_agent():
assert prompt_module._build_self_update_section(None) == ""
def test_build_self_update_section_present_for_custom_agent():
section = prompt_module._build_self_update_section("my-agent")
assert "<self_update>" in section
assert "my-agent" in section
assert "update_agent" in section
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
@@ -220,7 +235,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
skill_dir=skill_dir,
skill_file=skill_dir / "SKILL.md",
relative_path=skill_dir.relative_to(tmp_path),
category="custom",
category=SkillCategory.CUSTOM,
enabled=True,
)
@@ -240,6 +255,58 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
_set_skills_cache_state()
def test_explicit_config_enabled_skills_are_cached_by_config_identity(monkeypatch, tmp_path):
def make_skill(name: str) -> Skill:
skill_dir = tmp_path / name
return Skill(
name=name,
description=f"Description for {name}",
license="MIT",
skill_dir=skill_dir,
skill_file=skill_dir / "SKILL.md",
relative_path=skill_dir.relative_to(tmp_path),
category=SkillCategory.CUSTOM,
enabled=True,
)
config = cast(
AppConfig,
cast(
object,
SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=False),
),
),
)
load_count = 0
def fake_get_or_new_skill_storage(**kwargs):
nonlocal load_count
assert kwargs == {"app_config": config}
def load_skills(*, enabled_only):
nonlocal load_count
load_count += 1
assert enabled_only is True
return [make_skill("cached-skill")]
return SimpleNamespace(load_skills=load_skills)
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", fake_get_or_new_skill_storage)
_set_skills_cache_state()
try:
first = prompt_module.get_skills_prompt_section(app_config=config)
second = prompt_module.get_skills_prompt_section(app_config=config)
assert "cached-skill" in first
assert "cached-skill" in second
assert load_count == 1
finally:
_set_skills_cache_state()
def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_path):
started = threading.Event()
release = threading.Event()
@@ -257,7 +324,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
skill_dir=skill_dir,
skill_file=skill_dir / "SKILL.md",
relative_path=skill_dir.relative_to(tmp_path),
category="custom",
category=SkillCategory.CUSTOM,
enabled=True,
)
+111 -1
View File
@@ -6,7 +6,12 @@ from deerflow.config.agents_config import AgentConfig
from deerflow.skills.types import Skill
def _make_skill(name: str) -> Skill:
class NamedTool:
def __init__(self, name: str):
self.name = name
def _make_skill(name: str, allowed_tools: list[str] | None = None) -> Skill:
return Skill(
name=name,
description=f"Description for {name}",
@@ -15,6 +20,7 @@ def _make_skill(name: str) -> Skill:
skill_file=Path(f"/tmp/{name}/SKILL.md"),
relative_path=Path(name),
category="public",
allowed_tools=allowed_tools,
enabled=True,
)
@@ -132,6 +138,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
@@ -164,3 +171,106 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["skill1"]))
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert captured_skills[-1] == {"skill1"}
def test_make_lead_agent_filters_tools_from_available_skills(monkeypatch):
from unittest.mock import MagicMock
from deerflow.agents.lead_agent import agent as lead_agent_module
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted", "legacy"]))
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [_make_skill("restricted", ["read_file"]), _make_skill("legacy", None)])
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")])
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert [tool.name for tool in agent_kwargs["tools"]] == ["read_file"]
def test_make_lead_agent_all_legacy_skills_preserve_all_tools(monkeypatch):
from unittest.mock import MagicMock
from deerflow.agents.lead_agent import agent as lead_agent_module
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=None))
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [_make_skill("legacy", None)])
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file")])
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert [tool.name for tool in agent_kwargs["tools"]] == ["bash", "read_file", "update_agent"]
def test_make_lead_agent_enforces_allowed_tools_when_skill_cache_is_cold(monkeypatch):
from unittest.mock import MagicMock
from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.agents.lead_agent import prompt as prompt_module
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted"]))
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")])
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
mock_storage = SimpleNamespace(load_skills=lambda *, enabled_only: [_make_skill("restricted", ["read_file"])])
with prompt_module._enabled_skills_lock:
prompt_module._enabled_skills_cache = None
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda app_config=None, **kwargs: mock_storage)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert [tool.name for tool in agent_kwargs["tools"]] == ["read_file"]
def test_make_lead_agent_fails_closed_when_skill_policy_load_fails(monkeypatch):
from unittest.mock import MagicMock
import pytest
from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.agents.lead_agent import prompt as prompt_module
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
create_agent_mock = MagicMock()
monkeypatch.setattr(lead_agent_module, "create_agent", create_agent_mock)
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted"]))
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
def fail_storage(*args, **kwargs):
raise RuntimeError("skill storage unavailable")
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", fail_storage)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
with pytest.raises(RuntimeError, match="skill storage unavailable"):
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
create_agent_mock.assert_not_called()
@@ -105,6 +105,7 @@ def test_execute_command_uses_powershell_command_mode_on_windows(monkeypatch):
"capture_output": True,
"text": True,
"timeout": 600,
"env": None,
},
)
]
@@ -118,6 +119,7 @@ def test_execute_command_uses_posix_shell_command_mode_on_windows(monkeypatch):
return SimpleNamespace(stdout="ok", stderr="", returncode=0)
monkeypatch.setattr(local_sandbox.os, "name", "nt")
monkeypatch.setattr(local_sandbox.os, "environ", {"PATH": r"C:\Program Files\Git\bin"})
monkeypatch.setattr(LocalSandbox, "_get_shell", staticmethod(lambda: r"C:\Program Files\Git\bin\sh.exe"))
monkeypatch.setattr(local_sandbox.subprocess, "run", fake_run)
@@ -132,11 +134,33 @@ def test_execute_command_uses_posix_shell_command_mode_on_windows(monkeypatch):
"capture_output": True,
"text": True,
"timeout": 600,
"env": {
"PATH": r"C:\Program Files\Git\bin",
"MSYS_NO_PATHCONV": "1",
"MSYS2_ARG_CONV_EXCL": "*",
},
},
)
]
def test_execute_command_does_not_set_msys_env_for_non_msys_posix_shell_on_windows(monkeypatch):
calls: list[tuple[object, dict]] = []
def fake_run(*args, **kwargs):
calls.append((args[0], kwargs))
return SimpleNamespace(stdout="ok", stderr="", returncode=0)
monkeypatch.setattr(local_sandbox.os, "name", "nt")
monkeypatch.setattr(LocalSandbox, "_get_shell", staticmethod(lambda: r"C:\tools\busybox\sh.exe"))
monkeypatch.setattr(local_sandbox.subprocess, "run", fake_run)
output = LocalSandbox("t").execute_command("echo /mnt/skills/demo")
assert output == "ok"
assert calls[0][1]["env"] is None
def test_execute_command_uses_cmd_command_mode_on_windows(monkeypatch):
calls: list[tuple[object, dict]] = []
@@ -159,6 +183,7 @@ def test_execute_command_uses_cmd_command_mode_on_windows(monkeypatch):
"capture_output": True,
"text": True,
"timeout": 600,
"env": None,
},
)
]
@@ -0,0 +1,72 @@
"""Tests for loop detection configuration."""
import pytest
from deerflow.config.loop_detection_config import LoopDetectionConfig
class TestLoopDetectionConfig:
def test_defaults_match_middleware_defaults(self):
config = LoopDetectionConfig()
assert config.enabled is True
assert config.warn_threshold == 3
assert config.hard_limit == 5
assert config.window_size == 20
assert config.max_tracked_threads == 100
assert config.tool_freq_warn == 30
assert config.tool_freq_hard_limit == 50
def test_accepts_custom_values(self):
config = LoopDetectionConfig(
enabled=False,
warn_threshold=10,
hard_limit=20,
window_size=50,
max_tracked_threads=200,
tool_freq_warn=60,
tool_freq_hard_limit=80,
)
assert config.enabled is False
assert config.warn_threshold == 10
assert config.hard_limit == 20
assert config.window_size == 50
assert config.max_tracked_threads == 200
assert config.tool_freq_warn == 60
assert config.tool_freq_hard_limit == 80
def test_rejects_zero_thresholds(self):
with pytest.raises(ValueError):
LoopDetectionConfig(warn_threshold=0)
with pytest.raises(ValueError):
LoopDetectionConfig(hard_limit=0)
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_warn=0)
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_hard_limit=0)
def test_rejects_hard_limit_below_warn_threshold(self):
with pytest.raises(ValueError, match="hard_limit"):
LoopDetectionConfig(warn_threshold=5, hard_limit=4)
def test_rejects_tool_freq_hard_limit_below_warn_threshold(self):
with pytest.raises(ValueError, match="tool_freq_hard_limit"):
LoopDetectionConfig(tool_freq_warn=5, tool_freq_hard_limit=4)
def test_tool_freq_override_valid(self):
config = LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 150, "hard_limit": 300}})
override = config.tool_freq_overrides["bash"]
assert override.warn == 150
assert override.hard_limit == 300
def test_tool_freq_override_rejects_zero_warn(self):
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 0, "hard_limit": 10}})
def test_tool_freq_override_rejects_hard_limit_below_warn(self):
with pytest.raises(ValueError, match="hard_limit"):
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 100, "hard_limit": 50}})
+112 -4
View File
@@ -3,7 +3,7 @@
import copy
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.messages import AIMessage, SystemMessage
from deerflow.agents.middlewares.loop_detection_middleware import (
_HARD_STOP_MSG,
@@ -146,14 +146,42 @@ class TestLoopDetection:
for _ in range(2):
mw._apply(_make_state(tool_calls=call), runtime)
# Third identical call triggers warning
# Third identical call triggers warning. The warning is appended to
# the AIMessage content (tool_calls preserved) — never inserted as a
# separate HumanMessage between the AIMessage(tool_calls) and its
# ToolMessage responses, which would break OpenAI/Moonshot strict
# tool-call pairing validation.
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msgs = result["messages"]
assert len(msgs) == 1
assert isinstance(msgs[0], HumanMessage)
assert isinstance(msgs[0], AIMessage)
assert len(msgs[0].tool_calls) == len(call)
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
assert "LOOP DETECTED" in msgs[0].content
def test_warn_does_not_break_tool_call_pairing(self):
"""Regression: the warn branch must NOT inject a non-tool message
after an AIMessage(tool_calls=...). Moonshot/OpenAI reject the next
request with 'tool_call_ids did not have response messages' if any
non-tool message is wedged between the AIMessage and its ToolMessage
responses. See #2029.
"""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(2):
mw._apply(_make_state(tool_calls=call), runtime)
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msgs = result["messages"]
assert len(msgs) == 1
assert isinstance(msgs[0], AIMessage)
assert len(msgs[0].tool_calls) == len(call)
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
def test_warn_only_injected_once(self):
"""Warning for the same hash should only be injected once per thread."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
@@ -483,7 +511,11 @@ class TestToolFrequencyDetection:
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_4.py")]), runtime)
assert result is not None
msg = result["messages"][0]
assert isinstance(msg, HumanMessage)
# Warning is appended to the AIMessage content; tool_calls preserved
# so the tools node still runs and Moonshot/OpenAI tool-call pairing
# validation does not break.
assert isinstance(msg, AIMessage)
assert msg.tool_calls
assert "read_file" in msg.content
assert "LOOP DETECTED" in msg.content
@@ -616,6 +648,37 @@ class TestToolFrequencyDetection:
assert result is not None
assert "read_file" in result["messages"][0].content
def test_override_tool_uses_override_thresholds(self):
"""A tool in tool_freq_overrides uses its own thresholds, not the global ones."""
mw = LoopDetectionMiddleware(
tool_freq_warn=5,
tool_freq_hard_limit=10,
tool_freq_overrides={"bash": (50, 100)},
)
runtime = _make_runtime()
# 10 bash calls — would hit global hard_limit=10, but bash override is 100
for i in range(10):
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
assert result is None, f"unexpected trigger on call {i + 1}"
def test_non_override_tool_falls_back_to_global(self):
"""A tool NOT in tool_freq_overrides uses the global warn/hard_limit."""
mw = LoopDetectionMiddleware(
tool_freq_warn=3,
tool_freq_hard_limit=6,
tool_freq_overrides={"bash": (50, 100)},
)
runtime = _make_runtime()
for i in range(2):
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
# 3rd read_file call hits global warn=3 (read_file has no override)
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert result is not None
assert "read_file" in result["messages"][0].content
def test_hash_detection_takes_priority(self):
"""Hash-based hard stop fires before frequency check for identical calls."""
mw = LoopDetectionMiddleware(
@@ -636,3 +699,48 @@ class TestToolFrequencyDetection:
msg = result["messages"][0]
assert isinstance(msg, AIMessage)
assert _HARD_STOP_MSG in msg.content
class TestFromConfig:
"""Tests for LoopDetectionMiddleware.from_config — the sole validated construction path."""
@staticmethod
def _config(**kwargs):
from deerflow.config.loop_detection_config import LoopDetectionConfig
return LoopDetectionConfig(**kwargs)
def test_scalar_fields_mapped(self):
config = self._config(
warn_threshold=4,
hard_limit=8,
window_size=15,
max_tracked_threads=50,
tool_freq_warn=20,
tool_freq_hard_limit=40,
)
mw = LoopDetectionMiddleware.from_config(config)
assert mw.warn_threshold == 4
assert mw.hard_limit == 8
assert mw.window_size == 15
assert mw.max_tracked_threads == 50
assert mw.tool_freq_warn == 20
assert mw.tool_freq_hard_limit == 40
def test_overrides_converted_to_tuples(self):
config = self._config(tool_freq_overrides={"bash": {"warn": 50, "hard_limit": 100}})
mw = LoopDetectionMiddleware.from_config(config)
assert mw._tool_freq_overrides == {"bash": (50, 100)}
def test_empty_overrides(self):
mw = LoopDetectionMiddleware.from_config(self._config())
assert mw._tool_freq_overrides == {}
def test_constructed_middleware_detects_loops(self):
mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4))
runtime = _make_runtime()
call = [_bash_call("ls")]
mw._apply(_make_state(tool_calls=call), runtime)
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
@@ -125,3 +125,68 @@ class TestMigrateMemory:
from scripts.migrate_user_isolation import migrate_memory
migrate_memory(paths, user_id="default") # should not raise
class TestMigrateAgents:
@staticmethod
def _seed_legacy_agent(paths: Paths, name: str, *, soul: str = "soul", description: str = "d") -> Path:
legacy_dir = paths.agents_dir / name
legacy_dir.mkdir(parents=True, exist_ok=True)
(legacy_dir / "config.yaml").write_text(f"name: {name}\ndescription: {description}\n", encoding="utf-8")
(legacy_dir / "SOUL.md").write_text(soul, encoding="utf-8")
return legacy_dir
def test_moves_legacy_into_user_layout(self, base_dir: Path, paths: Paths):
self._seed_legacy_agent(paths, "agent-a", soul="soul-a")
self._seed_legacy_agent(paths, "agent-b", soul="soul-b")
from scripts.migrate_user_isolation import migrate_agents
report = migrate_agents(paths, user_id="default")
assert {entry["agent"] for entry in report} == {"agent-a", "agent-b"}
for entry in report:
assert entry["user_id"] == "default"
assert "moved -> " in entry["action"]
for name, soul in [("agent-a", "soul-a"), ("agent-b", "soul-b")]:
dest = paths.user_agent_dir("default", name)
assert dest.exists(), f"{name} should have moved into the per-user layout"
assert (dest / "SOUL.md").read_text() == soul
# Legacy agents/ root is cleaned up once empty.
assert not paths.agents_dir.exists()
def test_dry_run_does_not_move(self, base_dir: Path, paths: Paths):
legacy_dir = self._seed_legacy_agent(paths, "agent-a")
from scripts.migrate_user_isolation import migrate_agents
report = migrate_agents(paths, user_id="default", dry_run=True)
assert len(report) == 1
assert legacy_dir.exists(), "dry-run must not touch the filesystem"
assert not paths.user_agent_dir("default", "agent-a").exists()
def test_existing_destination_is_treated_as_conflict(self, base_dir: Path, paths: Paths):
self._seed_legacy_agent(paths, "agent-a", soul="legacy soul")
dest = paths.user_agent_dir("default", "agent-a")
dest.mkdir(parents=True)
(dest / "SOUL.md").write_text("preexisting", encoding="utf-8")
from scripts.migrate_user_isolation import migrate_agents
report = migrate_agents(paths, user_id="default")
assert report[0]["action"].startswith("conflict -> ")
# Per-user destination must be left untouched.
assert (dest / "SOUL.md").read_text() == "preexisting"
# Legacy copy lands under migration-conflicts/agents/.
conflicts_dir = paths.base_dir / "migration-conflicts" / "agents" / "agent-a"
assert (conflicts_dir / "SOUL.md").read_text() == "legacy soul"
def test_no_legacy_dir_is_noop(self, base_dir: Path, paths: Paths):
from scripts.migrate_user_isolation import migrate_agents
report = migrate_agents(paths, user_id="default")
assert report == []
@@ -50,6 +50,21 @@ class TestUserAgentMemoryFile:
assert paths.user_agent_memory_file("bob", "MyAgent") == expected
class TestUserAgentDir:
def test_user_agents_dir(self, paths: Paths):
assert paths.user_agents_dir("alice") == paths.base_dir / "users" / "alice" / "agents"
def test_user_agent_dir(self, paths: Paths):
assert paths.user_agent_dir("alice", "code-reviewer") == paths.base_dir / "users" / "alice" / "agents" / "code-reviewer"
def test_user_agent_dir_lowercases_name(self, paths: Paths):
assert paths.user_agent_dir("alice", "CodeReviewer") == paths.base_dir / "users" / "alice" / "agents" / "codereviewer"
def test_user_agent_dir_validates_user_id(self, paths: Paths):
with pytest.raises(ValueError, match="Invalid user_id"):
paths.user_agent_dir("../escape", "myagent")
class TestUserThreadDir:
def test_user_thread_dir(self, paths: Paths):
expected = paths.base_dir / "users" / "u1" / "threads" / "t1"
+6 -9
View File
@@ -8,7 +8,9 @@ Tests:
5. Postgres missing-dep error message
"""
import sys
from datetime import UTC, datetime
from unittest.mock import patch
import pytest
@@ -221,13 +223,8 @@ class TestEngineLifecycle:
"""If asyncpg is not installed, error message tells user what to do."""
from deerflow.persistence.engine import init_engine
try:
import asyncpg # noqa: F401
pytest.skip("asyncpg is installed -- cannot test missing-dep path")
except ImportError:
# asyncpg is not installed — this is the expected state for this test.
# We proceed to verify that init_engine raises an actionable ImportError.
pass # noqa: S110 — intentionally ignored
with pytest.raises(ImportError, match="uv sync --extra postgres"):
with (
patch.dict(sys.modules, {"asyncpg": None}),
pytest.raises(ImportError, match="uv sync --all-packages --extra postgres"),
):
await init_engine("postgres", url="postgresql+asyncpg://x:x@localhost/x")
@@ -0,0 +1,293 @@
from __future__ import annotations
import pytest
import requests
from deerflow.community.aio_sandbox.remote_backend import RemoteSandboxBackend
from deerflow.community.aio_sandbox.sandbox_info import SandboxInfo
class _StubResponse:
def __init__(
self,
*,
status_code: int = 200,
payload: object | None = None,
json_exc: Exception | None = None,
):
self.status_code = status_code
self._payload = {} if payload is None else payload
self._json_exc = json_exc
self.ok = 200 <= status_code < 400
self.text = ""
def raise_for_status(self) -> None:
if self.status_code >= 400:
raise requests.HTTPError(f"HTTP {self.status_code}")
def json(self) -> object:
if self._json_exc is not None:
raise self._json_exc
return self._payload
def test_list_running_delegates_to_provisioner_list(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
sandbox_info = SandboxInfo(sandbox_id="test-id", sandbox_url="http://localhost:8080")
def mock_list():
return [sandbox_info]
monkeypatch.setattr(backend, "_provisioner_list", mock_list)
assert backend.list_running() == [sandbox_info]
def test_provisioner_list_returns_sandbox_infos_and_filters_invalid_entries(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
assert url == "http://provisioner:8002/api/sandboxes"
assert timeout == 10
return _StubResponse(
payload={
"sandboxes": [
{"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"},
{"sandbox_id": "missing-url"},
{"sandbox_url": "http://k3s:31002"},
]
}
)
monkeypatch.setattr(requests, "get", mock_get)
infos = backend._provisioner_list()
assert len(infos) == 1
assert infos[0].sandbox_id == "abc123"
assert infos[0].sandbox_url == "http://k3s:31001"
def test_provisioner_list_returns_empty_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
raise requests.RequestException("network down")
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_list() == []
def test_provisioner_list_returns_empty_when_payload_is_not_dict(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(payload=[{"sandbox_id": "abc", "sandbox_url": "http://k3s:31001"}])
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_list() == []
def test_provisioner_list_returns_empty_when_sandboxes_is_not_list(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(payload={"sandboxes": {"sandbox_id": "abc"}})
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_list() == []
def test_provisioner_list_skips_non_dict_sandbox_entries(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(
payload={
"sandboxes": [
{"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"},
"bad-entry",
123,
None,
]
}
)
monkeypatch.setattr(requests, "get", mock_get)
infos = backend._provisioner_list()
assert len(infos) == 1
assert infos[0].sandbox_id == "abc123"
assert infos[0].sandbox_url == "http://k3s:31001"
def test_create_delegates_to_provisioner_create(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
expected = SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001")
def mock_create(thread_id: str, sandbox_id: str, extra_mounts=None):
assert thread_id == "thread-1"
assert sandbox_id == "abc123"
assert extra_mounts == [("/host", "/container", False)]
return expected
monkeypatch.setattr(backend, "_provisioner_create", mock_create)
result = backend.create("thread-1", "abc123", extra_mounts=[("/host", "/container", False)])
assert result == expected
def test_provisioner_create_returns_sandbox_info(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_post(url: str, json: dict, timeout: int):
assert url == "http://provisioner:8002/api/sandboxes"
assert json == {"sandbox_id": "abc123", "thread_id": "thread-1"}
assert timeout == 30
return _StubResponse(payload={"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"})
monkeypatch.setattr(requests, "post", mock_post)
info = backend._provisioner_create("thread-1", "abc123")
assert info.sandbox_id == "abc123"
assert info.sandbox_url == "http://k3s:31001"
def test_provisioner_create_raises_runtime_error_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_post(url: str, json: dict, timeout: int):
raise requests.RequestException("boom")
monkeypatch.setattr(requests, "post", mock_post)
with pytest.raises(RuntimeError, match="Provisioner create failed"):
backend._provisioner_create("thread-1", "abc123")
def test_destroy_delegates_to_provisioner_destroy(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
called: list[str] = []
def mock_destroy(sandbox_id: str):
called.append(sandbox_id)
monkeypatch.setattr(backend, "_provisioner_destroy", mock_destroy)
backend.destroy(SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001"))
assert called == ["abc123"]
def test_provisioner_destroy_calls_delete(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_delete(url: str, timeout: int):
assert url == "http://provisioner:8002/api/sandboxes/abc123"
assert timeout == 15
return _StubResponse(status_code=200)
monkeypatch.setattr(requests, "delete", mock_delete)
backend._provisioner_destroy("abc123")
def test_provisioner_destroy_swallows_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_delete(url: str, timeout: int):
raise requests.RequestException("network down")
monkeypatch.setattr(requests, "delete", mock_delete)
backend._provisioner_destroy("abc123")
def test_is_alive_delegates_to_provisioner_is_alive(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_is_alive(sandbox_id: str):
assert sandbox_id == "abc123"
return True
monkeypatch.setattr(backend, "_provisioner_is_alive", mock_is_alive)
alive = backend.is_alive(SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001"))
assert alive is True
def test_provisioner_is_alive_true_only_when_status_running(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get_running(url: str, timeout: int):
return _StubResponse(payload={"status": "Running"})
monkeypatch.setattr(requests, "get", mock_get_running)
assert backend._provisioner_is_alive("abc123") is True
def mock_get_pending(url: str, timeout: int):
return _StubResponse(payload={"status": "Pending"})
monkeypatch.setattr(requests, "get", mock_get_pending)
assert backend._provisioner_is_alive("abc123") is False
def test_provisioner_is_alive_returns_false_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
raise requests.RequestException("boom")
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_is_alive("abc123") is False
def test_discover_delegates_to_provisioner_discover(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
expected = SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001")
def mock_discover(sandbox_id: str):
assert sandbox_id == "abc123"
return expected
monkeypatch.setattr(backend, "_provisioner_discover", mock_discover)
result = backend.discover("abc123")
assert result == expected
def test_provisioner_discover_returns_none_on_404(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(status_code=404)
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_discover("abc123") is None
def test_provisioner_discover_returns_info_on_success(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(payload={"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"})
monkeypatch.setattr(requests, "get", mock_get)
info = backend._provisioner_discover("abc123")
assert info is not None
assert info.sandbox_id == "abc123"
assert info.sandbox_url == "http://k3s:31001"
def test_provisioner_discover_returns_none_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
raise requests.RequestException("boom")
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_discover("abc123") is None
+71
View File
@@ -310,6 +310,28 @@ class TestDbRunEventStore:
await close_engine()
@pytest.mark.anyio
async def test_structured_content_round_trips(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
content = [{"type": "text", "text": "hello"}, {"type": "image_url", "image_url": {"url": "https://example.test/a.png"}}]
record = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=content)
assert record["content"] == content
assert record["metadata"]["content_is_json"] is True
assert "content_is_dict" not in record["metadata"]
messages = await s.list_messages("t1")
assert messages[0]["content"] == content
assert messages[0]["metadata"]["content_is_json"] is True
await close_engine()
@pytest.mark.anyio
async def test_pagination(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
@@ -373,6 +395,55 @@ class TestDbRunEventStore:
assert seqs == list(range(1, 51))
await close_engine()
@pytest.mark.anyio
async def test_put_batch_accepts_structured_content(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
content = [{"messages": [{"type": "ai", "content": ""}]}]
results = await s.put_batch(
[
{
"thread_id": "t1",
"run_id": "r1",
"event_type": "run.end",
"category": "outputs",
"content": content,
}
]
)
assert results[0]["content"] == content
assert results[0]["metadata"]["content_is_json"] is True
events = await s.list_events("t1", "r1")
assert events[0]["content"] == content
assert events[0]["metadata"]["content_is_json"] is True
await close_engine()
@pytest.mark.anyio
async def test_dict_content_keeps_legacy_metadata_flag(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
content = {"status": "success"}
record = await s.put(thread_id="t1", run_id="r1", event_type="run.end", category="outputs", content=content)
assert record["content"] == content
assert record["metadata"]["content_is_json"] is True
assert record["metadata"]["content_is_dict"] is True
await close_engine()
# -- Factory tests --
+55
View File
@@ -166,6 +166,61 @@ class TestRunRepository:
assert row["total_tokens"] == 100
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("success-run", thread_id="t1", status="running")
await repo.update_run_completion(
"success-run",
status="success",
total_input_tokens=70,
total_output_tokens=30,
total_tokens=100,
lead_agent_tokens=80,
subagent_tokens=15,
middleware_tokens=5,
)
await repo.put("error-run", thread_id="t1", status="running")
await repo.update_run_completion(
"error-run",
status="error",
total_input_tokens=20,
total_output_tokens=30,
total_tokens=50,
lead_agent_tokens=40,
subagent_tokens=10,
)
await repo.put("running-run", thread_id="t1", status="running")
await repo.update_run_completion(
"running-run",
status="running",
total_input_tokens=900,
total_output_tokens=99,
total_tokens=999,
lead_agent_tokens=999,
)
await repo.put("other-thread-run", thread_id="t2", status="running")
await repo.update_run_completion(
"other-thread-run",
status="success",
total_tokens=888,
lead_agent_tokens=888,
)
agg = await repo.aggregate_tokens_by_thread("t1")
assert agg["total_tokens"] == 150
assert agg["total_input_tokens"] == 90
assert agg["total_output_tokens"] == 60
assert agg["total_runs"] == 2
assert agg["by_model"] == {"unknown": {"tokens": 150, "runs": 2}}
assert agg["by_caller"] == {
"lead_agent": 120,
"subagent": 25,
"middleware": 5,
}
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread_ordered_desc(self, tmp_path):
"""list_by_thread returns newest first."""
+61 -16
View File
@@ -3,6 +3,8 @@ from types import SimpleNamespace
from unittest.mock import AsyncMock, call
import pytest
from langgraph.checkpoint.base import empty_checkpoint
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.runtime.runs.manager import RunManager
from deerflow.runtime.runs.schemas import RunStatus
@@ -16,6 +18,14 @@ class FakeCheckpointer:
self.aput_writes = AsyncMock()
def _make_checkpoint(checkpoint_id: str, messages: list[str], version: int):
checkpoint = empty_checkpoint()
checkpoint["id"] = checkpoint_id
checkpoint["channel_values"] = {"messages": messages}
checkpoint["channel_versions"] = {"messages": version}
return checkpoint
def test_build_runtime_context_includes_app_config_when_present():
app_config = object()
@@ -110,16 +120,16 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
)
checkpointer.adelete_thread.assert_not_awaited()
checkpointer.aput.assert_awaited_once_with(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
{
"id": "ckpt-1",
"channel_versions": {"messages": 3},
"channel_values": {"messages": ["before"]},
},
{"source": "input"},
{"messages": 3},
)
checkpointer.aput.assert_awaited_once()
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
assert restored_checkpoint["id"] != "ckpt-1"
assert "channel_versions" in restored_checkpoint
assert "channel_values" in restored_checkpoint
assert restored_checkpoint["channel_versions"] == {"messages": 3}
assert restored_checkpoint["channel_values"] == {"messages": ["before"]}
assert restored_metadata == {"source": "input"}
assert new_versions == {"messages": 3}
assert checkpointer.aput_writes.await_args_list == [
call(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
@@ -134,6 +144,40 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
]
@pytest.mark.anyio
async def test_rollback_restored_checkpoint_becomes_latest_with_real_checkpointer():
checkpointer = InMemorySaver()
thread_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
before_checkpoint = _make_checkpoint("0001", ["before"], 1)
before_config = checkpointer.put(thread_config, before_checkpoint, {"step": 1}, {"messages": 1})
after_checkpoint = _make_checkpoint("0002", ["after"], 2)
after_config = checkpointer.put(before_config, after_checkpoint, {"step": 2}, {"messages": 2})
checkpointer.put_writes(after_config, [("messages", "pending-after")], task_id="task-after")
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="0001",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": before_checkpoint,
"metadata": {"step": 1},
"pending_writes": [("task-before", "messages", "pending-before")],
},
snapshot_capture_failed=False,
)
latest = checkpointer.get_tuple(thread_config)
assert latest is not None
assert latest.config["configurable"]["checkpoint_id"] != "0001"
assert latest.config["configurable"]["checkpoint_id"] != "0002"
assert latest.checkpoint["channel_values"] == {"messages": ["before"]}
assert latest.pending_writes == [("task-before", "messages", "pending-before")]
assert ("task-after", "messages", "pending-after") not in latest.pending_writes
@pytest.mark.anyio
async def test_rollback_deletes_thread_when_no_snapshot_exists():
checkpointer = FakeCheckpointer(put_result=None)
@@ -194,12 +238,13 @@ async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
snapshot_capture_failed=False,
)
checkpointer.aput.assert_awaited_once_with(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
{"id": "ckpt-1", "channel_versions": {}},
{},
{},
)
checkpointer.aput.assert_awaited_once()
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
assert restored_checkpoint["id"] != "ckpt-1"
assert restored_checkpoint["channel_versions"] == {}
assert restored_metadata == {}
assert new_versions == {}
@pytest.mark.anyio
+36
View File
@@ -7,6 +7,7 @@ import yaml
from deerflow.config import app_config as app_config_module
from deerflow.config import extensions_config as extensions_config_module
from deerflow.config import skills_config as skills_config_module
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.paths import Paths
@@ -35,6 +36,7 @@ def test_default_runtime_paths_resolve_from_current_project(tmp_path: Path, monk
encoding="utf-8",
)
(tmp_path / "extensions_config.json").write_text('{"mcpServers": {}, "skills": {}}', encoding="utf-8")
(tmp_path / "skills").mkdir()
assert AppConfig.resolve_config_path() == tmp_path / "config.yaml"
assert ExtensionsConfig.resolve_config_path() == tmp_path / "extensions_config.json"
@@ -121,6 +123,40 @@ def test_app_config_falls_back_to_legacy_when_project_root_lacks_config(tmp_path
assert AppConfig.resolve_config_path() == legacy_backend_config
def test_skills_config_falls_back_to_legacy_when_project_root_lacks_skills(tmp_path: Path, monkeypatch):
"""When DEER_FLOW_PROJECT_ROOT is unset and cwd has no `skills/`, the legacy
repo-root candidate must be used so monorepo runs (cwd=backend/) keep finding
`<repo>/skills` instead of `<repo>/backend/skills` (regression test for #2694)."""
_clear_path_env(monkeypatch)
cwd = tmp_path / "cwd"
cwd.mkdir()
monkeypatch.chdir(cwd)
legacy_skills = tmp_path / "legacy-repo" / "skills"
legacy_skills.mkdir(parents=True)
monkeypatch.setattr(
skills_config_module,
"_legacy_skills_candidates",
lambda: (legacy_skills,),
)
assert SkillsConfig().get_skills_path() == legacy_skills
def test_skills_config_returns_project_default_when_neither_exists(tmp_path: Path, monkeypatch):
"""When nothing exists, fall back to the project-root default path so callers
surface a stable empty location instead of silently picking a stale legacy dir."""
_clear_path_env(monkeypatch)
cwd = tmp_path / "cwd"
cwd.mkdir()
monkeypatch.chdir(cwd)
monkeypatch.setattr(skills_config_module, "_legacy_skills_candidates", lambda: ())
assert SkillsConfig().get_skills_path() == cwd / "skills"
def test_extensions_config_falls_back_to_legacy_when_project_root_lacks_file(tmp_path: Path, monkeypatch):
"""ExtensionsConfig should hit the legacy backend/repo-root locations when
the caller project root has no extensions_config.json/mcp_config.json."""
+308
View File
@@ -0,0 +1,308 @@
"""Unit tests for the Serper community web search tool."""
import json
from unittest.mock import MagicMock, patch
import httpx
import pytest
@pytest.fixture(autouse=True)
def reset_api_key_warned():
"""Reset the module-level warning flag before each test."""
import deerflow.community.serper.tools as serper_mod
serper_mod._api_key_warned = False
yield
serper_mod._api_key_warned = False
@pytest.fixture
def mock_config_with_key():
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "test-serper-key", "max_results": 5}
mock.return_value.get_tool_config.return_value = tool_config
yield mock
@pytest.fixture
def mock_config_no_key():
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {}
mock.return_value.get_tool_config.return_value = tool_config
yield mock
def _make_serper_response(organic: list) -> MagicMock:
mock_resp = MagicMock()
mock_resp.json.return_value = {"organic": organic}
mock_resp.raise_for_status = MagicMock()
return mock_resp
class TestGetApiKey:
def test_returns_config_key_when_present(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "from-config"}
mock.return_value.get_tool_config.return_value = tool_config
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "from-config"
def test_falls_back_to_env_when_config_key_empty(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": ""}
mock.return_value.get_tool_config.return_value = tool_config
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "env-key"
def test_falls_back_to_env_when_config_key_whitespace(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": " "}
mock.return_value.get_tool_config.return_value = tool_config
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "env-key"
def test_falls_back_to_env_when_config_key_null(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": None}
mock.return_value.get_tool_config.return_value = tool_config
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "env-key"
def test_falls_back_to_env_when_no_config(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
mock.return_value.get_tool_config.return_value = None
with patch.dict("os.environ", {"SERPER_API_KEY": "env-only"}):
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() == "env-only"
def test_returns_none_when_no_key_anywhere(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
mock.return_value.get_tool_config.return_value = None
with patch.dict("os.environ", {}, clear=True):
import os
os.environ.pop("SERPER_API_KEY", None)
from deerflow.community.serper.tools import _get_api_key
assert _get_api_key() is None
class TestWebSearchTool:
def test_basic_search_returns_normalized_results(self, mock_config_with_key):
organic = [
{"title": "Result 1", "link": "https://example.com/1", "snippet": "Snippet 1"},
{"title": "Result 2", "link": "https://example.com/2", "snippet": "Snippet 2"},
]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "python tutorial"})
parsed = json.loads(result)
assert parsed["query"] == "python tutorial"
assert parsed["total_results"] == 2
assert parsed["results"][0]["title"] == "Result 1"
assert parsed["results"][0]["url"] == "https://example.com/1"
assert parsed["results"][0]["content"] == "Snippet 1"
def test_respects_max_results_from_config(self, mock_config_with_key):
mock_config_with_key.return_value.get_tool_config.return_value.model_extra = {
"api_key": "test-key",
"max_results": 3,
}
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert parsed["total_results"] == 3
assert len(parsed["results"]) == 3
def test_max_results_parameter_accepted(self, mock_config_no_key):
"""Tool accepts max_results as a call parameter when config does not override it."""
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
mock_resp = _make_serper_response(organic)
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test", "max_results": 2})
parsed = json.loads(result)
assert parsed["total_results"] == 2
def test_config_max_results_overrides_parameter(self):
"""Config max_results overrides the parameter passed at call time, matching ddg_search behaviour."""
with patch("deerflow.community.serper.tools.get_app_config") as mock:
tool_config = MagicMock()
tool_config.model_extra = {"api_key": "test-key", "max_results": 3}
mock.return_value.get_tool_config.return_value = tool_config
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test", "max_results": 8})
parsed = json.loads(result)
assert parsed["total_results"] == 3
def test_empty_organic_returns_error_json(self, mock_config_with_key):
"""Empty organic list returns structured error, matching ddg_search convention."""
mock_resp = _make_serper_response([])
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "no results"})
parsed = json.loads(result)
assert "error" in parsed
assert parsed["error"] == "No results found"
assert parsed["query"] == "no results"
def test_missing_api_key_returns_error_json(self, mock_config_no_key):
with patch.dict("os.environ", {}, clear=True):
import os
os.environ.pop("SERPER_API_KEY", None)
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert "error" in parsed
assert "SERPER_API_KEY" in parsed["error"]
def test_missing_api_key_logs_warning_once(self, mock_config_no_key, caplog):
import logging
with patch.dict("os.environ", {}, clear=True):
import os
os.environ.pop("SERPER_API_KEY", None)
from deerflow.community.serper.tools import web_search_tool
with caplog.at_level(logging.WARNING, logger="deerflow.community.serper.tools"):
web_search_tool.invoke({"query": "q1"})
web_search_tool.invoke({"query": "q2"})
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
assert len(warnings) == 1
def test_http_error_returns_structured_error(self, mock_config_with_key):
mock_error_response = MagicMock()
mock_error_response.status_code = 403
mock_error_response.text = "Forbidden"
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = httpx.HTTPStatusError("403", request=MagicMock(), response=mock_error_response)
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert "error" in parsed
assert "403" in parsed["error"]
def test_network_exception_returns_error_json(self, mock_config_with_key):
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = Exception("timeout")
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert "error" in parsed
def test_sends_correct_headers_and_payload(self, mock_config_with_key):
organic = [{"title": "T", "link": "https://x.com", "snippet": "S"}]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
web_search_tool.invoke({"query": "hello world"})
call_kwargs = mock_post.call_args
headers = call_kwargs.kwargs["headers"]
payload = call_kwargs.kwargs["json"]
assert headers["X-API-KEY"] == "test-serper-key"
assert payload["q"] == "hello world"
assert payload["num"] == 5
def test_uses_env_key_when_config_absent(self):
with patch("deerflow.community.serper.tools.get_app_config") as mock:
mock.return_value.get_tool_config.return_value = None
with patch.dict("os.environ", {"SERPER_API_KEY": "env-only-key"}):
organic = [{"title": "T", "link": "https://x.com", "snippet": "S"}]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
web_search_tool.invoke({"query": "env key test"})
headers = mock_post.call_args.kwargs["headers"]
assert headers["X-API-KEY"] == "env-only-key"
def test_partial_fields_in_organic_result(self, mock_config_with_key):
"""Missing title/link/snippet should default to empty string."""
organic = [{}]
mock_resp = _make_serper_response(organic)
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
from deerflow.community.serper.tools import web_search_tool
result = web_search_tool.invoke({"query": "test"})
parsed = json.loads(result)
assert parsed["results"][0] == {"title": "", "url": "", "content": ""}

Some files were not shown because too many files have changed in this diff Show More