Compare commits

...

39 Commits

Author SHA1 Message Date
copilot-swe-agent[bot] dad3997459 fix(sandbox): cleanup dead containers and avoid lock-held liveness checks
Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/96707445-0f8b-4901-8ef3-d8e5667f8a05

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>
2026-05-11 00:09:09 +00:00
Willem Jiang b67c2a4e56 fix(sandbox): auto-restart crashed containers transparently (#2788)
When a sandbox container crashes (e.g. due to an internal error), the
  agent enters a connection-refused loop because AioSandboxProvider.get()
  returns a cached but dead sandbox object. Add a liveness check in get()
  that detects crashed containers via backend.is_alive() and evicts them
  from all caches, allowing ensure_sandbox_initialized() to transparently
  recreate a fresh container on the next acquire().

  The behavior is controlled by a new  config option
  (default: true). Set to false to skip health checks and preserve the
  old behavior of returning stale cached sandboxes.

  Closes #2788
2026-05-10 22:53:58 +08:00
Xinmin Zeng 94da8f67d7 fix(scripts): preserve uv extras across make dev restarts (#2754) (#2767)
`make dev` ran `uv sync` unconditionally on every restart, wiping any
optional extras the user had installed manually with
`uv sync --all-packages --extra postgres`. The Docker image-build path
already solved this via the `UV_EXTRAS` build-arg in backend/Dockerfile;
the local serve.sh path and the docker-compose-dev startup command
were the remaining outliers.

`scripts/serve.sh` now resolves extras before `uv sync`:
  1. honors `UV_EXTRAS` (parity with backend/Dockerfile and
     docker/docker-compose.yaml — no new convention introduced);
  2. falls back to parsing config.yaml — `database.backend: postgres`
     or legacy `checkpointer.type: postgres` auto-pins
     `--extra postgres`, so the common case needs zero extra config.
  3. detector stderr is no longer suppressed, so whitelist warnings or
     crashes surface to the dev terminal (review feedback).

Detection lives in `scripts/detect_uv_extras.py` (stdlib-only — has to
run before the venv exists). Extra names are validated against
`^[A-Za-z][A-Za-z0-9_-]*$` so a stray shell metacharacter in `.env`
cannot reach `uv sync` downstream (defense in depth).

`docker/docker-compose-dev.yaml`'s startup command is now extracted to
`docker/dev-entrypoint.sh` (review feedback — the inline command had
grown to a ~350-char one-liner). The script:
  - parses comma/whitespace-separated UV_EXTRAS, applying the same
    `^[A-Za-z][A-Za-z0-9_-]*$` whitelist as the local detector;
  - emits one `--extra X` flag per token, so `UV_EXTRAS=postgres,ollama`
    works in Docker dev too (harmonized with local — review feedback);
  - calls `uv sync --all-packages` (PR #2584) so workspace member
    extras (deerflow-harness's postgres extra) are installed;
  - keeps the existing self-heal `(uv sync || (recreate venv && retry))`
    branch;
  - exposes `--print-extras` for dry-run testing.

The compose file mounts the script read-only at runtime, so script
edits take effect on `make docker-restart` without an image rebuild.

The `--no-sync` alternative (a separate suggestion in the issue thread)
was considered but rejected for dev paths because it would drop the
self-heal branch and the auto-pickup of new pyproject deps. `--no-sync`
is already in use for the production CMD (`backend/Dockerfile:101`)
where it's appropriate.

Updates the asyncpg-missing error message to include the
`--all-packages` flag (matching #2584) plus the persistent install flow,
and expands `config.example.yaml` so all three install paths
(local / docker dev / docker image build) are documented with their
multi-extra capabilities.

Tests:
  - `tests/test_detect_uv_extras.py` (21 tests) — local-path env parsing,
    YAML edge cases, env-vs-config precedence, whitelist rejection of
    shell metacharacters.
  - `tests/test_dev_entrypoint.py` (15 tests) — docker-path validation
    via `--print-extras`, multi-extra parsing, metacharacter abort.
  - `tests/test_persistence_scaffold.py` (22 tests, unchanged) — passes
    with the merged `--all-packages --extra postgres` error message.

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-10 22:28:29 +08:00
YuJitang 5127f08e1a enable token usage by default (#2841) 2026-05-10 22:00:57 +08:00
DanielWalnut dfa4eb0c1a [codex] fix follow-up suggestions layout (#2836)
* fix follow-up suggestions layout

* fix agent chat welcome layout transition

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-10 15:10:44 +08:00
DanielWalnut 08ee7adeba fix(lint): remove duplicate is_dynamic_context_reminder definition (#2837)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-09 23:40:46 +08:00
Eilen Shin 1c96a6afc8 fix: keep new agent bootstrap in user scope (#2784) 2026-05-09 19:43:50 +08:00
YuJitang 417416087b fix: use backend thread token usage for header total (#2800)
* fix: use backend thread token usage for header total

* Refactor thread token usage fetch
2026-05-09 19:40:32 +08:00
DanielWalnut 881ff71252 fix(harness): preserve dynamic context across summarization (#2823) 2026-05-09 19:39:36 +08:00
DanielWalnut f76e4e35c8 fix title generation with dynamic context reminder (#2830) 2026-05-09 18:22:58 +08:00
yangyufan 0d1053ca44 fix(uploads): add Windows support for safe symlink-protected uploads (#2794)
* fix(uploads): add Windows support for safe symlink-protected uploads

* fix(uploads): update tests and translate comments;
2026-05-09 18:21:54 +08:00
He Wang 4063dd7157 feat(debug): print presented file paths with physical resolution (#2825)
Surface artifacts produced via the present_files tool in the CLI debug
REPL so headless clients without a frontend (VS Code launch configs,
etc.) can locate output files. Each turn prints newly added artifacts
plus their resolved host path. Works for any source that goes through
present_files — ACP agents, subagents, or sandbox writes.

Co-authored-by: Claude Opus 4 <noreply@anthropic.com>
2026-05-09 18:21:01 +08:00
ChenglongZ 7a3c58a733 Fix duplicate gateway upload filenames (#2789) 2026-05-09 18:02:40 +08:00
dependabot[bot] 1edc9d9fae chore(deps): bump langchain-core from 1.3.2 to 1.3.3 in /backend (#2807)
Bumps [langchain-core](https://github.com/langchain-ai/langchain) from 1.3.2 to 1.3.3.
- [Release notes](https://github.com/langchain-ai/langchain/releases)
- [Commits](https://github.com/langchain-ai/langchain/compare/langchain-core==1.3.2...langchain-core==1.3.3)

---
updated-dependencies:
- dependency-name: langchain-core
  dependency-version: 1.3.3
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-09 15:51:18 +08:00
KiteEater 7caf03e97c fix(packaging): add postgres extra for store/checkpointer supportFix postgres extra install guidance (#2584)
* Fix postgres extra install guidance

* Fix postgres install message lint

* Format postgres install messages

* Fix postgres install guidance and config docs
2026-05-09 09:49:08 +08:00
dependabot[bot] 41b04a556f chore(deps): bump uuid from 10.0.0 to 14.0.0 in /frontend (#2802)
Bumps [uuid](https://github.com/uuidjs/uuid) from 10.0.0 to 14.0.0.
- [Release notes](https://github.com/uuidjs/uuid/releases)
- [Changelog](https://github.com/uuidjs/uuid/blob/main/CHANGELOG.md)
- [Commits](https://github.com/uuidjs/uuid/compare/v10.0.0...v14.0.0)

---
updated-dependencies:
- dependency-name: uuid
  dependency-version: 14.0.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-09 09:33:00 +08:00
DanielWalnut c1b7f1d189 feat: static system prompt with DynamicContextMiddleware for prefix-cache optimization (#2801)
* feat(middleware): inject dynamic context via DynamicContextMiddleware

Move memory and current date out of the system prompt and into a
dedicated <system-reminder> HumanMessage injected once per session
(frozen-snapshot pattern) via a new DynamicContextMiddleware.

This keeps the system prompt byte-exact across all users and sessions,
enabling maximum Anthropic/Bedrock prefix-cache reuse.

Key design decisions:
- ID-swap technique: reminder takes the first HumanMessage's ID
  (replacing it in-place via add_messages), original content gets a
  derived `{id}__user` ID (appended after). Preserves correct ordering.
- hide_from_ui: True on reminder messages so frontend filters them out.
- Midnight crossing: date-update reminder injected before the current
  turn's HumanMessage when the conversation spans midnight.
- INFO-level logging for production diagnostics.

Also adds prompt-caching breakpoint budget enforcement tests and
updates ClaudeChatModel docs to reference the new pattern.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* feat(token-usage): log input/output token detail breakdown in middleware

Extend the LLM token usage log line to include input_token_details and
output_token_details (cache_creation, cache_read, reasoning, audio, etc.)
when present. Adds tests covering Anthropic cache detail logging from
both usage_metadata and response_metadata.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: fix nginx

* fix(middleware): always inject date; gate memory on injection_enabled

Date injection is now unconditional — it is part of the static system
prompt replacement and should always be present. Memory injection
remains gated by `memory.injection_enabled` in the app config.

Previously the entire DynamicContextMiddleware was skipped when
injection_enabled was False, which also suppressed the date.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(lint): format files and correct test assertions for token usage middleware

- ruff format dynamic_context_middleware.py and test_claude_provider_prompt_caching.py
- Remove unused pytest import from test_dynamic_context_middleware.py
- Fix two tests that asserted response_metadata fallback logic that
  doesn't exist: replace with tests that match actual middleware behavior

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(middleware): address Copilot review comments on DynamicContextMiddleware

- Use additional_kwargs flag for reminder detection instead of content
  substring matching, so user messages containing '<system-reminder>'
  are not mistakenly treated as injected reminders
- Generate stable UUID when original HumanMessage.id is None to prevent
  ambiguous 'None__user' derived IDs and message collisions
- Downgrade per-turn no-op log to DEBUG; keep actual injection events at INFO
- Add two new tests: missing-id UUID fallback and user-text false-positive

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-09 09:27:02 +08:00
dependabot[bot] 109490da25 chore(deps): bump python-multipart from 0.0.26 to 0.0.27 in /backend (#2799)
Bumps [python-multipart](https://github.com/Kludex/python-multipart) from 0.0.26 to 0.0.27.
- [Release notes](https://github.com/Kludex/python-multipart/releases)
- [Changelog](https://github.com/Kludex/python-multipart/blob/main/CHANGELOG.md)
- [Commits](https://github.com/Kludex/python-multipart/compare/0.0.26...0.0.27)

---
updated-dependencies:
- dependency-name: python-multipart
  dependency-version: 0.0.27
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-08 22:58:15 +08:00
dependabot[bot] 14c0a32ee6 chore(deps): bump mako from 1.3.11 to 1.3.12 in /backend (#2798)
Bumps [mako](https://github.com/sqlalchemy/mako) from 1.3.11 to 1.3.12.
- [Release notes](https://github.com/sqlalchemy/mako/releases)
- [Changelog](https://github.com/sqlalchemy/mako/blob/main/CHANGES)
- [Commits](https://github.com/sqlalchemy/mako/commits)

---
updated-dependencies:
- dependency-name: mako
  dependency-version: 1.3.12
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-08 22:57:48 +08:00
Willem Jiang 70737af7cd fix(nignx):resolve CSRF auth failure on non-standard ports (#2796) 2026-05-08 22:40:38 +08:00
DanielWalnut 2b1fcb3e43 fix(task): remove max_turns parameter from task tool interface (#2783)
* fix(task): remove max_turns parameter from task tool interface

Subagents should always use their configured max_turns value. Exposing
this parameter allowed callers to override the admin-configured limit,
which is undesirable. The value is now exclusively driven by subagent
config (per-agent overrides and global defaults in config.yaml).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-08 15:05:24 +08:00
He Wang 7de9b5828b fix(tools): introduce Runtime type alias to eliminate Pydantic serialization warning (#2774)
* fix(tools): introduce Runtime type alias to eliminate Pydantic serialization warning

Add deerflow/tools/types.py with:

    Runtime = ToolRuntime[dict[str, Any], ThreadState]

Replace every runtime: ToolRuntime[ContextT, ThreadState] and
runtime: ToolRuntime[dict[str, Any], ThreadState] annotation in
sandbox/tools.py, present_file_tool.py, task_tool.py, view_image_tool.py,
and skill_manage_tool.py with the new Runtime alias.

The unbound ContextT TypeVar (default None) caused
PydanticSerializationUnexpectedValue warnings on every tool call because
LangChain's BaseTool._parse_input calls model_dump() on the auto-generated
args_schema while DeerFlow passes a dict as runtime context.
Binding the context to dict[str, Any] aligns Pydantic's serialization
expectations with reality and removes the noise from all run modes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

* fix(tools): extend Runtime alias to setup_agent and update_agent tools

Replace bare ToolRuntime annotations in setup_agent_tool.py and
update_agent_tool.py with the shared Runtime alias introduced in the
previous commit, and add both tools to the Pydantic serialization
warning regression test (13 cases total).

Co-authored-by: Cursor <cursoragent@cursor.com>

* test(tools): loosen Pydantic warning filter to avoid version-specific format

Replace the brittle "field_name='context'" substring check with a looser
"context" match so the assertion stays valid if Pydantic changes its
internal warning format across versions.

Co-authored-by: Cursor <cursoragent@cursor.com>

* test(tools): simplify warning filter and clean up docstring

Remove the "context" substring condition from the Pydantic warning
filter — asserting that no PydanticSerializationUnexpectedValue fires
at all is both simpler and more comprehensive, since the test payload
contains only the tool's own args plus runtime.

Also update the module docstring to remove the version-specific warning
format example that was inconsistent with the looser filter.

Co-authored-by: Cursor <cursoragent@cursor.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-08 14:50:33 +08:00
Eilen Shin 37db689349 fix(events): serialize structured db event content (#2762) 2026-05-08 10:17:17 +08:00
Eilen Shin bd45cb2846 fix(sandbox): disable msys path conversion (#2766) 2026-05-08 10:13:11 +08:00
Eilen Shin 5fd0e6ac89 fix(middleware): sync raw tool call metadata (#2757) 2026-05-08 10:08:53 +08:00
YuJitang 530bda7107 fix: dedupe token usage aggregation by message id (#2770) 2026-05-08 09:54:20 +08:00
Willem Jiang 6c220a9aef fix(chat): prevent first user message from being swallowed in new conversations (#2731)
* fix(chat): prevent first user message from being swallowed in new conversations

  The optimistic message clearing effect cleared too eagerly — any stream
  message (including AI messages from messages-tuple events) triggered the
  clear before the server's human message had arrived via values events.
  For new threads this caused the user's first prompt to disappear permanently.

  Only clear optimistic messages once the server's human message has been
  confirmed to arrive in thread.messages, not just when any message arrives.

  Fixes #2730

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-07 17:31:48 +08:00
Tao Liu daa3ffc29b feat(loop-detection): make loop detection configurable with per-tool frequency overrides (#2711)
* Make loop detection configurable

Expose LoopDetectionMiddleware thresholds through config.yaml while preserving existing defaults and allowing the middleware to be disabled.

Refs bytedance/deer-flow#2517

* feat(loop-detection): add per-tool tool_freq_overrides to Phase 1

Adds ToolFreqOverride model and tool_freq_overrides field to
LoopDetectionConfig, wires it through LoopDetectionMiddleware, and
documents the option in config.example.yaml.

Resolves the gap flagged in the #2586 review: without per-tool overrides,
users hit by #2510/#2511 (RNA-seq workflows exceeding the bash hard limit)
had no way to raise thresholds for one tool without loosening the global
limit for every tool.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* docs(loop-detection): document tool_freq_overrides in LoopDetectionMiddleware docstring

Add the missing Args entry for tool_freq_overrides, explaining the
(warn, hard_limit) tuple structure and how per-tool thresholds supersede
the global tool_freq_warn / tool_freq_hard_limit for named tools.
Also run ruff format on the three files flagged by the lint check.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(loop-detection): validate LoopDetectionMiddleware __init__ params eagerly

Raise clear ValueError at construction time instead of crashing at
unpack-time inside _track_and_check when bad values are passed:
- tool_freq_overrides: must be 2-tuples of positive ints with hard_limit >= warn
- scalar thresholds: warn_threshold, hard_limit, tool_freq_warn,
  tool_freq_hard_limit must be >= 1 and hard limits must >= their warn pairs
- window_size, max_tracked_threads must be >= 1

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(test): isolate credential loader directory-path test from real ~/.claude

The test didn't monkeypatch HOME, so on any machine with real Claude Code
credentials at ~/.claude/.credentials.json the function fell through to
those credentials and the assertion failed. Adding HOME redirect ensures
the default credential path doesn't exist during the test.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style(test): add blank lines after import pytest in TestInitValidation

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* refactor(loop-detection): collapse dual validation to LoopDetectionConfig

Modifications
  - LoopDetectionMiddleware.__init__: stripped of all ValueError raises;
    becomes a plain field-assignment constructor.
  - LoopDetectionMiddleware.from_config: classmethod that builds the
    middleware from a Pydantic-validated LoopDetectionConfig and handles
    the ToolFreqOverride -> tuple[int, int] conversion.
  - agents/factory.py: SDK construction routed through
    LoopDetectionMiddleware.from_config(LoopDetectionConfig()) so the
    defaults path is Pydantic-validated too.
  - agents/lead_agent/agent.py: uses from_config instead of unpacking
    config fields by hand.
  - tests/test_loop_detection_middleware.py: deleted TestInitValidation
    (16 methods exercising the removed __init__ checks); added
    TestFromConfig (4 tests: scalar field mapping, override tuple
    conversion, empty overrides, behavioral smoke test).

Result: one validation layer (Pydantic), zero duplication, no __new__
hacks. Both production construction sites flow through LoopDetectionConfig.

Test results
  make test   -> 2977 passed, 18 skipped, 0 failed (137s)
  make format -> All checks passed; 411 files left unchanged

* feat(agents): make loop_detection configurable in create_deerflow_agent

Adds a `loop_detection: bool | AgentMiddleware = True` field to
RuntimeFeatures, mirroring the existing pattern used by `sandbox`,
`memory`, and `vision`. SDK users can now disable LoopDetectionMiddleware
or replace it with a custom instance built from their own
LoopDetectionConfig — e.g.
`LoopDetectionMiddleware.from_config(my_cfg)` — instead of being stuck
with the hardcoded defaults previously installed by the SDK factory.

The lead-agent path (which already reads AppConfig.loop_detection) is
unchanged, and the default `True` preserves prior always-on behavior for
all existing callers.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

---------

Co-authored-by: knight0940 <631532668@qq.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: Amorend <142649913+knight0940@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-07 16:15:15 +08:00
Xinmin Zeng 27559f3675 fix(frontend): defer thread id to onStart to avoid 404 on new chat (#2749)
* fix(frontend): defer thread id to onStart to avoid 404 on new chat

The LangGraph SDK's useStream eagerly fetches /threads/{id}/history the
moment it receives a thread id, and the local useThreadRuns issues
GET /threads/{id}/runs for the same reason. The chats page used to flip
isNewThread=false (and forward the client-generated thread id) inside
the synchronous onSend callback, before thread.submit had created the
thread on the backend. The two queries therefore raced ahead of
POST /runs/stream and returned 404 on the very first send.

Drop the onSend handler so isNewThread stays true until onStart fires
from useStream's onCreated — by then the backend has the thread, and
the SDK's submittingRef guard naturally suppresses the redundant
history fetch. The agent chat page already uses this pattern, so this
also unifies the two flows.

Adds an E2E regression that records request ordering and asserts
GET /history and GET /runs are never issued before POST /runs/stream
on the first send from /chats/new.

Closes #2746

* fix(frontend): split welcome layout from backend thread state

Removing onSend kept GET /history and GET /runs from racing ahead of
POST /runs/stream, but it also coupled the welcome layout (centered
input, hero, quick actions) to backend thread creation.  Until onCreated
returned, the user's optimistic message and the welcome hero rendered on
top of each other.

Introduce a dedicated `isWelcomeMode` UI flag, separate from
`isNewThread`:
- `isNewThread` still tracks "backend has no thread yet" and gates the
  thread id forwarded to useStream.
- `isWelcomeMode` drives the visual layout (header background, input
  box position, max width, hero, quick actions, autoFocus) and flips to
  false inside onSend so the layout animates immediately.

`isWelcomeMode` is kept in sync with `isNewThread` via an effect so
sidebar navigation and "new chat" still behave correctly.  All 15 E2E
tests pass, including the ordering regression added in the previous
commit.

* test(e2e): use monotonic sequence for thread-init ordering check

Date.now() is millisecond-resolution, so two requests emitted within
the same tick would share a timestamp and slip past the strict `<`
ordering assertions. Replace the timestamp with a monotonic counter
that increments on every observed request/requestfinished event so the
ordering check is robust regardless of scheduling.

Per PR #2749 review feedback from copilot-pull-request-reviewer.

* refactor(input-box): rename isNewThread prop to isWelcomeMode

Inside InputBox, the prop named `isNewThread` is only ever consulted
for visual layout decisions — gating follow-up suggestions, the bottom
background strip, and the welcome-mode quick-action SuggestionList. It
never reflects "the backend has created the thread", which after #2746
is tracked separately via `isNewThread` in the chat pages themselves.

Rename the prop to `isWelcomeMode` and update both call sites
(workspace chats page and agent chats page) so the prop name matches
its actual semantics. No behavior change.

Per PR #2749 review feedback from @WillemJiang.
2026-05-07 16:11:44 +08:00
AochenShen99 cef4224381 fix(skills): enforce allowed-tools metadata (#2626)
* fix(skills): parse allowed-tools frontmatter

* fix(skills): validate allowed-tools metadata

* fix(skills): add shared allowed-tools policy

* fix(subagents): enforce skill allowed-tools

* fix(agent): enforce skill allowed-tools

* refactor(skills): dedupe TypeVar and reuse cached enabled skills

- Drop redundant module-level TypeVar in tool_policy; rely on PEP 695 syntax.
- Expose get_cached_enabled_skills() and have the lead agent reuse it
  instead of synchronously rescanning skills on every request.

* fix(agent): expose config-scoped skill cache

* fix(subagents): pass filtered tools explicitly

* fix(skills): clean allowed-tools policy feedback
2026-05-07 08:34:43 +08:00
Hinotobi 2b0e62f679 [security] fix(auth): reject cross-site auth POSTs (#2740)
* fix(security): reject cross-site auth posts

* fix(auth): align secure cookie proxy scheme handling

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-07 07:58:06 +08:00
Eilen Shin 1336872b15 fix(channels): authenticate gateway command requests (#2742) 2026-05-06 15:27:34 +08:00
KiteEater 4ead2c6b19 fix(config): reset config-backed singletons on hot reload (#2588)
* Fix stale config singletons on reload

* fix(config): update checkpointer imports after runtime move

* Fix config reload singleton mutation on validation failure

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-06 10:17:55 +08:00
yangzheli 59c4a3f0a4 feat(agent): add custom-agent self-updates with user isolation (#2713)
* feat(agent): add update_agent tool for in-chat custom-agent self-updates (#2616)

Custom agents had no built-in way to persist updates to their own SOUL.md /
config.yaml from a normal chat — `setup_agent` was only bound during the
bootstrap flow, so when the user asked the agent to refine its description
or personality, the agent would shell out via bash/write_file and the edits
landed in a temporary sandbox/tool workspace instead of
`{base_dir}/agents/{agent_name}/`.

Changes:
- New `update_agent` builtin tool with partial-update semantics (only the
  fields you pass are written) and atomic temp-file + os.replace writes so
  a failed update never corrupts existing SOUL.md / config.yaml.
- Lead agent now binds `update_agent` in the non-bootstrap path whenever
  `agent_name` is set in the runtime context. Default agent (no
  agent_name) and bootstrap flow are unchanged.
- New `<self_update>` system-prompt section is injected for custom agents,
  instructing them to use `update_agent` — and explicitly NOT bash /
  write_file — to persist self-updates.
- Tests: 11 new cases in `tests/test_update_agent_tool.py` covering
  validation (missing/invalid agent_name, unknown agent, no fields),
  partial updates (soul-only, description-only, skills=[] vs omitted),
  no-op detection, atomic-write safety, and AgentConfig round-tripping;
  plus 2 new cases in `tests/test_lead_agent_prompt.py` covering the
  self-update prompt section.
- Docs: updated backend/CLAUDE.md builtin tools list and tools.mdx
  (en/zh) with the new tool description.

* feat(agent): isolate custom agents per user

Store custom agent definitions under the effective user, keep legacy agents readable until migration, and cover API/tool/migration behavior with tests.

Co-authored-by: Cursor <cursoragent@cursor.com>

* feat: consistent write/delete targets & add --user-id to migration

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-05 23:17:42 +08:00
Nan Gao e8675f266d fix(loop-detection): keep tool-call pairing on warn injection (#2724) (#2725)
* fix(loop-detection): keep tool-call pairing on warn injection (#2724)

* make format

* fix(loop-detection): avoid IMMessage leak to downstream consumer

* fix(channels): filter loop warning text from IM replies
2026-05-05 18:53:49 +08:00
Xun 680187ddc2 fix: Supplement list_running in RemoteSandboxBackend (#2716)
* fix: Supplement list_running in RemoteSandboxBackend

* fix

* except requests.RequestException as exc:

* fix
2026-05-05 18:53:10 +08:00
Xinmin Zeng aded753de3 fix(frontend): restore localhost fallback for getGatewayConfig in prod mode (#2705) (#2718)
* fix(frontend): unify gateway-config localhost fallback for prod (#2705)

`getGatewayConfig()` only fell back to localhost defaults when
`NODE_ENV === "development"`, while `next.config.js` always falls back
to `127.0.0.1:8001`. Running `make start` (which sets NODE_ENV=production
via `next start`) without `DEER_FLOW_INTERNAL_GATEWAY_BASE_URL` /
`DEER_FLOW_TRUSTED_ORIGINS` therefore caused zod to throw inside SSR
layouts and surfaced as a 500.

Drop the NODE_ENV gating and use localhost defaults everywhere — the
"force explicit config in prod" intent should be enforced by deployment
templates (docker-compose already sets both vars), not by request-time
crashes. Document the two vars in both .env.example files and add unit
coverage for the dev/prod env-unset paths.

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* Update internalGatewayUrl in gateway config tests

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-05 16:27:29 +08:00
Willem Jiang 028493bfd8 fix(docker):force ngix to resolve upstream names at request time (#2717)
* fix(docker):force ngix to resolve upstream names at request time

* fix(docker): set resolver valid=0s to eliminate DNS cache window for request-time re-resolution

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/07bdb872-022f-4fd2-9fa8-d800a4ce34a7

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

* Update DNS resolver valid time and add upstreams

* fix the unit test error

* Remove upstream server configurations from nginx.conf

Removed upstream server configurations for gateway and frontend.

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-05-05 14:35:55 +08:00
Willem Jiang 8e48b7e85c fix(channels): preserve clarification conversation history across follow-up turns (#2444)
* fix(channels): preserve clarification conversation history across follow-up turns

Pin channel-triggered runs to the root checkpoint namespace and ensure thread_id is always present in configurable run config so follow-up replies resume the same conversation state.

Add regression coverage to channel tests:

assert checkpoint_ns/thread_id are passed in wait and stream paths
add an integration-style clarification flow test that verifies the second user reply continues prior context instead of starting a new session
This addresses history loss after ask_clarification interruptions (issue #2425).

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix(channels): copy configurable dict before injecting run-scoped fields

  When configurable was already a plain dict, _resolve_run_params mutated
  it in place, leaking checkpoint_ns and thread_id back into the shared
  session config. Always copy via dict() before mutating to prevent
  cross-user or cross-channel config pollution.

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-05-04 16:14:07 +08:00
131 changed files with 6744 additions and 733 deletions
+11
View File
@@ -48,3 +48,14 @@ INFOQUEST_API_KEY=your-infoquest-api-key
# Set to "false" to disable Swagger UI, ReDoc, and OpenAPI schema in production
# GATEWAY_ENABLE_DOCS=false
# ── Frontend SSR → Gateway wiring ─────────────────────────────────────────────
# The Next.js server uses these to reach the Gateway during SSR (auth checks,
# /api/* rewrites). They default to localhost values that match `make dev` and
# `make start`, so most local users do not need to set them.
#
# Override only when the Gateway is not on localhost:8001 (e.g. when the
# frontend and gateway run on different hosts, in containers with a service
# alias, or behind a different port). docker-compose already sets these.
# DEER_FLOW_INTERNAL_GATEWAY_BASE_URL=http://localhost:8001
# DEER_FLOW_TRUSTED_ORIGINS=http://localhost:3000,http://localhost:2026
+6 -2
View File
@@ -263,8 +263,10 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
- `present_files` - Make output files visible to user (only `/mnt/user-data/outputs`)
- `ask_clarification` - Request clarification (intercepted by ClarificationMiddleware → interrupts)
- `view_image` - Read image as base64 (added only if model supports vision)
- `setup_agent` - Bootstrap-only: persist a brand-new custom agent's `SOUL.md` and `config.yaml`. Bound only when `is_bootstrap=True`.
- `update_agent` - Custom-agent-only: persist self-updates to the current agent's `SOUL.md` / `config.yaml` from inside a normal chat (partial update + atomic write). Bound when `agent_name` is set and `is_bootstrap=False`.
4. **Subagent tool** (if enabled):
- `task` - Delegate to subagent (description, prompt, subagent_type, max_turns)
- `task` - Delegate to subagent (description, prompt, subagent_type)
**Community tools** (`packages/harness/deerflow/community/`):
- `tavily/` - Web search (5 results default) and web fetch (4KB limit)
@@ -354,10 +356,11 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
**Per-User Isolation**:
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
- Custom agent definitions (`SOUL.md` + `config.yaml`) are also per-user at `{base_dir}/users/{user_id}/agents/{agent_name}/`. The legacy shared layout `{base_dir}/agents/{agent_name}/` remains read-only fallback for unmigrated installations
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
- Absolute `storage_path` in config opts out of per-user isolation
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json`, `threads/`, and `agents/` into per-user layout. Supports `--dry-run` (preview changes) and `--user-id USER_ID` (assign unowned legacy data to a user, defaults to `default`).
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
@@ -517,6 +520,7 @@ Multi-file upload with automatic document conversion:
- Rejects directory inputs before copying so uploads stay all-or-nothing
- Reuses one conversion worker per request when called from an active event loop
- Files stored in thread-isolated directories
- Duplicate filenames in a single upload request are auto-renamed with `_N` suffixes so later files do not truncate earlier files
- Agent receives uploaded file list via `UploadsMiddleware`
See [docs/FILE_UPLOAD.md](docs/FILE_UPLOAD.md) for details.
+1 -1
View File
@@ -124,7 +124,7 @@ FastAPI application providing REST endpoints for frontend integration:
| `POST /api/memory/reload` | Force memory reload |
| `GET /api/memory/config` | Memory configuration |
| `GET /api/memory/status` | Combined config + data |
| `POST /api/threads/{id}/uploads` | Upload files (auto-converts PDF/PPT/Excel/Word to Markdown, rejects directory paths) |
| `POST /api/threads/{id}/uploads` | Upload files (auto-converts PDF/PPT/Excel/Word to Markdown, rejects directory paths, auto-renames duplicate filenames in one request) |
| `GET /api/threads/{id}/uploads/list` | List uploaded files |
| `DELETE /api/threads/{id}` | Delete DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
| `GET /api/threads/{id}/artifacts/{path}` | Serve generated artifacts |
+31 -2
View File
@@ -146,6 +146,13 @@ def _normalize_custom_agent_name(raw_value: str) -> str:
return normalized
def _strip_loop_warning_text(text: str) -> str:
"""Remove middleware-authored loop warning lines from display text."""
if "[LOOP DETECTED]" not in text:
return text
return "\n".join(line for line in text.splitlines() if "[LOOP DETECTED]" not in line).strip()
def _extract_response_text(result: dict | list) -> str:
"""Extract the last AI message text from a LangGraph runs.wait result.
@@ -155,7 +162,7 @@ def _extract_response_text(result: dict | list) -> str:
Handles special cases:
- Regular AI text responses
- Clarification interrupts (``ask_clarification`` tool messages)
- AI messages with tool_calls but no text content
- Strips loop-detection warnings attached to tool-call AI messages
"""
if isinstance(result, list):
messages = result
@@ -185,7 +192,12 @@ def _extract_response_text(result: dict | list) -> str:
# Regular AI message with text content
if msg_type == "ai":
content = msg.get("content", "")
has_tool_calls = bool(msg.get("tool_calls"))
if isinstance(content, str) and content:
if has_tool_calls:
content = _strip_loop_warning_text(content)
if not content:
continue
return content
# content can be a list of content blocks
if isinstance(content, list):
@@ -196,6 +208,8 @@ def _extract_response_text(result: dict | list) -> str:
elif isinstance(block, str):
parts.append(block)
text = "".join(parts)
if has_tool_calls:
text = _strip_loop_warning_text(text)
if text:
return text
return ""
@@ -589,6 +603,17 @@ class ChannelManager:
user_layer.get("config"),
)
configurable = run_config.get("configurable")
if isinstance(configurable, Mapping):
configurable = dict(configurable)
else:
configurable = {}
run_config["configurable"] = configurable
# Pin channel-triggered runs to the root graph namespace so follow-up
# turns continue from the same conversation checkpoint.
configurable["checkpoint_ns"] = ""
configurable["thread_id"] = thread_id
run_context = _merge_dicts(
DEFAULT_RUN_CONTEXT,
self._default_session.get("context"),
@@ -972,7 +997,11 @@ class ChannelManager:
try:
async with httpx.AsyncClient() as http:
resp = await http.get(f"{self._gateway_url}{path}", timeout=10)
resp = await http.get(
f"{self._gateway_url}{path}",
timeout=10,
headers=create_internal_auth_headers(),
)
resp.raise_for_status()
data = resp.json()
except Exception:
+112 -1
View File
@@ -4,8 +4,10 @@ Per RFC-001:
State-changing operations require CSRF protection.
"""
import os
import secrets
from collections.abc import Callable
from urllib.parse import urlsplit
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
@@ -19,7 +21,7 @@ CSRF_TOKEN_LENGTH = 64 # bytes
def is_secure_request(request: Request) -> bool:
"""Detect whether the original client request was made over HTTPS."""
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
return _request_scheme(request) == "https"
def generate_csrf_token() -> str:
@@ -61,6 +63,109 @@ def is_auth_endpoint(request: Request) -> bool:
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
def _host_with_optional_port(hostname: str, port: int | None, scheme: str) -> str:
"""Return normalized host[:port], omitting default ports."""
host = hostname.lower()
if ":" in host and not host.startswith("["):
host = f"[{host}]"
if port is None or (scheme == "http" and port == 80) or (scheme == "https" and port == 443):
return host
return f"{host}:{port}"
def _normalize_origin(origin: str) -> str | None:
"""Return a normalized scheme://host[:port] origin, or None for invalid input."""
try:
parsed = urlsplit(origin.strip())
port = parsed.port
except ValueError:
return None
scheme = parsed.scheme.lower()
if scheme not in {"http", "https"} or not parsed.hostname:
return None
# Browser Origin is only scheme/host/port. Reject URL-shaped or credentialed values.
if parsed.username or parsed.password or parsed.path or parsed.query or parsed.fragment:
return None
return f"{scheme}://{_host_with_optional_port(parsed.hostname, port, scheme)}"
def _configured_cors_origins() -> set[str]:
"""Return explicit configured browser origins that may call auth routes."""
origins = set()
for raw_origin in os.environ.get("GATEWAY_CORS_ORIGINS", "").split(","):
origin = raw_origin.strip()
if not origin or origin == "*":
continue
normalized = _normalize_origin(origin)
if normalized:
origins.add(normalized)
return origins
def _first_header_value(value: str | None) -> str | None:
"""Return the first value from a comma-separated proxy header."""
if not value:
return None
first = value.split(",", 1)[0].strip()
return first or None
def _forwarded_param(request: Request, name: str) -> str | None:
"""Extract a parameter from the first RFC 7239 Forwarded header entry."""
forwarded = _first_header_value(request.headers.get("forwarded"))
if not forwarded:
return None
for part in forwarded.split(";"):
key, sep, value = part.strip().partition("=")
if sep and key.lower() == name:
return value.strip().strip('"') or None
return None
def _request_scheme(request: Request) -> str:
"""Resolve the original request scheme from trusted proxy headers."""
scheme = _forwarded_param(request, "proto") or _first_header_value(request.headers.get("x-forwarded-proto")) or request.url.scheme
return scheme.lower()
def _request_origin(request: Request) -> str | None:
"""Build the origin for the URL the browser is targeting."""
scheme = _request_scheme(request)
host = _forwarded_param(request, "host") or _first_header_value(request.headers.get("x-forwarded-host")) or request.headers.get("host") or request.url.netloc
forwarded_port = _first_header_value(request.headers.get("x-forwarded-port"))
if forwarded_port and ":" not in host.rsplit("]", 1)[-1]:
host = f"{host}:{forwarded_port}"
return _normalize_origin(f"{scheme}://{host}")
def is_allowed_auth_origin(request: Request) -> bool:
"""Allow auth POSTs only from the same origin or explicit configured origins.
Login/register/initialize are exempt from the double-submit token because
first-time browser clients do not have a CSRF token yet. They still create
a session cookie, so browser requests with a hostile Origin header must be
rejected to prevent login CSRF / session fixation. Requests without Origin
are allowed for non-browser clients such as curl and mobile integrations.
"""
origin = request.headers.get("origin")
if not origin:
return True
normalized_origin = _normalize_origin(origin)
if normalized_origin is None:
return False
request_origin = _request_origin(request)
return normalized_origin in _configured_cors_origins() or (request_origin is not None and normalized_origin == request_origin)
class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
@@ -70,6 +175,12 @@ class CSRFMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable) -> Response:
_is_auth = is_auth_endpoint(request)
if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request):
return JSONResponse(
status_code=403,
content={"detail": "Cross-site auth request denied."},
)
if should_check_csrf(request) and not _is_auth:
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
header_token = request.headers.get(CSRF_HEADER_NAME)
+43 -18
View File
@@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
from deerflow.config.agents_api_config import get_agents_api_config
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["agents"])
@@ -86,11 +87,11 @@ def _require_agents_api_enabled() -> None:
)
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False, *, user_id: str | None = None) -> AgentResponse:
"""Convert AgentConfig to AgentResponse."""
soul: str | None = None
if include_soul:
soul = load_agent_soul(agent_cfg.name) or ""
soul = load_agent_soul(agent_cfg.name, user_id=user_id) or ""
return AgentResponse(
name=agent_cfg.name,
@@ -116,9 +117,10 @@ async def list_agents() -> AgentsListResponse:
"""
_require_agents_api_enabled()
user_id = get_effective_user_id()
try:
agents = list_custom_agents()
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
agents = list_custom_agents(user_id=user_id)
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True, user_id=user_id) for a in agents])
except Exception as e:
logger.error(f"Failed to list agents: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
@@ -144,7 +146,12 @@ async def check_agent_name(name: str) -> dict:
_require_agents_api_enabled()
_validate_agent_name(name)
normalized = _normalize_agent_name(name)
available = not get_paths().agent_dir(normalized).exists()
user_id = get_effective_user_id()
paths = get_paths()
# Treat the name as taken if either the per-user path or the legacy shared
# path holds an agent — picking a name that collides with an unmigrated
# legacy agent would shadow the legacy entry once migration runs.
available = not paths.user_agent_dir(user_id, normalized).exists() and not paths.agent_dir(normalized).exists()
return {"available": available, "name": normalized}
@@ -169,10 +176,11 @@ async def get_agent(name: str) -> AgentResponse:
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
user_id = get_effective_user_id()
try:
agent_cfg = load_agent_config(name)
return _agent_config_to_response(agent_cfg, include_soul=True)
agent_cfg = load_agent_config(name, user_id=user_id)
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
except Exception as e:
@@ -202,10 +210,13 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
_require_agents_api_enabled()
_validate_agent_name(request.name)
normalized_name = _normalize_agent_name(request.name)
user_id = get_effective_user_id()
paths = get_paths()
agent_dir = get_paths().agent_dir(normalized_name)
agent_dir = paths.user_agent_dir(user_id, normalized_name)
legacy_dir = paths.agent_dir(normalized_name)
if agent_dir.exists():
if agent_dir.exists() or legacy_dir.exists():
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
try:
@@ -232,8 +243,8 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
agent_cfg = load_agent_config(normalized_name)
return _agent_config_to_response(agent_cfg, include_soul=True)
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
except HTTPException:
raise
@@ -267,13 +278,20 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
user_id = get_effective_user_id()
try:
agent_cfg = load_agent_config(name)
agent_cfg = load_agent_config(name, user_id=user_id)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
agent_dir = get_paths().agent_dir(name)
paths = get_paths()
agent_dir = paths.user_agent_dir(user_id, name)
if not agent_dir.exists() and paths.agent_dir(name).exists():
raise HTTPException(
status_code=409,
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before updating."),
)
try:
# Update config if any config fields changed
@@ -314,8 +332,8 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
logger.info(f"Updated agent '{name}'")
refreshed_cfg = load_agent_config(name)
return _agent_config_to_response(refreshed_cfg, include_soul=True)
refreshed_cfg = load_agent_config(name, user_id=user_id)
return _agent_config_to_response(refreshed_cfg, include_soul=True, user_id=user_id)
except HTTPException:
raise
@@ -402,15 +420,22 @@ async def delete_agent(name: str) -> None:
name: The agent name.
Raises:
HTTPException: 404 if agent not found.
HTTPException: 404 if no per-user copy exists; 409 if only a legacy
shared copy exists (suggesting the migration script).
"""
_require_agents_api_enabled()
_validate_agent_name(name)
name = _normalize_agent_name(name)
agent_dir = get_paths().agent_dir(name)
user_id = get_effective_user_id()
paths = get_paths()
agent_dir = paths.user_agent_dir(user_id, name)
if not agent_dir.exists():
if paths.agent_dir(name).exists():
raise HTTPException(
status_code=409,
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
)
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
try:
+24 -3
View File
@@ -68,6 +68,27 @@ class RunResponse(BaseModel):
updated_at: str = ""
class ThreadTokenUsageModelBreakdown(BaseModel):
tokens: int = 0
runs: int = 0
class ThreadTokenUsageCallerBreakdown(BaseModel):
lead_agent: int = 0
subagent: int = 0
middleware: int = 0
class ThreadTokenUsageResponse(BaseModel):
thread_id: str
total_tokens: int = 0
total_input_tokens: int = 0
total_output_tokens: int = 0
total_runs: int = 0
by_model: dict[str, ThreadTokenUsageModelBreakdown] = Field(default_factory=dict)
by_caller: ThreadTokenUsageCallerBreakdown = Field(default_factory=ThreadTokenUsageCallerBreakdown)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@@ -368,10 +389,10 @@ async def list_run_events(
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
@router.get("/{thread_id}/token-usage")
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
@require_permission("threads", "read", owner_check=True)
async def thread_token_usage(thread_id: str, request: Request) -> dict:
async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse:
"""Thread-level token usage aggregation."""
run_store = get_run_store(request)
agg = await run_store.aggregate_tokens_by_thread(thread_id)
return {"thread_id": thread_id, **agg}
return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
+9 -1
View File
@@ -16,6 +16,7 @@ from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provi
from deerflow.uploads.manager import (
PathTraversalError,
UnsafeUploadPathError,
claim_unique_filename,
delete_file_safe,
enrich_file_listing,
ensure_uploads_dir,
@@ -192,6 +193,10 @@ async def upload_files(
sandbox_sync_targets = []
skipped_files = []
total_size = 0
# Track filenames within this request so duplicate form parts do not
# silently truncate each other. Existing uploads keep the historical
# overwrite behavior for a single replacement upload.
seen_filenames: set[str] = set()
sandbox_provider = get_sandbox_provider()
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
@@ -208,7 +213,8 @@ async def upload_files(
continue
try:
safe_filename = normalize_filename(file.filename)
original_filename = normalize_filename(file.filename)
safe_filename = claim_unique_filename(original_filename, seen_filenames)
except ValueError:
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
continue
@@ -236,6 +242,8 @@ async def upload_files(
"virtual_path": virtual_path,
"artifact_url": upload_artifact_url(thread_id, safe_filename),
}
if safe_filename != original_filename:
file_info["original_filename"] = original_filename
logger.info(f"Saved file: {safe_filename} ({file_size} bytes) to {file_info['path']}")
+19
View File
@@ -136,6 +136,24 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
runtime_context.setdefault(key, context[key])
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
"""Stamp the authenticated user into the run context for background tools.
Tool execution may happen after the request handler has returned, so tools
that persist user-scoped files should not rely only on ambient ContextVars.
The value comes from server-side auth state, never from client context.
"""
user = getattr(request.state, "user", None)
user_id = getattr(user, "id", None)
if user_id is None:
return
runtime_context = config.setdefault("context", {})
if isinstance(runtime_context, dict):
runtime_context["user_id"] = str(user_id)
def resolve_agent_factory(assistant_id: str | None):
"""Resolve the agent factory callable from config.
@@ -288,6 +306,7 @@ async def start_run(
# that carries agent configuration (model_name, thinking_enabled, etc.).
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
merge_run_context_overrides(config, getattr(body, "context", None))
inject_authenticated_user_context(config, request)
stream_modes = normalize_stream_modes(body.stream_mode)
+20
View File
@@ -79,7 +79,9 @@ async def main():
from langgraph.runtime import Runtime
from deerflow.agents import make_lead_agent
from deerflow.config.paths import get_paths
from deerflow.mcp import initialize_mcp_tools
from deerflow.runtime.user_context import get_effective_user_id
# Initialize MCP tools at startup
try:
@@ -113,6 +115,8 @@ async def main():
print("Tip: `uv sync --group dev` to enable arrow-key & history support")
print("=" * 50)
seen_artifacts: set[str] = set()
while True:
try:
if session:
@@ -134,6 +138,22 @@ async def main():
last_message = result["messages"][-1]
print(f"\nAgent: {last_message.content}")
# Show files presented to the user this turn (new artifacts only)
artifacts = result.get("artifacts") or []
new_artifacts = [p for p in artifacts if p not in seen_artifacts]
if new_artifacts:
thread_id = config["configurable"]["thread_id"]
user_id = get_effective_user_id()
paths = get_paths()
print("\n[Presented files]")
for virtual in new_artifacts:
try:
physical = paths.resolve_virtual_path(thread_id, virtual, user_id=user_id)
print(f" - {virtual}\n{physical}")
except ValueError as exc:
print(f" - {virtual} (failed to resolve physical path: {exc})")
seen_artifacts.update(new_artifacts)
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break
@@ -173,7 +173,7 @@ def _assemble_from_features(
9. MemoryMiddleware (memory feature)
10. ViewImageMiddleware (vision feature)
11. SubagentLimitMiddleware (subagent feature)
12. LoopDetectionMiddleware (always)
12. LoopDetectionMiddleware (loop_detection feature)
13. ClarificationMiddleware (always last)
Two-phase ordering:
@@ -272,10 +272,15 @@ def _assemble_from_features(
extra_tools.append(task_tool)
# --- [12] LoopDetection (always) ---
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
# --- [12] LoopDetection ---
if feat.loop_detection is not False:
if isinstance(feat.loop_detection, AgentMiddleware):
chain.append(feat.loop_detection)
else:
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.config.loop_detection_config import LoopDetectionConfig
chain.append(LoopDetectionMiddleware())
chain.append(LoopDetectionMiddleware.from_config(LoopDetectionConfig()))
# --- [13] Clarification (always last among built-ins) ---
chain.append(ClarificationMiddleware())
@@ -31,6 +31,7 @@ class RuntimeFeatures:
vision: bool | AgentMiddleware = False
auto_title: bool | AgentMiddleware = False
guardrail: Literal[False] | AgentMiddleware = False
loop_detection: bool | AgentMiddleware = True
# ---------------------------------------------------------------------------
@@ -20,6 +20,8 @@ from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.models import create_chat_model
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
from deerflow.skills.types import Skill
logger = logging.getLogger(__name__)
@@ -256,6 +258,12 @@ def _build_middlewares(
resolved_app_config = app_config or get_app_config()
middlewares = build_lead_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
# Always inject current date (and optionally memory) as <system-reminder> into the
# first HumanMessage to keep the system prompt fully static for prefix-cache reuse.
from deerflow.agents.middlewares.dynamic_context_middleware import DynamicContextMiddleware
middlewares.append(DynamicContextMiddleware(agent_name=agent_name, app_config=resolved_app_config))
# Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
if summarization_middleware is not None:
@@ -297,7 +305,9 @@ def _build_middlewares(
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
# LoopDetectionMiddleware — detect and break repetitive tool call loops
middlewares.append(LoopDetectionMiddleware())
loop_detection_config = resolved_app_config.loop_detection
if loop_detection_config.enabled:
middlewares.append(LoopDetectionMiddleware.from_config(loop_detection_config))
# Inject custom middlewares before ClarificationMiddleware
if custom_middlewares:
@@ -308,6 +318,28 @@ def _build_middlewares(
return middlewares
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
if is_bootstrap:
return {"bootstrap"}
if agent_config and agent_config.skills is not None:
return set(agent_config.skills)
return None
def _load_enabled_skills_for_tool_policy(available_skills: set[str] | None, *, app_config: AppConfig) -> list[Skill]:
try:
from deerflow.agents.lead_agent.prompt import get_enabled_skills_for_config
skills = get_enabled_skills_for_config(app_config)
except Exception:
logger.exception("Failed to load skills for allowed-tools policy")
raise
if available_skills is None:
return skills
return [skill for skill in skills if skill.name in available_skills]
def make_lead_agent(config: RunnableConfig):
"""LangGraph graph factory; keep the signature compatible with LangGraph Server."""
runtime_config = _get_runtime_config(config)
@@ -318,7 +350,7 @@ def make_lead_agent(config: RunnableConfig):
def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
# Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent
from deerflow.tools.builtins import setup_agent, update_agent
cfg = _get_runtime_config(config)
resolved_app_config = app_config
@@ -333,6 +365,7 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
agent_name = validate_agent_name(cfg.get("agent_name"))
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
available_skills = _available_skill_names(agent_config, is_bootstrap)
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
agent_model_name = agent_config.model if agent_config and agent_config.model else None
@@ -371,15 +404,18 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
"is_plan_mode": is_plan_mode,
"subagent_enabled": subagent_enabled,
"tool_groups": agent_config.tool_groups if agent_config else None,
"available_skills": ["bootstrap"] if is_bootstrap else (agent_config.skills if agent_config and agent_config.skills is not None else None),
"available_skills": sorted(available_skills) if available_skills is not None else None,
}
)
skills_for_tool_policy = _load_enabled_skills_for_tool_policy(available_skills, app_config=resolved_app_config)
if is_bootstrap:
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent]
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent],
tools=filter_tools_by_skill_allowed_tools(tools, skills_for_tool_policy),
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled,
@@ -390,15 +426,14 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
state_schema=ThreadState,
)
# Custom agents can update their own SOUL.md / config via update_agent.
# The default agent (no agent_name) does not see this tool.
extra_tools = [update_agent] if agent_name else []
# Default lead agent (unchanged behavior)
tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config)
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config),
tools=get_available_tools(
model_name=model_name,
groups=agent_config.tool_groups if agent_config else None,
subagent_enabled=subagent_enabled,
app_config=resolved_app_config,
),
tools=filter_tools_by_skill_allowed_tools(tools + extra_tools, skills_for_tool_policy),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled,
@@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio
import logging
import threading
from datetime import datetime
from functools import lru_cache
from typing import TYPE_CHECKING
@@ -20,6 +19,7 @@ logger = logging.getLogger(__name__)
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
_enabled_skills_lock = threading.Lock()
_enabled_skills_cache: list[Skill] | None = None
_enabled_skills_by_config_cache: dict[int, tuple[object, list[Skill]]] = {}
_enabled_skills_refresh_active = False
_enabled_skills_refresh_version = 0
_enabled_skills_refresh_event = threading.Event()
@@ -84,6 +84,7 @@ def _invalidate_enabled_skills_cache() -> threading.Event:
_get_cached_skills_prompt_section.cache_clear()
with _enabled_skills_lock:
_enabled_skills_cache = None
_enabled_skills_by_config_cache.clear()
_enabled_skills_refresh_version += 1
_enabled_skills_refresh_event.clear()
if _enabled_skills_refresh_active:
@@ -107,6 +108,15 @@ def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_W
def _get_enabled_skills():
return get_cached_enabled_skills()
def get_cached_enabled_skills() -> list[Skill]:
"""Return the cached enabled-skills list, kicking off a background refresh on miss.
Safe to call from request paths: never blocks on disk I/O. Returns an empty
list on cache miss; the next call will see the warmed result.
"""
with _enabled_skills_lock:
cached = _enabled_skills_cache
@@ -117,17 +127,29 @@ def _get_enabled_skills():
return []
def _get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
def get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
"""Return enabled skills using the caller's config source.
When a concrete ``app_config`` is supplied, bypass the global enabled-skills
cache so the skill list and skill paths are resolved from the same config
object. This keeps request-scoped config injection consistent even while the
release branch still supports global fallback paths.
When a concrete ``app_config`` is supplied, cache the loaded skills by that
config object's identity so request-scoped config injection still resolves
skill paths from the matching config without rescanning storage on every
agent factory call.
"""
if app_config is None:
return _get_enabled_skills()
return list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
cache_key = id(app_config)
with _enabled_skills_lock:
cached = _enabled_skills_by_config_cache.get(cache_key)
if cached is not None:
cached_config, cached_skills = cached
if cached_config is app_config:
return list(cached_skills)
skills = list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
with _enabled_skills_lock:
_enabled_skills_by_config_cache[cache_key] = (app_config, skills)
return list(skills)
def _skill_mutability_label(category: SkillCategory | str) -> str:
@@ -344,8 +366,7 @@ You are {agent_name}, an open-source super agent.
</role>
{soul}
{memory_context}
{self_update_section}
<thinking_style>
- Think concisely and strategically about the user's request BEFORE taking action
- Break down the task: What is clear? What is ambiguous? What is missing?
@@ -604,7 +625,7 @@ You have access to skills that provide optimized workflows for specific tasks. E
def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_config: AppConfig | None = None) -> str:
"""Generate the skills prompt section with available skills list."""
skills = _get_enabled_skills_for_config(app_config)
skills = get_enabled_skills_for_config(app_config)
if app_config is None:
try:
@@ -643,6 +664,26 @@ def get_agent_soul(agent_name: str | None) -> str:
return ""
def _build_self_update_section(agent_name: str | None) -> str:
"""Prompt block that teaches the custom agent to persist self-updates via update_agent."""
if not agent_name:
return ""
return f"""<self_update>
You are running as the custom agent **{agent_name}** with a persisted SOUL.md and config.yaml.
When the user asks you to update your own description, personality, behaviour, skill set, tool groups, or default model,
you MUST persist the change with the `update_agent` tool. Do NOT use `bash`, `write_file`, or any sandbox tool to edit
SOUL.md or config.yaml — those write into a temporary sandbox/tool workspace and the changes will be lost on the next turn.
Rules:
- Always pass the FULL replacement text for `soul` (no patch semantics). Start from your current SOUL above and apply the user's edits.
- Only pass the fields that should change. Omit the others to preserve them.
- Pass `skills=[]` to disable all skills, or omit `skills` to keep the existing whitelist.
- After `update_agent` returns successfully, tell the user the change is persisted and will take effect on the next turn.
</self_update>
"""
def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> str:
"""Generate <available-deferred-tools> block for the system prompt.
@@ -732,9 +773,6 @@ def apply_prompt_template(
available_skills: set[str] | None = None,
app_config: AppConfig | None = None,
) -> str:
# Get memory context
memory_context = _get_memory_context(agent_name, app_config=app_config)
# Include subagent section only if enabled (from runtime parameter)
n = max_concurrent_subagents
subagent_section = _build_subagent_section(n, app_config=app_config) if subagent_enabled else ""
@@ -768,17 +806,18 @@ def apply_prompt_template(
custom_mounts_section = _build_custom_mounts_section(app_config=app_config)
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
# Format the prompt with dynamic skills and memory
prompt = SYSTEM_PROMPT_TEMPLATE.format(
# Build and return the fully static system prompt.
# Memory and current date are injected per-turn via DynamicContextMiddleware
# as a <system-reminder> in the first HumanMessage, keeping this prompt
# identical across users and sessions for maximum prefix-cache reuse.
return SYSTEM_PROMPT_TEMPLATE.format(
agent_name=agent_name or "DeerFlow 2.0",
soul=get_agent_soul(agent_name),
self_update_section=_build_self_update_section(agent_name),
skills_section=skills_section,
deferred_tools_section=deferred_tools_section,
memory_context=memory_context,
subagent_section=subagent_section,
subagent_reminder=subagent_reminder,
subagent_thinking=subagent_thinking,
acp_section=acp_and_mounts_section,
)
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"
@@ -0,0 +1,204 @@
"""Middleware to inject dynamic context (memory, current date) as a system-reminder.
The system prompt is kept fully static for maximum prefix-cache reuse across users
and sessions. The current date is always injected. Per-user memory is also injected
when ``memory.injection_enabled`` is True in the app config. Both are delivered once
per conversation as a dedicated <system-reminder> HumanMessage inserted before the
first user message (frozen-snapshot pattern).
When a conversation spans midnight the middleware detects the date change and injects
a lightweight date-update reminder as a separate HumanMessage before the current turn.
This correction is persisted so subsequent turns on the new day see a consistent history
and do not re-inject.
Reminder format:
<system-reminder>
<memory>...</memory>
<current_date>2026-05-08, Friday</current_date>
</system-reminder>
Date-update format:
<system-reminder>
<current_date>2026-05-09, Saturday</current_date>
</system-reminder>
"""
from __future__ import annotations
import logging
import re
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, override
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
_SUMMARY_MESSAGE_NAME = "summary"
def _extract_date(content: str) -> str | None:
"""Return the first <current_date> value found in *content*, or None."""
m = _DATE_RE.search(content)
return m.group(1) if m else None
def is_dynamic_context_reminder(message: object) -> bool:
"""Return whether *message* is a hidden dynamic-context reminder."""
return isinstance(message, HumanMessage) and bool(message.additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY))
def _last_injected_date(messages: list) -> str | None:
"""Scan messages in reverse and return the most recently injected date.
Detection uses the ``dynamic_context_reminder`` additional_kwargs flag rather
than content substring matching, so user messages containing ``<system-reminder>``
are not mistakenly treated as injected reminders.
"""
for msg in reversed(messages):
if is_dynamic_context_reminder(msg):
content_str = msg.content if isinstance(msg.content, str) else str(msg.content)
return _extract_date(content_str)
return None
def _is_user_injection_target(message: object) -> bool:
"""Return whether *message* can receive a dynamic-context reminder."""
return isinstance(message, HumanMessage) and not is_dynamic_context_reminder(message) and message.name != _SUMMARY_MESSAGE_NAME
class DynamicContextMiddleware(AgentMiddleware):
"""Inject memory and current date into HumanMessages as a <system-reminder>.
First turn
----------
Prepends a full system-reminder (memory + date) to the first HumanMessage and
persists it (same message ID). The first message is then frozen for the whole
session — its content never changes again, so the prefix cache can hit on every
subsequent turn.
Midnight crossing
-----------------
If the conversation spans midnight, the current date differs from the date that
was injected earlier. In that case a lightweight date-update reminder is prepended
to the **current** (last) HumanMessage and persisted. Subsequent turns on the new
day see the corrected date in history and skip re-injection.
"""
def __init__(self, agent_name: str | None = None, *, app_config: AppConfig | None = None):
super().__init__()
self._agent_name = agent_name
self._app_config = app_config
def _build_full_reminder(self) -> str:
from deerflow.agents.lead_agent.prompt import _get_memory_context
# Memory injection is gated by injection_enabled; date is always included.
injection_enabled = self._app_config.memory.injection_enabled if self._app_config else True
memory_context = _get_memory_context(self._agent_name, app_config=self._app_config) if injection_enabled else ""
current_date = datetime.now().strftime("%Y-%m-%d, %A")
lines: list[str] = ["<system-reminder>"]
if memory_context:
lines.append(memory_context.strip())
lines.append("") # blank line separating memory from date
lines.append(f"<current_date>{current_date}</current_date>")
lines.append("</system-reminder>")
return "\n".join(lines)
def _build_date_update_reminder(self) -> str:
current_date = datetime.now().strftime("%Y-%m-%d, %A")
return "\n".join(
[
"<system-reminder>",
f"<current_date>{current_date}</current_date>",
"</system-reminder>",
]
)
@staticmethod
def _make_reminder_and_user_messages(original: HumanMessage, reminder_content: str) -> tuple[HumanMessage, HumanMessage]:
"""Return (reminder_msg, user_msg) using the ID-swap technique.
reminder_msg takes the original message's ID so that add_messages replaces it
in-place (preserving position). user_msg carries the original content with a
derived ``{id}__user`` ID and is appended immediately after by add_messages.
If the original message has no ID a stable UUID is generated so the derived
``{id}__user`` ID never collapses to the ambiguous ``None__user`` string.
"""
stable_id = original.id or str(uuid.uuid4())
reminder_msg = HumanMessage(
content=reminder_content,
id=stable_id,
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
)
user_msg = HumanMessage(
content=original.content,
id=f"{stable_id}__user",
name=original.name,
additional_kwargs=original.additional_kwargs,
)
return reminder_msg, user_msg
def _inject(self, state) -> dict | None:
messages = list(state.get("messages", []))
if not messages:
return None
current_date = datetime.now().strftime("%Y-%m-%d, %A")
last_date = _last_injected_date(messages)
logger.debug(
"DynamicContextMiddleware._inject: msg_count=%d last_date=%r current_date=%r",
len(messages),
last_date,
current_date,
)
if last_date is None:
# ── First turn: inject full reminder as a separate HumanMessage ─────
first_idx = next((i for i, m in enumerate(messages) if _is_user_injection_target(m)), None)
if first_idx is None:
return None
full_reminder = self._build_full_reminder()
logger.info(
"DynamicContextMiddleware: injecting full reminder (len=%d, has_memory=%s) into first HumanMessage id=%r",
len(full_reminder),
"<memory>" in full_reminder,
messages[first_idx].id,
)
reminder_msg, user_msg = self._make_reminder_and_user_messages(messages[first_idx], full_reminder)
return {"messages": [reminder_msg, user_msg]}
if last_date == current_date:
# ── Same day: nothing to do ──────────────────────────────────────────
return None
# ── Midnight crossed: inject date-update reminder as a separate HumanMessage ──
last_human_idx = next((i for i in reversed(range(len(messages))) if _is_user_injection_target(messages[i])), None)
if last_human_idx is None:
return None
reminder_msg, user_msg = self._make_reminder_and_user_messages(messages[last_human_idx], self._build_date_update_reminder())
logger.info("DynamicContextMiddleware: midnight crossing detected — injected date update before current turn")
return {"messages": [reminder_msg, user_msg]}
@override
def before_agent(self, state, runtime: Runtime) -> dict | None:
return self._inject(state)
@override
async def abefore_agent(self, state, runtime: Runtime) -> dict | None:
return self._inject(state)
@@ -12,19 +12,23 @@ Detection strategy:
response so the agent is forced to produce a final text answer.
"""
from __future__ import annotations
import hashlib
import json
import logging
import threading
from collections import OrderedDict, defaultdict
from copy import deepcopy
from typing import override
from typing import TYPE_CHECKING, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
if TYPE_CHECKING:
from deerflow.config.loop_detection_config import LoopDetectionConfig
logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor
@@ -140,6 +144,9 @@ _TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
"""Detects and breaks repetitive tool call loops.
Threshold parameters are validated upstream by :class:`LoopDetectionConfig`;
construct via :meth:`from_config` to ensure values pass Pydantic validation.
Args:
warn_threshold: Number of identical tool call sets before injecting
a warning message. Default: 3.
@@ -155,6 +162,14 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
Default: 30.
tool_freq_hard_limit: Number of calls to the same tool type before
forcing a stop. Default: 50.
tool_freq_overrides: Per-tool overrides for frequency thresholds,
keyed by tool name. Each value is a ``(warn, hard_limit)`` tuple
that replaces ``tool_freq_warn`` / ``tool_freq_hard_limit`` for
that specific tool. Tools not listed here fall back to the global
thresholds. Useful for raising limits on intentionally
high-frequency tools (e.g. ``bash`` in batch pipelines) without
weakening protection on all other tools. Default: ``None``
(no overrides).
"""
def __init__(
@@ -165,6 +180,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
tool_freq_overrides: dict[str, tuple[int, int]] | None = None,
):
super().__init__()
self.warn_threshold = warn_threshold
@@ -173,14 +189,26 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
self.max_tracked_threads = max_tracked_threads
self.tool_freq_warn = tool_freq_warn
self.tool_freq_hard_limit = tool_freq_hard_limit
self._tool_freq_overrides: dict[str, tuple[int, int]] = tool_freq_overrides or {}
self._lock = threading.Lock()
# Per-thread tracking using OrderedDict for LRU eviction
self._history: OrderedDict[str, list[str]] = OrderedDict()
self._warned: dict[str, set[str]] = defaultdict(set)
# Per-thread, per-tool-type cumulative call counts
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
@classmethod
def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware:
"""Construct from a Pydantic-validated config, trusting its validation."""
return cls(
warn_threshold=config.warn_threshold,
hard_limit=config.hard_limit,
window_size=config.window_size,
max_tracked_threads=config.max_tracked_threads,
tool_freq_warn=config.tool_freq_warn,
tool_freq_hard_limit=config.tool_freq_hard_limit,
tool_freq_overrides={name: (o.warn, o.hard_limit) for name, o in config.tool_freq_overrides.items()},
)
def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
@@ -280,7 +308,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
freq[name] += 1
tc_count = freq[name]
if tc_count >= self.tool_freq_hard_limit:
if name in self._tool_freq_overrides:
eff_warn, eff_hard = self._tool_freq_overrides[name]
else:
eff_warn, eff_hard = self.tool_freq_warn, self.tool_freq_hard_limit
if tc_count >= eff_hard:
logger.error(
"Tool frequency hard limit reached — forcing stop",
extra={
@@ -291,7 +324,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
)
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
if tc_count >= self.tool_freq_warn:
if tc_count >= eff_warn:
warned = self._tool_freq_warned[thread_id]
if name not in warned:
warned.add(name)
@@ -356,13 +389,30 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
return {"messages": [stripped_msg]}
if warning:
# Inject as HumanMessage instead of SystemMessage to avoid
# Anthropic's "multiple non-consecutive system messages" error.
# Anthropic models require system messages only at the start of
# the conversation; injecting one mid-conversation crashes
# langchain_anthropic's _format_messages(). HumanMessage works
# with all providers. See #1299.
return {"messages": [HumanMessage(content=warning, name="loop_warning")]}
# WORKAROUND for v2.0-m1 — see #2724.
#
# Append the warning to the AIMessage content instead of
# injecting a separate HumanMessage. Inserting any non-tool
# message between an AIMessage(tool_calls=...) and its
# ToolMessage responses breaks OpenAI/Moonshot strict pairing
# validation ("tool_call_ids did not have response messages")
# because the tools node has not run yet at after_model time.
# tool_calls are preserved so the tools node still executes.
#
# This is a temporary mitigation: mutating an existing
# AIMessage to carry framework-authored text leaks loop-warning
# text into downstream consumers (MemoryMiddleware fact
# extraction, TitleMiddleware, telemetry, model replay) as if
# the model said it. The proper fix is to defer warning
# injection from after_model to wrap_model_call so every prior
# ToolMessage is already in the request — see RFC #2517 (which
# lists "loop intervention does not leave invalid
# tool-call/tool-message state" as acceptance criteria) and
# the prototype on `fix/loop-detection-tool-call-pairing`.
messages = state.get("messages", [])
last_msg = messages[-1]
patched_msg = last_msg.model_copy(update={"content": self._append_text(last_msg.content, warning)})
return {"messages": [patched_msg]}
return None
@@ -7,6 +7,7 @@ from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
logger = logging.getLogger(__name__)
@@ -63,7 +64,7 @@ class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
updated_msg = clone_ai_message_with_tool_calls(last_msg, truncated_tool_calls)
return {"messages": [updated_msg]}
@override
@@ -14,6 +14,9 @@ from langgraph.config import get_config
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
logger = logging.getLogger(__name__)
@@ -78,10 +81,7 @@ def _clone_ai_message(
content: Any | None = None,
) -> AIMessage:
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
update: dict[str, Any] = {"tool_calls": tool_calls}
if content is not None:
update["content"] = content
return message.model_copy(update=update)
return clone_ai_message_with_tool_calls(message, tool_calls, content=content)
@dataclass
@@ -136,6 +136,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
return None
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = self._create_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
@@ -161,6 +162,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
return None
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = await self._acreate_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
@@ -180,6 +182,24 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
"""
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
def _preserve_dynamic_context_reminders(
self,
messages_to_summarize: list[AnyMessage],
preserved_messages: list[AnyMessage],
) -> tuple[list[AnyMessage], list[AnyMessage]]:
"""Keep hidden dynamic-context reminders out of summary compression.
These reminders carry the current date and optional memory. If summarization
removes them, DynamicContextMiddleware can mistake the summary HumanMessage
for the first user message and inject the reminder in the wrong place.
"""
reminders = [msg for msg in messages_to_summarize if is_dynamic_context_reminder(msg)]
if not reminders:
return messages_to_summarize, preserved_messages
remaining = [msg for msg in messages_to_summarize if not is_dynamic_context_reminder(msg)]
return remaining, reminders + preserved_messages
def _partition_with_skill_rescue(
self,
messages: list[AnyMessage],
@@ -9,6 +9,7 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
from deerflow.config.title_config import get_title_config
from deerflow.models import create_chat_model
@@ -61,6 +62,10 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return ""
@staticmethod
def _is_user_message_for_title(message: object) -> bool:
return getattr(message, "type", None) == "human" and not is_dynamic_context_reminder(message)
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
"""Check if we should generate a title for this thread."""
config = self._get_title_config()
@@ -77,7 +82,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return False
# Count user and assistant messages
user_messages = [m for m in messages if m.type == "human"]
user_messages = [m for m in messages if self._is_user_message_for_title(m)]
assistant_messages = [m for m in messages if m.type == "ai"]
# Generate title after first complete exchange
@@ -91,7 +96,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
config = self._get_title_config()
messages = state.get("messages", [])
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
user_msg_content = next((m.content for m in messages if self._is_user_message_for_title(m)), "")
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
user_msg = self._normalize_content(user_msg_content)
@@ -267,11 +267,20 @@ class TokenUsageMiddleware(AgentMiddleware):
usage = getattr(last, "usage_metadata", None)
if usage:
input_token_details = usage.get("input_token_details") or {}
output_token_details = usage.get("output_token_details") or {}
detail_parts = []
if input_token_details:
detail_parts.append(f"input_token_details={input_token_details}")
if output_token_details:
detail_parts.append(f"output_token_details={output_token_details}")
detail_suffix = f" {' '.join(detail_parts)}" if detail_parts else ""
logger.info(
"LLM token usage: input=%s output=%s total=%s",
"LLM token usage: input=%s output=%s total=%s%s",
usage.get("input_tokens", "?"),
usage.get("output_tokens", "?"),
usage.get("total_tokens", "?"),
detail_suffix,
)
todos = state.get("todos") or []
@@ -0,0 +1,50 @@
"""Helpers for keeping AIMessage tool-call metadata consistent."""
from __future__ import annotations
from typing import Any
from langchain_core.messages import AIMessage
def _raw_tool_call_id(raw_tool_call: Any) -> str | None:
if not isinstance(raw_tool_call, dict):
return None
raw_id = raw_tool_call.get("id")
return raw_id if isinstance(raw_id, str) and raw_id else None
def clone_ai_message_with_tool_calls(
message: AIMessage,
tool_calls: list[dict[str, Any]],
*,
content: Any | None = None,
) -> AIMessage:
"""Clone an AIMessage while keeping raw provider tool-call metadata in sync."""
kept_ids = {tc["id"] for tc in tool_calls if isinstance(tc.get("id"), str) and tc["id"]}
update: dict[str, Any] = {"tool_calls": tool_calls}
if content is not None:
update["content"] = content
additional_kwargs = dict(getattr(message, "additional_kwargs", {}) or {})
raw_tool_calls = additional_kwargs.get("tool_calls")
if isinstance(raw_tool_calls, list):
synced_raw_tool_calls = [raw_tc for raw_tc in raw_tool_calls if _raw_tool_call_id(raw_tc) in kept_ids]
if synced_raw_tool_calls:
additional_kwargs["tool_calls"] = synced_raw_tool_calls
else:
additional_kwargs.pop("tool_calls", None)
if not tool_calls:
additional_kwargs.pop("function_call", None)
update["additional_kwargs"] = additional_kwargs
response_metadata = dict(getattr(message, "response_metadata", {}) or {})
if not tool_calls and response_metadata.get("finish_reason") == "tool_calls":
response_metadata["finish_reason"] = "stop"
update["response_metadata"] = response_metadata
return message.model_copy(update=update)
@@ -80,6 +80,7 @@ class AioSandboxProvider(SandboxProvider):
port: 8080 # Base port for local containers
container_prefix: deer-flow-sandbox
idle_timeout: 600 # Idle timeout in seconds (0 to disable)
auto_restart: true # Restart crashed containers automatically
replicas: 3 # Max concurrent sandbox containers (LRU eviction when exceeded)
mounts: # Volume mounts for local containers
- host_path: /path/on/host
@@ -164,12 +165,14 @@ class AioSandboxProvider(SandboxProvider):
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
replicas = getattr(sandbox_config, "replicas", None)
auto_restart = getattr(sandbox_config, "auto_restart", True)
return {
"image": sandbox_config.image or DEFAULT_IMAGE,
"port": sandbox_config.port or DEFAULT_PORT,
"container_prefix": sandbox_config.container_prefix or DEFAULT_CONTAINER_PREFIX,
"idle_timeout": idle_timeout if idle_timeout is not None else DEFAULT_IDLE_TIMEOUT,
"auto_restart": auto_restart,
"replicas": replicas if replicas is not None else DEFAULT_REPLICAS,
"mounts": sandbox_config.mounts or [],
"environment": self._resolve_env_vars(sandbox_config.environment or {}),
@@ -608,18 +611,58 @@ class AioSandboxProvider(SandboxProvider):
def get(self, sandbox_id: str) -> Sandbox | None:
"""Get a sandbox by ID. Updates last activity timestamp.
When ``auto_restart`` is enabled (the default), the container's liveness
is verified on each lookup. If the underlying container has crashed, the
sandbox is evicted from all caches so that the next ``acquire()`` call will
transparently create a fresh container.
Args:
sandbox_id: The ID of the sandbox.
Returns:
The sandbox instance if found, None otherwise.
The sandbox instance if found and alive, None otherwise.
"""
with self._lock:
sandbox = self._sandboxes.get(sandbox_id)
if sandbox is not None:
self._last_activity[sandbox_id] = time.time()
if sandbox is None:
return None
self._last_activity[sandbox_id] = time.time()
auto_restart = self._config.get("auto_restart", True)
info = self._sandbox_infos.get(sandbox_id) if auto_restart else None
if not info:
return sandbox
if self._backend.is_alive(info):
return sandbox
info_to_destroy = None
with self._lock:
current_sandbox = self._sandboxes.get(sandbox_id)
current_info = self._sandbox_infos.get(sandbox_id)
if current_sandbox is None:
return None
if current_info is not info:
self._last_activity[sandbox_id] = time.time()
return current_sandbox
logger.warning(f"Sandbox {sandbox_id} container is not alive, evicting from cache for auto-restart")
self._sandboxes.pop(sandbox_id, None)
self._sandbox_infos.pop(sandbox_id, None)
self._last_activity.pop(sandbox_id, None)
self._warm_pool.pop(sandbox_id, None)
thread_ids = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids:
del self._thread_sandboxes[tid]
info_to_destroy = info
if info_to_destroy:
try:
self._backend.destroy(info_to_destroy)
except Exception as e:
logger.warning(f"Failed to cleanup dead sandbox {sandbox_id}: {e}")
return None
def release(self, sandbox_id: str) -> None:
"""Release a sandbox from active use into the warm pool.
@@ -84,8 +84,52 @@ class RemoteSandboxBackend(SandboxBackend):
"""
return self._provisioner_discover(sandbox_id)
def list_running(self) -> list[SandboxInfo]:
"""Return all sandboxes currently managed by the provisioner.
Calls ``GET /api/sandboxes`` so that ``AioSandboxProvider._reconcile_orphans()``
can adopt pods that were created by a previous process and were never
explicitly destroyed.
Without this, a process restart silently orphans all existing k8s Pods —
they stay running forever because the idle checker only
tracks in-process state.
"""
return self._provisioner_list()
# ── Provisioner API calls ─────────────────────────────────────────────
def _provisioner_list(self) -> list[SandboxInfo]:
"""GET /api/sandboxes → list all running sandboxes."""
try:
resp = requests.get(f"{self._provisioner_url}/api/sandboxes", timeout=10)
resp.raise_for_status()
data = resp.json()
if not isinstance(data, dict):
logger.warning("Provisioner list_running returned non-dict payload: %r", type(data))
return []
sandboxes = data.get("sandboxes", [])
if not isinstance(sandboxes, list):
logger.warning("Provisioner list_running returned non-list sandboxes: %r", type(sandboxes))
return []
infos: list[SandboxInfo] = []
for sandbox in sandboxes:
if not isinstance(sandbox, dict):
logger.warning("Provisioner list_running entry is not a dict: %r", type(sandbox))
continue
sandbox_id = sandbox.get("sandbox_id")
sandbox_url = sandbox.get("sandbox_url")
if isinstance(sandbox_id, str) and sandbox_id and isinstance(sandbox_url, str) and sandbox_url:
infos.append(SandboxInfo(sandbox_id=sandbox_id, sandbox_url=sandbox_url))
logger.info("Provisioner list_running: %d sandbox(es) found", len(infos))
return infos
except requests.RequestException as exc:
logger.warning("Provisioner list_running failed: %s", exc)
return []
def _provisioner_create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
"""POST /api/sandboxes → create Pod + Service."""
try:
@@ -1,5 +1,6 @@
from .app_config import get_app_config
from .extensions_config import ExtensionsConfig, get_extensions_config
from .loop_detection_config import LoopDetectionConfig
from .memory_config import MemoryConfig, get_memory_config
from .paths import Paths, get_paths
from .skill_evolution_config import SkillEvolutionConfig
@@ -20,6 +21,7 @@ __all__ = [
"SkillsConfig",
"ExtensionsConfig",
"get_extensions_config",
"LoopDetectionConfig",
"MemoryConfig",
"get_memory_config",
"get_tracing_config",
@@ -1,13 +1,22 @@
"""Configuration and loaders for custom agents."""
"""Configuration and loaders for custom agents.
Custom agents are stored per-user under ``{base_dir}/users/{user_id}/agents/{name}/``.
A legacy shared layout at ``{base_dir}/agents/{name}/`` is still readable so that
installations that pre-date user isolation continue to work until they run the
``scripts/migrate_user_isolation.py`` migration. New writes always target the
per-user layout.
"""
import logging
import re
from pathlib import Path
from typing import Any
import yaml
from pydantic import BaseModel
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
@@ -40,14 +49,47 @@ class AgentConfig(BaseModel):
skills: list[str] | None = None
def load_agent_config(name: str | None) -> AgentConfig | None:
def resolve_agent_dir(name: str, *, user_id: str | None = None) -> Path:
"""Return the on-disk directory for an agent, preferring the per-user layout.
Resolution order:
1. ``{base_dir}/users/{user_id}/agents/{name}/`` (per-user, current layout).
2. ``{base_dir}/agents/{name}/`` (legacy shared layout — read-only fallback).
If neither exists, the per-user path is returned so callers that intend to
create the agent write into the new layout.
Args:
name: Validated agent name.
user_id: Owner of the agent. Defaults to the effective user from the
request context (or ``"default"`` in no-auth mode).
"""
paths = get_paths()
effective_user = user_id or get_effective_user_id()
user_path = paths.user_agent_dir(effective_user, name)
if user_path.exists():
return user_path
legacy_path = paths.agent_dir(name)
if legacy_path.exists():
return legacy_path
return user_path
def load_agent_config(name: str | None, *, user_id: str | None = None) -> AgentConfig | None:
"""Load the custom or default agent's config from its directory.
Reads from the per-user layout first; falls back to the legacy shared layout
for installations that have not yet been migrated.
Args:
name: The agent name.
user_id: Owner of the agent. Defaults to the effective user from the
current request context.
Returns:
AgentConfig instance.
AgentConfig instance, or ``None`` if ``name`` is ``None``.
Raises:
FileNotFoundError: If the agent directory or config.yaml does not exist.
@@ -58,7 +100,7 @@ def load_agent_config(name: str | None) -> AgentConfig | None:
return None
name = validate_agent_name(name)
agent_dir = get_paths().agent_dir(name)
agent_dir = resolve_agent_dir(name, user_id=user_id)
config_file = agent_dir / "config.yaml"
if not agent_dir.exists():
@@ -84,7 +126,7 @@ def load_agent_config(name: str | None) -> AgentConfig | None:
return AgentConfig(**data)
def load_agent_soul(agent_name: str | None) -> str | None:
def load_agent_soul(agent_name: str | None, *, user_id: str | None = None) -> str | None:
"""Read the SOUL.md file for a custom agent, if it exists.
SOUL.md defines the agent's personality, values, and behavioral guardrails.
@@ -92,11 +134,16 @@ def load_agent_soul(agent_name: str | None) -> str | None:
Args:
agent_name: The name of the agent or None for the default agent.
user_id: Owner of the agent. Defaults to the effective user from the
current request context.
Returns:
The SOUL.md content as a string, or None if the file does not exist.
"""
agent_dir = get_paths().agent_dir(agent_name) if agent_name else get_paths().base_dir
if agent_name:
agent_dir = resolve_agent_dir(agent_name, user_id=user_id)
else:
agent_dir = get_paths().base_dir
soul_path = agent_dir / SOUL_FILENAME
if not soul_path.exists():
return None
@@ -104,32 +151,50 @@ def load_agent_soul(agent_name: str | None) -> str | None:
return content or None
def list_custom_agents() -> list[AgentConfig]:
def list_custom_agents(*, user_id: str | None = None) -> list[AgentConfig]:
"""Scan the agents directory and return all valid custom agents.
Returns the union of agents in the per-user layout and the legacy shared
layout, so that pre-migration installations remain visible until they are
migrated. Per-user entries shadow legacy entries with the same name.
Args:
user_id: Owner whose agents to list. Defaults to the effective user
from the current request context.
Returns:
List of AgentConfig for each valid agent directory found.
"""
agents_dir = get_paths().agents_dir
if not agents_dir.exists():
return []
paths = get_paths()
effective_user = user_id or get_effective_user_id()
seen: set[str] = set()
agents: list[AgentConfig] = []
for entry in sorted(agents_dir.iterdir()):
if not entry.is_dir():
user_root = paths.user_agents_dir(effective_user)
legacy_root = paths.agents_dir
for root in (user_root, legacy_root):
if not root.exists():
continue
for entry in sorted(root.iterdir()):
if not entry.is_dir():
continue
if entry.name in seen:
continue
config_file = entry / "config.yaml"
if not config_file.exists():
logger.debug(f"Skipping {entry.name}: no config.yaml")
continue
config_file = entry / "config.yaml"
if not config_file.exists():
logger.debug(f"Skipping {entry.name}: no config.yaml")
continue
try:
agent_cfg = load_agent_config(entry.name)
agents.append(agent_cfg)
except Exception as e:
logger.warning(f"Skipping agent '{entry.name}': {e}")
try:
agent_cfg = load_agent_config(entry.name, user_id=effective_user)
if agent_cfg is None:
continue
agents.append(agent_cfg)
seen.add(entry.name)
except Exception as e:
logger.warning(f"Skipping agent '{entry.name}': {e}")
agents.sort(key=lambda a: a.name)
return agents
@@ -1,5 +1,6 @@
import logging
import os
from collections.abc import Mapping
from contextvars import ContextVar
from pathlib import Path
from typing import Any, Self
@@ -14,6 +15,7 @@ from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpo
from deerflow.config.database_config import DatabaseConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
from deerflow.config.loop_detection_config import LoopDetectionConfig
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
@@ -99,6 +101,7 @@ class AppConfig(BaseModel):
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
model_config = ConfigDict(extra="allow")
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
@@ -157,56 +160,54 @@ class AppConfig(BaseModel):
config_data = cls.resolve_env_variables(config_data)
cls._apply_database_defaults(config_data)
# Load title config if present
if "title" in config_data:
load_title_config_from_dict(config_data["title"])
# Load summarization config if present
if "summarization" in config_data:
load_summarization_config_from_dict(config_data["summarization"])
# Load memory config if present
if "memory" in config_data:
load_memory_config_from_dict(config_data["memory"])
# Always refresh agents API config so removed config sections reset
# singleton-backed state to its default/disabled values on reload.
load_agents_api_config_from_dict(config_data.get("agents_api") or {})
# Load subagents config if present
if "subagents" in config_data:
load_subagents_config_from_dict(config_data["subagents"])
# Load tool_search config if present
if "tool_search" in config_data:
load_tool_search_config_from_dict(config_data["tool_search"])
# Load guardrails config if present
if "guardrails" in config_data:
load_guardrails_config_from_dict(config_data["guardrails"])
# Load circuit_breaker config if present
if "circuit_breaker" in config_data:
config_data["circuit_breaker"] = config_data["circuit_breaker"]
# Load checkpointer config if present
if "checkpointer" in config_data:
load_checkpointer_config_from_dict(config_data["checkpointer"])
# Load stream bridge config if present
if "stream_bridge" in config_data:
load_stream_bridge_config_from_dict(config_data["stream_bridge"])
# Always refresh ACP agent config so removed entries do not linger across reloads.
load_acp_config_from_dict(config_data.get("acp_agents", {}))
# Load extensions config separately (it's in a different file)
extensions_config = ExtensionsConfig.from_file()
config_data["extensions"] = extensions_config.model_dump()
result = cls.model_validate(config_data)
acp_agents = cls._validate_acp_agents(config_data.get("acp_agents", {}))
cls._apply_singleton_configs(result, acp_agents)
return result
@classmethod
def _validate_acp_agents(
cls,
config_data: Mapping[str, Mapping[str, object]] | None,
) -> dict[str, ACPAgentConfig]:
if config_data is None:
config_data = {}
return {name: ACPAgentConfig(**cfg) for name, cfg in config_data.items()}
@classmethod
def _apply_singleton_configs(cls, config: Self, acp_agents: dict[str, ACPAgentConfig]) -> None:
from deerflow.config.checkpointer_config import get_checkpointer_config
previous_checkpointer_config = get_checkpointer_config()
load_title_config_from_dict(config.title.model_dump())
load_summarization_config_from_dict(config.summarization.model_dump())
load_memory_config_from_dict(config.memory.model_dump())
load_agents_api_config_from_dict(config.agents_api.model_dump())
load_subagents_config_from_dict(config.subagents.model_dump())
load_tool_search_config_from_dict(config.tool_search.model_dump())
load_guardrails_config_from_dict(config.guardrails.model_dump())
load_checkpointer_config_from_dict(config.checkpointer.model_dump() if config.checkpointer is not None else None)
load_stream_bridge_config_from_dict(config.stream_bridge.model_dump() if config.stream_bridge is not None else None)
load_acp_config_from_dict({name: agent.model_dump() for name, agent in acp_agents.items()})
if previous_checkpointer_config != config.checkpointer:
# These runtime singletons derive their backend from checkpointer config.
# Keep imports local to avoid cycles: both providers import get_app_config.
from deerflow.runtime.checkpointer import reset_checkpointer
from deerflow.runtime.store import reset_store
reset_checkpointer()
reset_store()
@classmethod
def _apply_database_defaults(cls, config_data: dict[str, Any]) -> None:
"""Apply config.yaml defaults for persistence when the section is absent."""
@@ -14,12 +14,13 @@ class CheckpointerConfig(BaseModel):
description="Checkpointer backend type. "
"'memory' is in-process only (lost on restart). "
"'sqlite' persists to a local file (requires langgraph-checkpoint-sqlite). "
"'postgres' persists to PostgreSQL (requires langgraph-checkpoint-postgres)."
"'postgres' persists to PostgreSQL (install with deerflow-harness[postgres])."
)
connection_string: str | None = Field(
default=None,
description="Connection string for sqlite (file path) or postgres (DSN). "
"Required for sqlite and postgres types. "
"Optional for sqlite and defaults to 'store.db' when omitted. "
"Required for postgres. "
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
)
@@ -40,7 +41,10 @@ def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
_checkpointer_config = config
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
def load_checkpointer_config_from_dict(config_dict: dict | None) -> None:
"""Load checkpointer configuration from a dictionary."""
global _checkpointer_config
if config_dict is None:
_checkpointer_config = None
return
_checkpointer_config = CheckpointerConfig(**config_dict)
@@ -0,0 +1,73 @@
"""Configuration for loop detection middleware."""
from pydantic import BaseModel, Field, model_validator
class ToolFreqOverride(BaseModel):
"""Per-tool frequency threshold override.
Can be higher or lower than the global defaults. Commonly used to raise
thresholds for high-frequency tools like bash in batch workflows (e.g.
RNA-seq pipelines) without weakening protection on every other tool.
"""
warn: int = Field(ge=1)
hard_limit: int = Field(ge=1)
@model_validator(mode="after")
def _validate(self) -> "ToolFreqOverride":
if self.hard_limit < self.warn:
raise ValueError("hard_limit must be >= warn")
return self
class LoopDetectionConfig(BaseModel):
"""Configuration for repetitive tool-call loop detection."""
enabled: bool = Field(
default=True,
description="Whether to enable repetitive tool-call loop detection",
)
warn_threshold: int = Field(
default=3,
ge=1,
description="Number of identical tool-call sets before injecting a warning",
)
hard_limit: int = Field(
default=5,
ge=1,
description="Number of identical tool-call sets before forcing a stop",
)
window_size: int = Field(
default=20,
ge=1,
description="Number of recent tool-call sets to track per thread",
)
max_tracked_threads: int = Field(
default=100,
ge=1,
description="Maximum number of thread histories to keep in memory",
)
tool_freq_warn: int = Field(
default=30,
ge=1,
description="Number of calls to the same tool type before injecting a frequency warning",
)
tool_freq_hard_limit: int = Field(
default=50,
ge=1,
description="Number of calls to the same tool type before forcing a stop",
)
tool_freq_overrides: dict[str, ToolFreqOverride] = Field(
default_factory=dict,
description=("Per-tool overrides for tool_freq_warn / tool_freq_hard_limit, keyed by tool name. Values can be higher or lower than the global defaults. Commonly used to raise thresholds for high-frequency tools like bash."),
)
@model_validator(mode="after")
def validate_thresholds(self) -> "LoopDetectionConfig":
"""Ensure hard stop cannot happen before the warning threshold."""
if self.hard_limit < self.warn_threshold:
raise ValueError("hard_limit must be greater than or equal to warn_threshold")
if self.tool_freq_hard_limit < self.tool_freq_warn:
raise ValueError("tool_freq_hard_limit must be greater than or equal to tool_freq_warn")
return self
@@ -132,15 +132,20 @@ class Paths:
@property
def agents_dir(self) -> Path:
"""Root directory for all custom agents: `{base_dir}/agents/`."""
"""Legacy root for shared (pre user-isolation) custom agents: `{base_dir}/agents/`.
New code should use :meth:`user_agents_dir` instead. This property remains
only as a read-side fallback for installations that have not yet run the
``migrate_user_isolation.py`` script.
"""
return self.base_dir / "agents"
def agent_dir(self, name: str) -> Path:
"""Directory for a specific agent: `{base_dir}/agents/{name}/`."""
"""Legacy per-agent directory (no user isolation): `{base_dir}/agents/{name}/`."""
return self.agents_dir / name.lower()
def agent_memory_file(self, name: str) -> Path:
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
"""Legacy per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
return self.agent_dir(name) / "memory.json"
def user_dir(self, user_id: str) -> Path:
@@ -151,9 +156,17 @@ class Paths:
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
return self.user_dir(user_id) / "memory.json"
def user_agents_dir(self, user_id: str) -> Path:
"""Per-user root for that user's custom agents: `{base_dir}/users/{user_id}/agents/`."""
return self.user_dir(user_id) / "agents"
def user_agent_dir(self, user_id: str, agent_name: str) -> Path:
"""Per-user per-agent directory: `{base_dir}/users/{user_id}/agents/{name}/`."""
return self.user_agents_dir(user_id) / agent_name.lower()
def user_agent_memory_file(self, user_id: str, agent_name: str) -> Path:
"""Per-user per-agent memory: `{base_dir}/users/{user_id}/agents/{name}/memory.json`."""
return self.user_dir(user_id) / "agents" / agent_name.lower() / "memory.json"
return self.user_agent_dir(user_id, agent_name) / "memory.json"
def thread_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
@@ -23,6 +23,9 @@ class SandboxConfig(BaseModel):
replicas: Maximum number of concurrent sandbox containers (default: 3). When the limit is reached the least-recently-used sandbox is evicted to make room.
container_prefix: Prefix for container names (default: deer-flow-sandbox)
idle_timeout: Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.
auto_restart: Automatically restart sandbox containers that have crashed (default: true). When a tool call
detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated
on the next acquire. Set to false to disable.
mounts: List of volume mounts to share directories with the container
environment: Environment variables to inject into the container (values starting with $ are resolved from host env)
"""
@@ -55,6 +58,10 @@ class SandboxConfig(BaseModel):
default=None,
description="Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.",
)
auto_restart: bool = Field(
default=True,
description="Automatically restart sandbox containers that have crashed. When a tool call detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated on the next acquire.",
)
mounts: list[VolumeMountConfig] = Field(
default_factory=list,
description="List of volume mounts to share directories between host and container",
@@ -40,7 +40,10 @@ def set_stream_bridge_config(config: StreamBridgeConfig | None) -> None:
_stream_bridge_config = config
def load_stream_bridge_config_from_dict(config_dict: dict) -> None:
def load_stream_bridge_config_from_dict(config_dict: dict | None) -> None:
"""Load stream bridge configuration from a dictionary."""
global _stream_bridge_config
if config_dict is None:
_stream_bridge_config = None
return
_stream_bridge_config = StreamBridgeConfig(**config_dict)
@@ -179,9 +179,3 @@ def load_subagents_config_from_dict(config_dict: dict) -> None:
overrides_summary or "none",
custom_agents_names or "none",
)
else:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
)
@@ -4,4 +4,4 @@ from pydantic import BaseModel, Field
class TokenUsageConfig(BaseModel):
"""Configuration for token usage tracking."""
enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
enabled: bool = Field(default=True, description="Enable token usage tracking middleware")
@@ -196,6 +196,10 @@ class ClaudeChatModel(ChatAnthropic):
enforced by both the Anthropic API and AWS Bedrock. Breakpoints are
placed on the *last* eligible blocks because later breakpoints cover a
larger prefix and yield better cache hit rates.
The system prompt is expected to be fully static (no per-user memory or
current date). Dynamic context is injected per-turn via
DynamicContextMiddleware as a <system-reminder> in the first HumanMessage.
"""
MAX_CACHE_BREAKPOINTS = 4
@@ -81,7 +81,16 @@ async def init_engine(
try:
import asyncpg # noqa: F401
except ImportError:
raise ImportError("database.backend is set to 'postgres' but asyncpg is not installed.\nInstall it with:\n uv sync --extra postgres\nOr switch to backend: sqlite in config.yaml for single-node deployment.") from None
raise ImportError(
"database.backend is set to 'postgres' but asyncpg is not installed.\n"
"Install it with:\n"
" cd backend && uv sync --all-packages --extra postgres\n"
"On the next `make dev` the postgres extra is auto-detected from\n"
"config.yaml (database.backend: postgres) and reinstalled, so it\n"
"will not be wiped again. Set UV_EXTRAS=postgres in .env to opt in\n"
"explicitly. Or switch to backend: sqlite in config.yaml for\n"
"single-node deployment."
) from None
if backend == "sqlite":
import os
@@ -36,7 +36,9 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
SQLITE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite checkpointer. Install it with: uv add langgraph-checkpoint-sqlite"
POSTGRES_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
POSTGRES_INSTALL = (
"langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install the package extra with: pip install 'deerflow-harness[postgres]' (or use: uv sync --all-packages --extra postgres when developing locally)"
)
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
# ---------------------------------------------------------------------------
@@ -9,6 +9,7 @@ from __future__ import annotations
import json
import logging
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@@ -33,20 +34,21 @@ class DbRunEventStore(RunEventStore):
if isinstance(val, datetime):
d["created_at"] = val.isoformat()
d.pop("id", None)
# Restore dict content that was JSON-serialized on write
# Restore structured content that was JSON-serialized on write.
raw = d.get("content", "")
if isinstance(raw, str) and d.get("metadata", {}).get("content_is_dict"):
metadata = d.get("metadata", {})
if isinstance(raw, str) and (metadata.get("content_is_json") or metadata.get("content_is_dict")):
try:
d["content"] = json.loads(raw)
except (json.JSONDecodeError, ValueError):
# Content looked like JSON (content_is_dict flag) but failed to parse;
# Content looked like JSON but failed to parse;
# keep the raw string as-is.
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
return d
def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]:
def _truncate_trace(self, category: str, content: Any, metadata: dict | None) -> tuple[Any, dict]:
if category == "trace":
text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
text = content if isinstance(content, str) else json.dumps(content, default=str, ensure_ascii=False)
encoded = text.encode("utf-8")
if len(encoded) > self._max_trace_content:
# Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
@@ -54,6 +56,18 @@ class DbRunEventStore(RunEventStore):
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
return content, metadata or {}
@staticmethod
def _content_to_db(content: Any, metadata: dict | None) -> tuple[str, dict]:
metadata = metadata or {}
if isinstance(content, str):
return content, metadata
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**metadata, "content_is_json": True}
if isinstance(content, dict):
metadata["content_is_dict"] = True
return db_content, metadata
@staticmethod
def _user_id_from_context() -> str | None:
"""Soft read of user_id from contextvar for write paths.
@@ -82,11 +96,7 @@ class DbRunEventStore(RunEventStore):
the initial ``human_message`` event (once per run).
"""
content, metadata = self._truncate_trace(category, content, metadata)
if isinstance(content, dict):
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
db_content, metadata = self._content_to_db(content, metadata)
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
@@ -128,11 +138,7 @@ class DbRunEventStore(RunEventStore):
category = e.get("category", "trace")
metadata = e.get("metadata")
content, metadata = self._truncate_trace(category, content, metadata)
if isinstance(content, dict):
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
db_content, metadata = self._content_to_db(content, metadata)
row = RunEventRow(
thread_id=e["thread_id"],
run_id=e["run_id"],
@@ -36,7 +36,9 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
SQLITE_STORE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite store. Install it with: uv add langgraph-checkpoint-sqlite"
POSTGRES_STORE_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL store. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
POSTGRES_STORE_INSTALL = (
"langgraph-checkpoint-postgres is required for the PostgreSQL store. Install the package extra with: pip install 'deerflow-harness[postgres]' (or use: uv sync --all-packages --extra postgres when developing locally)"
)
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
# ---------------------------------------------------------------------------
@@ -42,6 +42,13 @@ class LocalSandbox(Sandbox):
"""Return whether the selected shell is cmd.exe."""
return LocalSandbox._shell_name(shell) in {"cmd", "cmd.exe"}
@staticmethod
def _is_msys_shell(shell: str) -> bool:
"""Return whether the selected shell is a Git Bash/MSYS shell."""
normalized = shell.replace("\\", "/").lower()
shell_name = LocalSandbox._shell_name(shell)
return shell_name in {"sh.exe", "bash.exe"} and any(part in normalized for part in ("/git/", "/mingw", "/msys"))
@staticmethod
def _find_first_available_shell(candidates: tuple[str, ...]) -> str | None:
"""Return the first executable shell path or command found from candidates."""
@@ -303,12 +310,19 @@ class LocalSandbox(Sandbox):
shell = self._get_shell()
if os.name == "nt":
env = None
if self._is_powershell(shell):
args = [shell, "-NoProfile", "-Command", resolved_command]
elif self._is_cmd_shell(shell):
args = [shell, "/c", resolved_command]
else:
args = [shell, "-c", resolved_command]
if self._is_msys_shell(shell):
env = {
**os.environ,
"MSYS_NO_PATHCONV": "1",
"MSYS2_ARG_CONV_EXCL": "*",
}
result = subprocess.run(
args,
@@ -316,6 +330,7 @@ class LocalSandbox(Sandbox):
capture_output=True,
text=True,
timeout=600,
env=env,
)
else:
args = [shell, "-c", resolved_command]
@@ -3,10 +3,9 @@ import re
import shlex
from pathlib import Path
from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT
from langchain.tools import tool
from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.agents.thread_state import ThreadDataState
from deerflow.config import get_app_config
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
from deerflow.sandbox.exceptions import (
@@ -19,6 +18,7 @@ from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.sandbox.search import GrepMatch
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
from deerflow.tools.types import Runtime
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
_FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE)
@@ -419,7 +419,7 @@ def _join_path_preserving_style(base: str, relative: str) -> str:
return f"{stripped_base}{separator}{normalized_relative}"
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
def _sanitize_error(error: Exception, runtime: Runtime | None = None) -> str:
"""Sanitize an error message to avoid leaking host filesystem paths.
In local-sandbox mode, resolved host paths in the error string are masked
@@ -994,7 +994,7 @@ def _apply_cwd_prefix(command: str, thread_data: ThreadDataState | None) -> str:
return command
def get_thread_data(runtime: ToolRuntime[ContextT, ThreadState] | None) -> ThreadDataState | None:
def get_thread_data(runtime: Runtime | None) -> ThreadDataState | None:
"""Extract thread_data from runtime state."""
if runtime is None:
return None
@@ -1003,7 +1003,7 @@ def get_thread_data(runtime: ToolRuntime[ContextT, ThreadState] | None) -> Threa
return runtime.state.get("thread_data")
def is_local_sandbox(runtime: ToolRuntime[ContextT, ThreadState] | None) -> bool:
def is_local_sandbox(runtime: Runtime | None) -> bool:
"""Check if the current sandbox is a local sandbox.
Path replacement is only needed for local sandbox since aio sandbox
@@ -1019,7 +1019,7 @@ def is_local_sandbox(runtime: ToolRuntime[ContextT, ThreadState] | None) -> bool
return sandbox_state.get("sandbox_id") == "local"
def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
def sandbox_from_runtime(runtime: Runtime | None = None) -> Sandbox:
"""Extract sandbox instance from tool runtime.
DEPRECATED: Use ensure_sandbox_initialized() for lazy initialization support.
@@ -1048,7 +1048,7 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
return sandbox
def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
def ensure_sandbox_initialized(runtime: Runtime | None = None) -> Sandbox:
"""Ensure sandbox is initialized, acquiring lazily if needed.
On first call, acquires a sandbox from the provider and stores it in runtime state.
@@ -1107,7 +1107,7 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
return sandbox
def ensure_thread_directories_exist(runtime: ToolRuntime[ContextT, ThreadState] | None) -> None:
def ensure_thread_directories_exist(runtime: Runtime | None) -> None:
"""Ensure thread data directories (workspace, uploads, outputs) exist.
This function is called lazily when any sandbox tool is first used.
@@ -1221,7 +1221,7 @@ def _truncate_ls_output(output: str, max_chars: int) -> str:
@tool("bash", parse_docstring=True)
def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str:
def bash_tool(runtime: Runtime, description: str, command: str) -> str:
"""Execute a bash command in a Linux environment.
@@ -1270,7 +1270,7 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
@tool("ls", parse_docstring=True)
def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path: str) -> str:
def ls_tool(runtime: Runtime, description: str, path: str) -> str:
"""List the contents of a directory up to 2 levels deep in tree format.
Args:
@@ -1318,7 +1318,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
@tool("glob", parse_docstring=True)
def glob_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
description: str,
pattern: str,
path: str,
@@ -1368,7 +1368,7 @@ def glob_tool(
@tool("grep", parse_docstring=True)
def grep_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
description: str,
pattern: str,
path: str,
@@ -1438,7 +1438,7 @@ def grep_tool(
@tool("read_file", parse_docstring=True)
def read_file_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
description: str,
path: str,
start_line: int | None = None,
@@ -1493,7 +1493,7 @@ def read_file_tool(
@tool("write_file", parse_docstring=True)
def write_file_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
description: str,
path: str,
content: str,
@@ -1533,7 +1533,7 @@ def write_file_tool(
@tool("str_replace", parse_docstring=True)
def str_replace_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
description: str,
path: str,
old_str: str,
@@ -9,6 +9,29 @@ from .types import SKILL_MD_FILE, Skill, SkillCategory
logger = logging.getLogger(__name__)
def parse_allowed_tools(raw: object, skill_file: Path) -> list[str] | None:
"""Parse the optional allowed-tools frontmatter field.
Returns None when the field is omitted. Returns a list when the field is a
YAML sequence of strings, including an empty list for explicit no-tool
skills. Raises ValueError for malformed values.
"""
if raw is None:
return None
if not isinstance(raw, list):
raise ValueError(f"allowed-tools in {skill_file} must be a list of strings")
allowed_tools: list[str] = []
for item in raw:
if not isinstance(item, str):
raise ValueError(f"allowed-tools in {skill_file} must contain only strings")
tool_name = item.strip()
if not tool_name:
raise ValueError(f"allowed-tools in {skill_file} cannot contain empty tool names")
allowed_tools.append(tool_name)
return allowed_tools
def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: Path | None = None) -> Skill | None:
"""Parse a SKILL.md file and extract metadata.
@@ -64,6 +87,12 @@ def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: P
if license_text is not None:
license_text = str(license_text).strip() or None
try:
allowed_tools = parse_allowed_tools(metadata.get("allowed-tools"), skill_file)
except ValueError as exc:
logger.error("Invalid allowed-tools in %s: %s", skill_file, exc)
return None
return Skill(
name=name,
description=description,
@@ -72,6 +101,7 @@ def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: P
skill_file=skill_file,
relative_path=relative_path or Path(skill_file.parent.name),
category=category,
allowed_tools=allowed_tools,
enabled=True, # Actual state comes from the extensions config file.
)
@@ -0,0 +1,44 @@
import logging
from typing import Protocol
from deerflow.skills.types import Skill
logger = logging.getLogger(__name__)
class NamedTool(Protocol):
name: str
def allowed_tool_names_for_skills(skills: list[Skill]) -> set[str] | None:
"""Return the union of explicit skill allowed-tools declarations.
None means legacy allow-all behavior. It is returned only when no loaded
skill declares allowed-tools. Once any skill declares the field, legacy
skills without the field contribute no tools instead of disabling the
explicit restrictions from other skills.
"""
if not skills:
return None
allowed: set[str] = set()
has_explicit_declaration = False
for skill in skills:
if skill.allowed_tools is None:
continue
has_explicit_declaration = True
if not skill.allowed_tools:
logger.info("Skill %s declared empty allowed-tools", skill.name)
allowed.update(skill.allowed_tools)
if not has_explicit_declaration:
return None
return allowed
def filter_tools_by_skill_allowed_tools[ToolT: NamedTool](tools: list[ToolT], skills: list[Skill]) -> list[ToolT]:
allowed = allowed_tool_names_for_skills(skills)
if allowed is None:
return tools
return [tool for tool in tools if tool.name in allowed]
@@ -27,6 +27,7 @@ class Skill:
skill_file: Path
relative_path: Path # Relative path from category root to skill directory
category: SkillCategory # 'public' or 'custom'
allowed_tools: list[str] | None = None
enabled: bool = False # Whether this skill is enabled
@property
@@ -8,6 +8,7 @@ from pathlib import Path
import yaml
from deerflow.skills.parser import parse_allowed_tools
from deerflow.skills.types import SKILL_MD_FILE
# Allowed properties in SKILL.md frontmatter
@@ -84,4 +85,9 @@ def _validate_skill_frontmatter(skill_dir: Path) -> tuple[bool, str, str | None]
if len(description) > 1024:
return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters.", None
try:
parse_allowed_tools(frontmatter.get("allowed-tools"), skill_md)
except ValueError as e:
return False, str(e).replace(str(skill_md), SKILL_MD_FILE), None
return True, "Skill is valid!", name
@@ -23,6 +23,8 @@ from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadSt
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
from deerflow.skills.types import Skill
from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name
logger = logging.getLogger(__name__)
@@ -260,16 +262,16 @@ class SubagentExecutor:
# Generate trace_id if not provided (for top-level calls)
self.trace_id = trace_id or str(uuid.uuid4())[:8]
# Filter tools based on config
self.tools = _filter_tools(
self._base_tools = _filter_tools(
tools,
config.tools,
config.disallowed_tools,
)
self.tools = self._base_tools
logger.info(f"[trace={self.trace_id}] SubagentExecutor initialized: {config.name} with {len(self.tools)} tools")
def _create_agent(self):
def _create_agent(self, tools: list[BaseTool] | None = None):
"""Create the agent instance."""
app_config = self.app_config or get_app_config()
if self.model_name is None:
@@ -283,26 +285,14 @@ class SubagentExecutor:
return create_agent(
model=model,
tools=self.tools,
tools=tools if tools is not None else self.tools,
middleware=middlewares,
system_prompt=self.config.system_prompt,
state_schema=ThreadState,
)
async def _load_skill_messages(self) -> list[SystemMessage]:
"""Load skill content as conversation items based on config.skills.
Aligned with Codex's pattern: each subagent loads its own skills
per-session and injects them as conversation items (developer messages),
not as system prompt text. The config.skills whitelist controls which
skills are loaded:
- None: load all enabled skills
- []: no skills
- ["skill-a", "skill-b"]: only these skills
Returns:
List of SystemMessages containing skill content.
"""
async def _load_skills(self) -> list[Skill]:
"""Load enabled skill metadata based on config.skills."""
if self.config.skills is not None and len(self.config.skills) == 0:
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} skills=[] — skipping skill loading")
return []
@@ -316,8 +306,8 @@ class SubagentExecutor:
all_skills = await asyncio.to_thread(storage.load_skills, enabled_only=True)
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} loaded {len(all_skills)} enabled skills from disk")
except Exception:
logger.warning(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}", exc_info=True)
return []
logger.exception(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}")
raise
if not all_skills:
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} no enabled skills found")
@@ -326,10 +316,26 @@ class SubagentExecutor:
# Filter by config.skills whitelist
if self.config.skills is not None:
allowed = set(self.config.skills)
skills = [s for s in all_skills if s.name in allowed]
else:
skills = all_skills
return [s for s in all_skills if s.name in allowed]
return all_skills
def _apply_skill_allowed_tools(self, skills: list[Skill]) -> list[BaseTool]:
return filter_tools_by_skill_allowed_tools(self._base_tools, skills)
async def _load_skill_messages(self, skills: list[Skill]) -> list[SystemMessage]:
"""Load skill content as conversation items based on config.skills.
Aligned with Codex's pattern: each subagent loads its own skills
per-session and injects them as conversation items (developer messages),
not as system prompt text. The config.skills whitelist controls which
skills are loaded:
- None: load all enabled skills
- []: no skills
- ["skill-a", "skill-b"]: only these skills
Returns:
List of SystemMessages containing skill content.
"""
if not skills:
return []
@@ -347,19 +353,21 @@ class SubagentExecutor:
return messages
async def _build_initial_state(self, task: str) -> dict[str, Any]:
async def _build_initial_state(self, task: str) -> tuple[dict[str, Any], list[BaseTool]]:
"""Build the initial state for agent execution.
Args:
task: The task description.
Returns:
Initial state dictionary.
Initial state dictionary and tools filtered by loaded skill metadata.
"""
# Load skills as conversation items (Codex pattern)
skill_messages = await self._load_skill_messages()
skills = await self._load_skills()
filtered_tools = self._apply_skill_allowed_tools(skills)
skill_messages = await self._load_skill_messages(skills)
messages: list = []
messages: list[Any] = []
# Skill content injected as developer/system messages before the task
messages.extend(skill_messages)
# Then the actual task
@@ -375,7 +383,7 @@ class SubagentExecutor:
if self.thread_data is not None:
state["thread_data"] = self.thread_data
return state
return state, filtered_tools
async def _aexecute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
"""Execute a task asynchronously.
@@ -405,8 +413,8 @@ class SubagentExecutor:
result.ai_messages = ai_messages
try:
agent = self._create_agent()
state = await self._build_initial_state(task)
state, filtered_tools = await self._build_initial_state(task)
agent = self._create_agent(filtered_tools)
# Build config with thread_id for sandbox access and recursion limit
run_config: RunnableConfig = {
@@ -2,10 +2,12 @@ from .clarification_tool import ask_clarification_tool
from .present_file_tool import present_file_tool
from .setup_agent_tool import setup_agent
from .task_tool import task_tool
from .update_agent_tool import update_agent
from .view_image_tool import view_image_tool
__all__ = [
"setup_agent",
"update_agent",
"present_file_tool",
"ask_clarification_tool",
"view_image_tool",
@@ -1,20 +1,19 @@
from pathlib import Path
from typing import Annotated
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langchain.tools import InjectedToolCallId, tool
from langchain_core.messages import ToolMessage
from langgraph.config import get_config
from langgraph.types import Command
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadState
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.tools.types import Runtime
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
def _get_thread_id(runtime: Runtime) -> str | None:
"""Resolve the current thread id from runtime context or RunnableConfig."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id:
@@ -32,7 +31,7 @@ def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
def _normalize_presented_filepath(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
filepath: str,
) -> str:
"""Normalize a presented file path to the `/mnt/user-data/outputs/*` contract.
@@ -83,7 +82,7 @@ def _normalize_presented_filepath(
@tool("present_files", parse_docstring=True)
def present_file_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
filepaths: list[str],
tool_call_id: Annotated[str, InjectedToolCallId],
) -> Command:
@@ -3,20 +3,28 @@ import logging
import yaml
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langgraph.prebuilt import ToolRuntime
from langgraph.types import Command
from deerflow.config.agents_config import validate_agent_name
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
def _get_runtime_user_id(runtime: Runtime) -> str:
context_user_id = runtime.context.get("user_id") if runtime.context else None
if context_user_id:
return str(context_user_id)
return get_effective_user_id()
@tool
def setup_agent(
soul: str,
description: str,
runtime: ToolRuntime,
runtime: Runtime,
skills: list[str] | None = None,
) -> Command:
"""Setup the custom DeerFlow agent.
@@ -34,7 +42,14 @@ def setup_agent(
try:
agent_name = validate_agent_name(agent_name)
paths = get_paths()
agent_dir = paths.agent_dir(agent_name) if agent_name else paths.base_dir
if agent_name:
# Custom agents are persisted under the current user's bucket so
# different users do not see each other's agents.
user_id = _get_runtime_user_id(runtime)
agent_dir = paths.user_agent_dir(user_id, agent_name)
else:
# Default agent (no agent_name): SOUL.md lives at the global base dir.
agent_dir = paths.base_dir
is_new_dir = not agent_dir.exists()
agent_dir.mkdir(parents=True, exist_ok=True)
@@ -6,11 +6,9 @@ import uuid
from dataclasses import replace
from typing import TYPE_CHECKING, Annotated, Any, cast
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langchain.tools import InjectedToolCallId, tool
from langgraph.config import get_stream_writer
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadState
from deerflow.config import get_app_config
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
@@ -21,6 +19,7 @@ from deerflow.subagents.executor import (
get_background_task_result,
request_cancel_background_task,
)
from deerflow.tools.types import Runtime
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
@@ -50,12 +49,11 @@ def _merge_skill_allowlists(parent: list[str] | None, child: list[str] | None) -
@tool("task", parse_docstring=True)
async def task_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
description: str,
prompt: str,
subagent_type: str,
tool_call_id: Annotated[str, InjectedToolCallId],
max_turns: int | None = None,
) -> str:
"""Delegate a task to a specialized subagent that runs in its own context.
@@ -91,7 +89,6 @@ async def task_tool(
description: A short (3-5 word) description of the task for logging/display. ALWAYS PROVIDE THIS PARAMETER FIRST.
prompt: The task description for the subagent. Be specific and clear about what needs to be done. ALWAYS PROVIDE THIS PARAMETER SECOND.
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max.
"""
runtime_app_config = _get_runtime_app_config(runtime)
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
@@ -113,9 +110,6 @@ async def task_tool(
# each subagent loads its own skills based on config, injected as conversation items).
# No longer appended to system_prompt here.
if max_turns is not None:
overrides["max_turns"] = max_turns
# Extract parent context from runtime
sandbox_state = None
thread_data = None
@@ -0,0 +1,241 @@
"""update_agent tool — let a custom agent persist updates to its own SOUL.md / config.
Bound to the lead agent only when ``runtime.context['agent_name']`` is set
(i.e. inside an existing custom agent's chat). The default agent does not see
this tool, and the bootstrap flow continues to use ``setup_agent`` for the
initial creation handshake.
The tool writes back to ``{base_dir}/users/{user_id}/agents/{agent_name}/{config.yaml,SOUL.md}``
so an agent created by one user is never visible to (or mutable by) another.
Writes are staged into temp files first; both files are renamed into place only
after both temp files are successfully written, so a partial failure cannot leave
config.yaml updated while SOUL.md still holds stale content.
"""
from __future__ import annotations
import logging
import tempfile
from pathlib import Path
from typing import Any
import yaml
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langgraph.types import Command
from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import get_app_config
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
def _stage_temp(path: Path, text: str) -> Path:
"""Write ``text`` into a sibling temp file and return its path.
The caller is responsible for ``Path.replace``-ing the temp into the target
once every staged file is ready, or for unlinking it on failure.
"""
path.parent.mkdir(parents=True, exist_ok=True)
fd = tempfile.NamedTemporaryFile(
mode="w",
dir=path.parent,
suffix=".tmp",
delete=False,
encoding="utf-8",
)
try:
fd.write(text)
fd.flush()
fd.close()
return Path(fd.name)
except BaseException:
fd.close()
Path(fd.name).unlink(missing_ok=True)
raise
def _cleanup_temps(temps: list[Path]) -> None:
"""Best-effort removal of staged temp files."""
for tmp in temps:
try:
tmp.unlink(missing_ok=True)
except OSError:
logger.debug("Failed to clean up temp file %s", tmp, exc_info=True)
@tool
def update_agent(
runtime: Runtime,
soul: str | None = None,
description: str | None = None,
skills: list[str] | None = None,
tool_groups: list[str] | None = None,
model: str | None = None,
) -> Command:
"""Persist updates to the current custom agent's SOUL.md and config.yaml.
Use this when the user asks to refine the agent's identity, description,
skill whitelist, tool-group whitelist, or default model. Only the fields
you explicitly pass are updated; omitted fields keep their existing values.
Pass ``soul`` as the FULL replacement SOUL.md content there is no patch
semantics, so always start from the current SOUL and apply your edits.
Pass ``skills=[]`` to disable all skills for this agent. Omit ``skills``
entirely to keep the existing whitelist.
Args:
soul: Optional full replacement SOUL.md content.
description: Optional new one-line description.
skills: Optional skill whitelist. ``[]`` = no skills, omit = unchanged.
tool_groups: Optional tool-group whitelist. ``[]`` = empty, omit = unchanged.
model: Optional model override (must match a configured model name).
Returns:
Command with a ToolMessage describing the result. Changes take effect
on the next user turn (when the lead agent is rebuilt with the fresh
SOUL.md and config.yaml).
"""
tool_call_id = runtime.tool_call_id
agent_name_raw: str | None = runtime.context.get("agent_name") if runtime.context else None
def _err(message: str) -> Command:
return Command(update={"messages": [ToolMessage(content=f"Error: {message}", tool_call_id=tool_call_id)]})
if soul is None and description is None and skills is None and tool_groups is None and model is None:
return _err("No fields provided. Pass at least one of: soul, description, skills, tool_groups, model.")
try:
agent_name = validate_agent_name(agent_name_raw)
except ValueError as e:
return _err(str(e))
if not agent_name:
return _err("update_agent is only available inside a custom agent's chat. There is no agent_name in the current runtime context, so there is nothing to update. If you are inside the bootstrap flow, use setup_agent instead.")
# Resolve the active user so that updates only affect this user's agent.
# ``get_effective_user_id`` returns DEFAULT_USER_ID when no auth context
# is set (matching how memory and thread storage behave).
user_id = get_effective_user_id()
# Reject an unknown ``model`` *before* touching the filesystem. Otherwise
# ``_resolve_model_name`` silently falls back to the default at runtime
# and the user sees confusing repeated warnings on every later turn.
if model is not None and get_app_config().get_model_config(model) is None:
return _err(f"Unknown model '{model}'. Pass a model name that exists in config.yaml's models section.")
paths = get_paths()
agent_dir = paths.user_agent_dir(user_id, agent_name)
if not agent_dir.exists() and paths.agent_dir(agent_name).exists():
return _err(f"Agent '{agent_name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before updating.")
try:
existing_cfg = load_agent_config(agent_name, user_id=user_id)
except FileNotFoundError:
return _err(f"Agent '{agent_name}' does not exist for the current user. Use setup_agent to create a new agent first.")
except ValueError as e:
return _err(f"Agent '{agent_name}' has an unreadable config: {e}")
if existing_cfg is None:
return _err(f"Agent '{agent_name}' could not be loaded.")
updated_fields: list[str] = []
# Force the on-disk ``name`` to match the directory we are writing into,
# even if ``existing_cfg.name`` had drifted (e.g. from manual yaml edits).
config_data: dict[str, Any] = {"name": agent_name}
new_description = description if description is not None else existing_cfg.description
config_data["description"] = new_description
if description is not None and description != existing_cfg.description:
updated_fields.append("description")
new_model = model if model is not None else existing_cfg.model
if new_model is not None:
config_data["model"] = new_model
if model is not None and model != existing_cfg.model:
updated_fields.append("model")
new_tool_groups = tool_groups if tool_groups is not None else existing_cfg.tool_groups
if new_tool_groups is not None:
config_data["tool_groups"] = new_tool_groups
if tool_groups is not None and tool_groups != existing_cfg.tool_groups:
updated_fields.append("tool_groups")
new_skills = skills if skills is not None else existing_cfg.skills
if new_skills is not None:
config_data["skills"] = new_skills
if skills is not None and skills != existing_cfg.skills:
updated_fields.append("skills")
config_changed = bool({"description", "model", "tool_groups", "skills"} & set(updated_fields))
# Stage every file we intend to rewrite into a temp sibling. Only after
# *all* temp files exist do we rename them into place — so a failure on
# SOUL.md cannot leave config.yaml already replaced.
pending: list[tuple[Path, Path]] = []
staged_temps: list[Path] = []
try:
agent_dir.mkdir(parents=True, exist_ok=True)
if config_changed:
yaml_text = yaml.dump(config_data, default_flow_style=False, allow_unicode=True, sort_keys=False)
config_target = agent_dir / "config.yaml"
config_tmp = _stage_temp(config_target, yaml_text)
staged_temps.append(config_tmp)
pending.append((config_tmp, config_target))
if soul is not None:
soul_target = agent_dir / "SOUL.md"
soul_tmp = _stage_temp(soul_target, soul)
staged_temps.append(soul_tmp)
pending.append((soul_tmp, soul_target))
updated_fields.append("soul")
# Commit phase. ``Path.replace`` is atomic per file on POSIX/NTFS and
# the staging step above means any earlier failure has already been
# reported. The remaining failure mode is a crash *between* two
# ``replace`` calls, which is reported via the partial-write error
# branch below so the caller knows which files are now on disk.
committed: list[Path] = []
try:
for tmp, target in pending:
tmp.replace(target)
committed.append(target)
except Exception as e:
_cleanup_temps([t for t, _ in pending if t not in committed])
if committed:
logger.error(
"[update_agent] Partial write for agent '%s' (user=%s): committed=%s, failed during rename: %s",
agent_name,
user_id,
[p.name for p in committed],
e,
exc_info=True,
)
return _err(f"Partial update for agent '{agent_name}': {[p.name for p in committed]} were updated, but the rest failed ({e}). Re-run update_agent to retry the remaining fields.")
raise
except Exception as e:
_cleanup_temps(staged_temps)
logger.error("[update_agent] Failed to update agent '%s' (user=%s): %s", agent_name, user_id, e, exc_info=True)
return _err(f"Failed to update agent '{agent_name}': {e}")
if not updated_fields:
return Command(update={"messages": [ToolMessage(content=f"No changes applied to agent '{agent_name}'. The provided values matched the existing config.", tool_call_id=tool_call_id)]})
logger.info("[update_agent] Updated agent '%s' (user=%s) fields: %s", agent_name, user_id, updated_fields)
return Command(
update={
"messages": [
ToolMessage(
content=(f"Agent '{agent_name}' updated successfully. Changed: {', '.join(updated_fields)}. The new configuration takes effect on the next user turn."),
tool_call_id=tool_call_id,
)
]
}
)
@@ -3,13 +3,13 @@ import mimetypes
from pathlib import Path
from typing import Annotated
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langchain.tools import InjectedToolCallId, tool
from langchain_core.messages import ToolMessage
from langgraph.types import Command
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.agents.thread_state import ThreadDataState
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
from deerflow.tools.types import Runtime
_ALLOWED_IMAGE_VIRTUAL_ROOTS = (
f"{VIRTUAL_PATH_PREFIX}/workspace",
@@ -48,7 +48,7 @@ def _sanitize_image_error(error: Exception, thread_data: ThreadDataState | None)
@tool("view_image", parse_docstring=True)
def view_image_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
image_path: str,
tool_call_id: Annotated[str, InjectedToolCallId],
) -> Command:
@@ -7,16 +7,15 @@ import logging
from typing import Any
from weakref import WeakValueDictionary
from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT
from langchain.tools import tool
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.agents.thread_state import ThreadState
from deerflow.mcp.tools import _make_sync_tool_wrapper
from deerflow.skills.security_scanner import scan_skill_content
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.storage.skill_storage import SkillStorage
from deerflow.skills.types import SKILL_MD_FILE
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
@@ -31,7 +30,7 @@ def _get_lock(name: str) -> asyncio.Lock:
return lock
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None:
def _get_thread_id(runtime: Runtime | None) -> str | None:
if runtime is None:
return None
if runtime.context and runtime.context.get("thread_id"):
@@ -65,7 +64,7 @@ async def _to_thread(func, /, *args, **kwargs):
async def _skill_manage_impl(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
action: str,
name: str,
content: str | None = None,
@@ -204,7 +203,7 @@ async def _skill_manage_impl(
@tool("skill_manage", parse_docstring=True)
async def skill_manage_tool(
runtime: ToolRuntime[ContextT, ThreadState],
runtime: Runtime,
action: str,
name: str,
content: str | None = None,
@@ -0,0 +1,11 @@
from typing import Any
from langchain.tools import ToolRuntime
from deerflow.agents.thread_state import ThreadState
# Concrete runtime type used by all DeerFlow tools.
# Using dict[str, Any] for the context parameter instead of the unbound ContextT
# TypeVar prevents PydanticSerializationUnexpectedValue warnings when LangChain
# calls model_dump() on a tool's auto-generated args_schema.
Runtime = ToolRuntime[dict[str, Any], ThreadState]
@@ -121,9 +121,11 @@ def open_upload_file_no_symlink(base_dir: Path, filename: str) -> tuple[Path, ob
Upload directories may be mounted into local sandboxes. A sandbox process can
therefore leave a symlink at a future upload filename. Normal ``Path.write_bytes``
follows that link and can overwrite files outside the uploads directory with
gateway privileges. This helper rejects symlink destinations and uses
``O_NOFOLLOW`` where available so the final path component cannot be raced into
a symlink between validation and open.
gateway privileges. This helper rejects symlink destinations using ``O_NOFOLLOW``
on POSIX. On Windows (which lacks ``O_NOFOLLOW``), it uses dual ``lstat`` checks
and ``fstat`` validation after ``open()`` to reduce the TOCTOU window; this does
not eliminate all races but makes exploitation significantly harder. Path-traversal
validation prevents escapes from *base_dir* in both cases.
"""
safe_name = normalize_filename(filename)
dest = base_dir / safe_name
@@ -138,23 +140,65 @@ def open_upload_file_no_symlink(base_dir: Path, filename: str) -> tuple[Path, ob
validate_path_traversal(dest, base_dir)
if not hasattr(os, "O_NOFOLLOW"):
raise UnsafeUploadPathError("Upload writes require O_NOFOLLOW support")
has_nofollow = hasattr(os, "O_NOFOLLOW")
flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW
if hasattr(os, "O_NONBLOCK"):
flags |= os.O_NONBLOCK
if has_nofollow:
# POSIX: O_NOFOLLOW makes open() fail with ELOOP if dest is a symlink.
flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW
if hasattr(os, "O_NONBLOCK"):
flags |= os.O_NONBLOCK
try:
fd = os.open(dest, flags, 0o600)
except OSError as exc:
if exc.errno in {errno.ELOOP, errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
raise
try:
opened_stat = os.fstat(fd)
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink != 1:
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
os.ftruncate(fd, 0)
fh = os.fdopen(fd, "wb")
fd = -1
finally:
if fd >= 0:
os.close(fd)
return dest, fh
# Windows: no O_NOFOLLOW available. Uses a second lstat immediately before open()
# to narrow the TOCTOU window, then fstat after open() as a further defence.
# Note: a narrow race window remains between the pre-open lstat and open(); the
# path-traversal check mitigates escapes from base_dir but cannot prevent an
# attacker who can atomically replace dest with a symlink after the check.
if st is not None and st.st_nlink > 1:
raise UnsafeUploadPathError(f"Upload destination has multiple links: {safe_name}")
flags = os.O_WRONLY | os.O_CREAT
if hasattr(os, "O_BINARY"):
flags |= os.O_BINARY
try:
pre_open_st = os.lstat(dest)
except FileNotFoundError:
pre_open_st = None
if pre_open_st is not None and not stat.S_ISREG(pre_open_st.st_mode):
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
if pre_open_st is not None and pre_open_st.st_nlink > 1:
raise UnsafeUploadPathError(f"Upload destination has multiple links: {safe_name}")
try:
fd = os.open(dest, flags, 0o600)
except OSError as exc:
if exc.errno in {errno.ELOOP, errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
if exc.errno in {errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
raise
try:
opened_stat = os.fstat(fd)
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink != 1:
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink > 1:
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
os.ftruncate(fd, 0)
fh = os.fdopen(fd, "wb")
+1 -1
View File
@@ -8,7 +8,7 @@ dependencies = [
"deerflow-harness",
"fastapi>=0.115.0",
"httpx>=0.28.0",
"python-multipart>=0.0.26",
"python-multipart>=0.0.27",
"sse-starlette>=2.1.0",
"uvicorn[standard]>=0.34.0",
"lark-oapi>=1.4.0",
+86 -3
View File
@@ -1,7 +1,7 @@
"""One-time migration: move legacy thread dirs and memory into per-user layout.
Usage:
PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run]
PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run] [--user-id USER_ID]
The script is idempotent re-running it after a successful migration is a no-op.
"""
@@ -69,6 +69,67 @@ def migrate_thread_dirs(
return report
def migrate_agents(
paths: Paths,
user_id: str = "default",
*,
dry_run: bool = False,
) -> list[dict]:
"""Move legacy custom-agent directories into per-user layout.
Legacy layout: ``{base_dir}/agents/{name}/``
Per-user layout: ``{base_dir}/users/{user_id}/agents/{name}/``
Pre-existing per-user agents take precedence: if a destination already
exists for an agent name, the legacy copy is moved to
``{base_dir}/migration-conflicts/agents/{name}/`` for manual review.
Args:
paths: Paths instance.
user_id: Target user to receive the legacy agents (defaults to
``"default"``, matching ``DEFAULT_USER_ID`` for no-auth setups).
dry_run: If True, only log what would happen.
Returns:
List of migration report entries, one per legacy agent directory found.
"""
report: list[dict] = []
legacy_agents = paths.agents_dir
if not legacy_agents.exists():
logger.info("No legacy agents directory found — nothing to migrate.")
return report
for agent_dir in sorted(legacy_agents.iterdir()):
if not agent_dir.is_dir():
continue
agent_name = agent_dir.name
dest = paths.user_agent_dir(user_id, agent_name)
entry = {"agent": agent_name, "user_id": user_id, "action": ""}
if dest.exists():
conflicts_dir = paths.base_dir / "migration-conflicts" / "agents" / agent_name
entry["action"] = f"conflict -> {conflicts_dir}"
if not dry_run:
conflicts_dir.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(agent_dir), str(conflicts_dir))
logger.warning("Conflict for agent %s: moved legacy copy to %s", agent_name, conflicts_dir)
else:
entry["action"] = f"moved -> {dest}"
if not dry_run:
dest.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(agent_dir), str(dest))
logger.info("Migrated agent %s -> user %s", agent_name, user_id)
report.append(entry)
# Clean up empty legacy agents dir
if not dry_run and legacy_agents.exists() and not any(legacy_agents.iterdir()):
legacy_agents.rmdir()
return report
def migrate_memory(
paths: Paths,
user_id: str = "default",
@@ -127,6 +188,12 @@ def _build_owner_map_from_db(paths: Paths) -> dict[str, str]:
def main() -> None:
parser = argparse.ArgumentParser(description="Migrate DeerFlow data to per-user layout")
parser.add_argument("--dry-run", action="store_true", help="Log actions without making changes")
parser.add_argument(
"--user-id",
default="default",
metavar="USER_ID",
help=("User ID to claim un-owned legacy data (global memory.json and legacy custom agents). Defaults to 'default'. In multi-user installs, set this to the operator account that should inherit those legacy artifacts."),
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
@@ -134,26 +201,42 @@ def main() -> None:
paths = get_paths()
logger.info("Base directory: %s", paths.base_dir)
logger.info("Dry run: %s", args.dry_run)
logger.info("Claiming un-owned legacy data for user_id=%s", args.user_id)
owner_map = _build_owner_map_from_db(paths)
logger.info("Found %d thread ownership records in DB", len(owner_map))
report = migrate_thread_dirs(paths, owner_map, dry_run=args.dry_run)
migrate_memory(paths, user_id="default", dry_run=args.dry_run)
migrate_memory(paths, user_id=args.user_id, dry_run=args.dry_run)
agent_report = migrate_agents(paths, user_id=args.user_id, dry_run=args.dry_run)
if report:
logger.info("Migration report:")
logger.info("Thread migration report:")
for entry in report:
logger.info(" thread=%s user=%s action=%s", entry["thread_id"], entry["user_id"], entry["action"])
else:
logger.info("No threads to migrate.")
if agent_report:
logger.info("Agent migration report:")
for entry in agent_report:
logger.info(" agent=%s user=%s action=%s", entry["agent"], entry["user_id"], entry["action"])
else:
logger.info("No agents to migrate.")
unowned = [e for e in report if e["user_id"] == "default"]
if unowned:
logger.warning("%d thread(s) had no owner and were assigned to 'default':", len(unowned))
for e in unowned:
logger.warning(" %s", e["thread_id"])
if agent_report:
logger.warning(
"%d legacy agent(s) were assigned to '%s'. If those agents belonged to other users, move them manually under {base_dir}/users/<user_id>/agents/.",
len(agent_report),
args.user_id,
)
if __name__ == "__main__":
main()
@@ -0,0 +1,210 @@
"""Tests for AioSandboxProvider auto-restart of crashed containers."""
import importlib
import threading
from unittest.mock import MagicMock, patch
def _import_provider():
return importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
def _make_provider(*, auto_restart=True, alive=True):
"""Build a minimal AioSandboxProvider with a mock backend.
Args:
auto_restart: Value for the auto_restart config key.
alive: Whether the mock backend reports containers as alive.
"""
mod = _import_provider()
with patch.object(mod.AioSandboxProvider, "_start_idle_checker"):
provider = mod.AioSandboxProvider.__new__(mod.AioSandboxProvider)
provider._config = {"auto_restart": auto_restart}
provider._lock = threading.Lock()
provider._sandboxes = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._thread_locks = {}
provider._last_activity = {}
provider._warm_pool = {}
provider._shutdown_called = False
provider._idle_checker_stop = threading.Event()
backend = MagicMock()
backend.is_alive.return_value = alive
provider._backend = backend
return provider, backend
def _seed_sandbox(provider, sandbox_id="dead-beef", thread_id="thread-1"):
"""Insert a sandbox into the provider's caches as if it were acquired."""
sandbox = MagicMock()
info = MagicMock()
provider._sandboxes[sandbox_id] = sandbox
provider._sandbox_infos[sandbox_id] = info
provider._last_activity[sandbox_id] = 0.0
if thread_id:
provider._thread_sandboxes[thread_id] = sandbox_id
return sandbox, info
# ── get() returns sandbox when container is alive ──────────────────────────
def test_get_returns_sandbox_when_container_alive():
"""When auto_restart is on and the container is alive, get() returns the sandbox."""
provider, backend = _make_provider(auto_restart=True, alive=True)
sandbox, _ = _seed_sandbox(provider)
result = provider.get("dead-beef")
assert result is sandbox
backend.is_alive.assert_called_once()
def test_get_returns_sandbox_when_auto_restart_disabled():
"""When auto_restart is off, get() skips the health check entirely."""
provider, backend = _make_provider(auto_restart=False)
sandbox, _ = _seed_sandbox(provider)
result = provider.get("dead-beef")
assert result is sandbox
backend.is_alive.assert_not_called()
# ── get() evicts dead sandbox when auto_restart is on ──────────────────────
def test_get_evicts_dead_sandbox_when_auto_restart_enabled():
"""When the container is dead and auto_restart is on, get() returns None and cleans caches."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_, info = _seed_sandbox(provider, sandbox_id="dead-beef", thread_id="thread-1")
result = provider.get("dead-beef")
assert result is None
assert "dead-beef" not in provider._sandboxes
assert "dead-beef" not in provider._sandbox_infos
assert "dead-beef" not in provider._last_activity
assert "thread-1" not in provider._thread_sandboxes
backend.destroy.assert_called_once_with(info)
def test_get_returns_dead_sandbox_when_auto_restart_disabled():
"""When auto_restart is off, get() returns the cached sandbox even if the container is dead."""
provider, backend = _make_provider(auto_restart=False, alive=False)
sandbox, _ = _seed_sandbox(provider)
result = provider.get("dead-beef")
assert result is sandbox
# Caches are untouched
assert "dead-beef" in provider._sandboxes
def test_get_eviction_cleans_multiple_thread_mappings():
"""A sandbox mapped to multiple thread IDs has all mappings cleaned on eviction."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-1", thread_id="t-a")
# Manually add a second thread mapping to the same sandbox
provider._thread_sandboxes["t-b"] = "sid-1"
result = provider.get("sid-1")
assert result is None
assert "t-a" not in provider._thread_sandboxes
assert "t-b" not in provider._thread_sandboxes
# ── get() does not check health for unknown sandbox IDs ────────────────────
def test_get_returns_none_for_unknown_id():
"""If the sandbox_id is not in cache, get() returns None without checking health."""
provider, backend = _make_provider(auto_restart=True, alive=True)
result = provider.get("nonexistent")
assert result is None
backend.is_alive.assert_not_called()
# ── get() handles missing sandbox_info gracefully ──────────────────────────
def test_get_handles_missing_info_gracefully():
"""If sandbox is cached but info is missing, get() skips the health check."""
provider, backend = _make_provider(auto_restart=True, alive=False)
sandbox = MagicMock()
provider._sandboxes["sid-x"] = sandbox
provider._sandbox_infos.pop("sid-x", None) # Ensure no info
provider._last_activity["sid-x"] = 0.0
result = provider.get("sid-x")
# No info → cannot call is_alive → sandbox returned as-is
assert result is sandbox
backend.is_alive.assert_not_called()
def test_get_liveness_check_runs_outside_provider_lock():
"""get() should not hold the provider lock while checking backend liveness."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-locked", thread_id="thread-1")
def _assert_lock_not_held(_):
assert not provider._lock.locked()
return False
backend.is_alive.side_effect = _assert_lock_not_held
assert provider.get("sid-locked") is None
def test_get_still_evicts_when_backend_destroy_fails():
"""Cleanup errors should not keep stale sandbox state in memory."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-fail", thread_id="thread-1")
backend.destroy.side_effect = RuntimeError("boom")
assert provider.get("sid-fail") is None
assert "sid-fail" not in provider._sandboxes
assert "sid-fail" not in provider._sandbox_infos
assert "thread-1" not in provider._thread_sandboxes
backend.destroy.assert_called_once()
# ── Integration: eviction clears caches for recreation ─────────────────────
def test_eviction_clears_all_caches_for_recreation():
"""After eviction, all caches are clean so _acquire_internal can recreate.
This verifies the preconditions for transparent restart: when get() evicts
a dead sandbox, the next _acquire_internal call will find no cached entry,
no warm-pool entry, and fall through to _create_sandbox.
"""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-1", thread_id="thread-1")
# Before eviction: caches populated
assert "sid-1" in provider._sandboxes
assert "sid-1" in provider._sandbox_infos
assert "thread-1" in provider._thread_sandboxes
# get() detects the dead container and evicts
assert provider.get("sid-1") is None
# After eviction: all caches clean
assert "sid-1" not in provider._sandboxes
assert "sid-1" not in provider._sandbox_infos
assert "thread-1" not in provider._thread_sandboxes
assert "sid-1" not in provider._warm_pool
# _acquire_internal for the same thread would find nothing cached
# and generate the deterministic ID, then discover fails (container
# is gone), falling through to _create_sandbox — a fresh start.
+213 -1
View File
@@ -4,10 +4,40 @@ import json
import os
from pathlib import Path
import pytest
import yaml
from pydantic import ValidationError
from deerflow.config.agents_api_config import get_agents_api_config
import deerflow.config.app_config as app_config_module
from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.agents_api_config import get_agents_api_config, load_agents_api_config_from_dict
from deerflow.config.app_config import AppConfig, get_app_config, reset_app_config
from deerflow.config.checkpointer_config import get_checkpointer_config, load_checkpointer_config_from_dict
from deerflow.config.guardrails_config import get_guardrails_config, load_guardrails_config_from_dict
from deerflow.config.memory_config import get_memory_config, load_memory_config_from_dict
from deerflow.config.stream_bridge_config import get_stream_bridge_config, load_stream_bridge_config_from_dict
from deerflow.config.subagents_config import get_subagents_app_config, load_subagents_config_from_dict
from deerflow.config.summarization_config import get_summarization_config, load_summarization_config_from_dict
from deerflow.config.title_config import get_title_config, load_title_config_from_dict
from deerflow.config.tool_search_config import get_tool_search_config, load_tool_search_config_from_dict
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
from deerflow.runtime.store import get_store, reset_store
def _reset_config_singletons() -> None:
load_title_config_from_dict({})
load_summarization_config_from_dict({})
load_memory_config_from_dict({})
load_agents_api_config_from_dict({})
load_subagents_config_from_dict({})
load_tool_search_config_from_dict({})
load_guardrails_config_from_dict({})
load_checkpointer_config_from_dict(None)
load_stream_bridge_config_from_dict(None)
load_acp_config_from_dict({})
reset_checkpointer()
reset_store()
reset_app_config()
def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> None:
@@ -53,6 +83,23 @@ def _write_config_with_agents_api(
path.write_text(yaml.safe_dump(config), encoding="utf-8")
def _write_config_with_sections(path: Path, sections: dict | None = None) -> None:
config = {
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
"models": [
{
"name": "first-model",
"use": "langchain_openai:ChatOpenAI",
"model": "gpt-test",
}
],
}
if sections:
config.update(sections)
path.write_text(yaml.safe_dump(config), encoding="utf-8")
def _write_extensions_config(path: Path) -> None:
path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
@@ -175,3 +222,168 @@ def test_get_app_config_resets_agents_api_config_when_section_removed(tmp_path,
assert get_agents_api_config().enabled is False
finally:
reset_app_config()
def test_get_app_config_resets_singleton_configs_when_sections_removed(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config_with_sections(
config_path,
{
"title": {"enabled": False, "max_words": 3},
"summarization": {"enabled": True},
"memory": {"enabled": False, "max_facts": 50},
"subagents": {"timeout_seconds": 42, "agents": {"reviewer": {"max_turns": 2}}},
"tool_search": {"enabled": True},
"guardrails": {"enabled": True, "fail_closed": False},
"checkpointer": {"type": "memory"},
"stream_bridge": {"type": "memory", "queue_maxsize": 12},
},
)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
reset_app_config()
try:
get_app_config()
assert get_title_config().enabled is False
assert get_summarization_config().enabled is True
assert get_memory_config().enabled is False
assert get_subagents_app_config().timeout_seconds == 42
assert get_tool_search_config().enabled is True
assert get_guardrails_config().enabled is True
assert get_checkpointer_config() is not None
assert get_stream_bridge_config() is not None
_write_config_with_sections(config_path)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
get_app_config()
assert get_title_config().enabled is True
assert get_summarization_config().enabled is False
assert get_memory_config().enabled is True
assert get_subagents_app_config().timeout_seconds == 900
assert get_tool_search_config().enabled is False
assert get_guardrails_config().enabled is False
assert get_checkpointer_config() is None
assert get_stream_bridge_config() is None
finally:
_reset_config_singletons()
def test_get_app_config_resets_persistence_runtime_singletons_when_checkpointer_removed(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config_with_sections(config_path, {"checkpointer": {"type": "memory"}})
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
reset_checkpointer()
reset_store()
reset_app_config()
try:
get_app_config()
initial_checkpointer = get_checkpointer()
initial_store = get_store()
_write_config_with_sections(config_path)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
get_app_config()
assert get_checkpointer_config() is None
assert get_checkpointer() is not initial_checkpointer
assert get_store() is not initial_store
finally:
_reset_config_singletons()
def test_get_app_config_keeps_persistence_runtime_singletons_when_checkpointer_unchanged(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config_with_sections(
config_path,
{
"title": {"enabled": False},
"checkpointer": {"type": "memory"},
},
)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
_reset_config_singletons()
try:
get_app_config()
initial_checkpointer = get_checkpointer()
initial_store = get_store()
_write_config_with_sections(
config_path,
{
"title": {"enabled": True},
"checkpointer": {"type": "memory"},
},
)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
get_app_config()
assert get_checkpointer() is initial_checkpointer
assert get_store() is initial_store
finally:
_reset_config_singletons()
def test_get_app_config_does_not_mutate_singletons_when_reload_validation_fails(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config_with_sections(
config_path,
{
"title": {"enabled": False},
"tool_search": {"enabled": True},
"checkpointer": {"type": "memory"},
},
)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
_reset_config_singletons()
try:
previous_app_config = get_app_config()
initial_checkpointer = get_checkpointer()
initial_store = get_store()
_write_config_with_sections(
config_path,
{
"title": False,
"tool_search": False,
"checkpointer": {"type": "memory"},
},
)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
with pytest.raises(ValidationError):
get_app_config()
assert app_config_module._app_config is previous_app_config
assert get_title_config().enabled is False
assert get_tool_search_config().enabled is True
assert get_checkpointer_config() is not None
assert get_checkpointer() is initial_checkpointer
assert get_store() is initial_store
finally:
_reset_config_singletons()
+199
View File
@@ -372,6 +372,37 @@ class TestExtractResponseText:
# Should return "" (no text in current turn), NOT "Hi there!" from previous turn
assert _extract_response_text(result) == ""
def test_does_not_publish_loop_warning_on_tool_calling_ai_message(self):
"""Loop-detection warning text on a tool-calling AI message is middleware-authored."""
from app.channels.manager import _extract_response_text
result = {
"messages": [
{"type": "human", "content": "search the repo"},
{
"type": "ai",
"content": "[LOOP DETECTED] You are repeating the same tool calls.",
"tool_calls": [{"name": "grep", "args": {"pattern": "TODO"}, "id": "call_1"}],
},
]
}
assert _extract_response_text(result) == ""
def test_preserves_visible_text_when_stripping_loop_warning(self):
from app.channels.manager import _extract_response_text
result = {
"messages": [
{"type": "human", "content": "prepare the report"},
{
"type": "ai",
"content": "Here is the report.\n\n[LOOP DETECTED] You are repeating the same tool calls.",
"tool_calls": [{"name": "present_files", "args": {"filepaths": ["/mnt/user-data/outputs/report.md"]}, "id": "call_1"}],
},
]
}
assert _extract_response_text(result) == "Here is the report."
# ---------------------------------------------------------------------------
# ChannelManager tests
@@ -435,6 +466,47 @@ class TestChannelManager:
assert headers["Cookie"] == f"csrf_token={csrf_token}"
assert headers["X-DeerFlow-Internal-Token"]
def test_fetch_gateway_includes_internal_auth_headers(self, monkeypatch):
from app.channels.manager import ChannelManager
class MockResponse:
def raise_for_status(self):
return None
def json(self):
return {"models": [{"name": "default"}]}
class MockAsyncClient:
def __init__(self, *args, **kwargs):
return None
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, **kwargs):
calls.append({"url": url, **kwargs})
return MockResponse()
calls = []
monkeypatch.setattr("app.channels.manager.httpx.AsyncClient", MockAsyncClient)
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(bus=bus, store=store, gateway_url="http://gateway:8001")
reply = await manager._fetch_gateway("/api/models", "models")
assert reply == "Available models:\n• default"
assert calls[0]["url"] == "http://gateway:8001/api/models"
assert calls[0]["timeout"] == 10
assert calls[0]["headers"]["X-DeerFlow-Internal-Token"]
_run(go())
def test_handle_chat_calls_channel_receive_file_for_inbound_files(self, monkeypatch):
from app.channels.manager import ChannelManager
@@ -530,6 +602,8 @@ class TestChannelManager:
assert call_args[0][0] == "test-thread-123" # thread_id
assert call_args[0][1] == "lead_agent" # assistant_id
assert call_args[1]["input"]["messages"][0]["content"] == "hi"
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
assert len(outbound_received) == 1
assert outbound_received[0].text == "Hello from agent!"
@@ -661,12 +735,135 @@ class TestChannelManager:
call_args = mock_client.runs.wait.call_args
assert call_args[0][1] == "lead_agent"
assert call_args[1]["config"]["recursion_limit"] == 55
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
assert call_args[1]["context"]["thinking_enabled"] is False
assert call_args[1]["context"]["subagent_enabled"] is True
assert call_args[1]["context"]["agent_name"] == "mobile-agent"
_run(go())
def test_clarification_follow_up_preserves_history(self):
"""Conversation should continue after ask_clarification instead of resetting history."""
from app.channels.manager import ChannelManager
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(bus=bus, store=store)
outbound_received = []
async def capture_outbound(msg):
outbound_received.append(msg)
bus.subscribe_outbound(capture_outbound)
history_by_checkpoint: dict[tuple[str, str], list[str]] = {}
async def _runs_wait(thread_id, assistant_id, *, input, config, context):
del assistant_id, context # unused in this test, kept for signature parity
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
key = (thread_id, str(checkpoint_ns))
history = history_by_checkpoint.setdefault(key, [])
human_text = input["messages"][0]["content"]
history.append(human_text)
if len(history) == 1:
return {
"messages": [
{"type": "human", "content": history[0]},
{
"type": "ai",
"content": "",
"tool_calls": [
{
"name": "ask_clarification",
"args": {"question": "Which environment should I use?"},
}
],
},
{
"type": "tool",
"name": "ask_clarification",
"content": "Which environment should I use?",
},
]
}
if len(history) == 2 and history[0] == "Deploy my app" and history[1] == "prod":
return {
"messages": [
{"type": "human", "content": history[0]},
{
"type": "ai",
"content": "",
"tool_calls": [
{
"name": "ask_clarification",
"args": {"question": "Which environment should I use?"},
}
],
},
{
"type": "tool",
"name": "ask_clarification",
"content": "Which environment should I use?",
},
{"type": "human", "content": history[1]},
{"type": "ai", "content": "Got it. I will deploy to prod."},
]
}
return {
"messages": [
{"type": "human", "content": history[-1]},
{"type": "ai", "content": "History missing; clarification repeated."},
]
}
mock_client = MagicMock()
mock_client.threads.create = AsyncMock(return_value={"thread_id": "clarify-thread-1"})
mock_client.threads.get = AsyncMock(return_value={"thread_id": "clarify-thread-1"})
mock_client.runs.wait = AsyncMock(side_effect=_runs_wait)
manager._client = mock_client
await manager.start()
await bus.publish_inbound(
InboundMessage(
channel_name="test",
chat_id="chat1",
user_id="user1",
text="Deploy my app",
)
)
await _wait_for(lambda: len(outbound_received) >= 1)
await bus.publish_inbound(
InboundMessage(
channel_name="test",
chat_id="chat1",
user_id="user1",
text="prod",
)
)
await _wait_for(lambda: len(outbound_received) >= 2)
await manager.stop()
assert outbound_received[0].text == "Which environment should I use?"
assert outbound_received[1].text == "Got it. I will deploy to prod."
assert mock_client.runs.wait.call_count == 2
first_call = mock_client.runs.wait.call_args_list[0]
second_call = mock_client.runs.wait.call_args_list[1]
assert first_call.kwargs["config"]["configurable"]["checkpoint_ns"] == ""
assert second_call.kwargs["config"]["configurable"]["checkpoint_ns"] == ""
_run(go())
def test_handle_chat_uses_user_session_overrides(self):
from app.channels.manager import ChannelManager
@@ -1343,6 +1540,8 @@ class TestChannelManager:
call_args = mock_client.runs.stream.call_args
assert call_args[1]["input"]["messages"][0]["content"] == "hello"
assert call_args[1]["config"]["configurable"]["checkpoint_ns"] == ""
assert call_args[1]["config"]["configurable"]["thread_id"] == "test-thread-123"
assert call_args[1]["context"]["is_bootstrap"] is True
# Final message should be published
+41 -1
View File
@@ -1,6 +1,8 @@
"""Unit tests for checkpointer config and singleton factory."""
"""Unit tests for checkpointer config, packaging metadata, and factories."""
import sys
import tomllib
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -13,6 +15,8 @@ from deerflow.config.checkpointer_config import (
set_checkpointer_config,
)
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
from deerflow.runtime.checkpointer.provider import POSTGRES_INSTALL
from deerflow.runtime.store.provider import POSTGRES_STORE_INSTALL
@pytest.fixture(autouse=True)
@@ -67,6 +71,42 @@ class TestCheckpointerConfig:
with pytest.raises(Exception):
load_checkpointer_config_from_dict({"type": "unknown"})
def test_connection_string_description_matches_runtime_defaults(self):
description = CheckpointerConfig.model_fields["connection_string"].description
assert description is not None
assert "Optional for sqlite" in description
assert "defaults to 'store.db'" in description
assert "Required for postgres" in description
class TestHarnessPackaging:
def test_pyproject_declares_postgres_extra(self):
pyproject_path = Path(__file__).resolve().parents[1] / "packages" / "harness" / "pyproject.toml"
data = tomllib.loads(pyproject_path.read_text())
optional_dependencies = data["project"]["optional-dependencies"]
assert "postgres" in optional_dependencies
assert optional_dependencies["postgres"] == [
"asyncpg>=0.29",
"langgraph-checkpoint-postgres>=3.0.5",
"psycopg[binary]>=3.3.3",
"psycopg-pool>=3.3.0",
]
def test_workspace_pyproject_forwards_postgres_extra_to_harness(self):
pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
data = tomllib.loads(pyproject_path.read_text())
optional_dependencies = data["project"]["optional-dependencies"]
assert optional_dependencies["postgres"] == ["deerflow-harness[postgres]"]
def test_postgres_missing_dependency_messages_recommend_package_extra(self):
assert "deerflow-harness[postgres]" in POSTGRES_INSTALL
assert "deerflow-harness[postgres]" in POSTGRES_STORE_INSTALL
assert "uv sync --all-packages --extra postgres" in POSTGRES_INSTALL
assert "uv sync --all-packages --extra postgres" in POSTGRES_STORE_INSTALL
# ---------------------------------------------------------------------------
# Factory tests
@@ -192,6 +192,7 @@ def test_agent_features_defaults():
assert f.vision is False
assert f.auto_title is False
assert f.guardrail is False
assert f.loop_detection is True
# ---------------------------------------------------------------------------
@@ -630,6 +631,51 @@ def test_loop_detection_before_clarification(mock_create_agent):
assert loop_idx == clar_idx - 1
# ---------------------------------------------------------------------------
# 30b. loop_detection=False skips LoopDetectionMiddleware
# ---------------------------------------------------------------------------
@patch("deerflow.agents.factory.create_agent")
def test_loop_detection_disabled(mock_create_agent):
mock_create_agent.return_value = MagicMock()
create_deerflow_agent(
_make_mock_model(),
features=RuntimeFeatures(sandbox=False, loop_detection=False),
)
call_kwargs = mock_create_agent.call_args[1]
mw_types = [type(m).__name__ for m in call_kwargs["middleware"]]
assert "LoopDetectionMiddleware" not in mw_types
# ---------------------------------------------------------------------------
# 30c. loop_detection=<custom AgentMiddleware> replaces the default
# ---------------------------------------------------------------------------
@patch("deerflow.agents.factory.create_agent")
def test_loop_detection_custom_middleware(mock_create_agent):
from langchain.agents.middleware import AgentMiddleware as AM
mock_create_agent.return_value = MagicMock()
class MyLoopDetection(AM):
pass
custom = MyLoopDetection()
create_deerflow_agent(
_make_mock_model(),
features=RuntimeFeatures(sandbox=False, loop_detection=custom),
)
call_kwargs = mock_create_agent.call_args[1]
middleware = call_kwargs["middleware"]
assert custom in middleware
mw_types = [type(m).__name__ for m in middleware]
# Default LoopDetectionMiddleware must not also appear.
assert "LoopDetectionMiddleware" not in mw_types
# Custom replacement still sits immediately before ClarificationMiddleware.
assert mw_types[-1] == "ClarificationMiddleware"
assert mw_types[-2] == "MyLoopDetection"
# ---------------------------------------------------------------------------
# 31. plan_mode=True adds TodoMiddleware
# ---------------------------------------------------------------------------
+2
View File
@@ -85,6 +85,8 @@ def test_load_claude_code_credential_from_override_path(tmp_path, monkeypatch):
def test_load_claude_code_credential_ignores_directory_path(tmp_path, monkeypatch):
_clear_claude_code_env(monkeypatch)
# Redirect HOME so the default ~/.claude/.credentials.json doesn't exist
monkeypatch.setenv("HOME", str(tmp_path))
cred_dir = tmp_path / "claude-creds-dir"
cred_dir.mkdir()
monkeypatch.setenv("CLAUDE_CODE_CREDENTIALS_PATH", str(cred_dir))
+235
View File
@@ -0,0 +1,235 @@
"""Tests for CSRF middleware."""
from fastapi import FastAPI
from starlette.testclient import TestClient
from app.gateway.csrf_middleware import CSRFMiddleware
def _make_app() -> FastAPI:
app = FastAPI()
app.add_middleware(CSRFMiddleware)
@app.post("/api/v1/auth/login/local")
async def login_local():
return {"ok": True}
@app.post("/api/v1/auth/register")
async def register():
return {"ok": True}
@app.post("/api/threads/abc/runs/stream")
async def protected_mutation():
return {"ok": True}
return app
def test_auth_post_rejects_cross_origin_browser_request():
"""CSRF-exempt auth routes must not accept hostile browser origins.
Login/register endpoints intentionally skip the double-submit token because
first-time callers do not have a token yet. They still set an auth session,
so a hostile cross-site form POST must be rejected to avoid login CSRF /
session fixation.
"""
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://evil.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
def test_auth_post_allows_same_origin_browser_request():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_rejects_malformed_origin_with_path():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example/path"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
assert response.cookies.get("csrf_token") is None
def test_auth_post_rejects_malformed_origin_with_invalid_port():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example:bad"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
assert response.cookies.get("csrf_token") is None
def test_auth_post_allows_same_origin_default_port_equivalence():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example:443"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_allows_forwarded_same_origin():
client = TestClient(_make_app(), base_url="http://internal:8000")
response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "https://deerflow.example",
"X-Forwarded-Proto": "https",
"X-Forwarded-Host": "deerflow.example, internal:8000",
},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_allows_forwarded_same_origin_with_non_default_port():
client = TestClient(_make_app(), base_url="http://internal:8000")
response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "http://localhost:2026",
"X-Forwarded-Proto": "http",
"X-Forwarded-Host": "localhost:2026",
},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_allows_rfc_forwarded_same_origin():
client = TestClient(_make_app(), base_url="http://internal:8000")
response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "https://deerflow.example",
"Forwarded": "proto=https;host=deerflow.example",
},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
assert "secure" in response.headers["set-cookie"].lower()
def test_auth_post_allows_explicit_configured_origin(monkeypatch):
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "https://app.example")
client = TestClient(_make_app(), base_url="https://api.example")
response = client.post(
"/api/v1/auth/register",
headers={"Origin": "https://app.example"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_does_not_treat_wildcard_cors_as_allowed_origin(monkeypatch):
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "*")
client = TestClient(_make_app(), base_url="https://api.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://evil.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
def test_auth_post_sets_strict_samesite_csrf_cookie():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 200
set_cookie = response.headers["set-cookie"].lower()
assert "csrf_token=" in set_cookie
assert "samesite=strict" in set_cookie
assert "secure" in set_cookie
def test_auth_post_without_origin_still_allows_non_browser_clients():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post("/api/v1/auth/login/local")
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_non_auth_mutation_still_requires_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/threads/abc/runs/stream",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."
def test_non_auth_mutation_allows_valid_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
client.cookies.set("csrf_token", "known-token")
response = client.post(
"/api/threads/abc/runs/stream",
headers={
"Origin": "https://deerflow.example",
"X-CSRF-Token": "known-token",
},
)
assert response.status_code == 200
def test_non_auth_mutation_rejects_mismatched_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
client.cookies.set("csrf_token", "cookie-token")
response = client.post(
"/api/threads/abc/runs/stream",
headers={
"Origin": "https://deerflow.example",
"X-CSRF-Token": "header-token",
},
)
assert response.status_code == 403
assert response.json()["detail"] == "CSRF token mismatch."
+16 -2
View File
@@ -537,7 +537,10 @@ class TestAgentsAPI:
def test_create_persists_files_on_disk(self, agent_client, tmp_path):
agent_client.post("/api/agents", json={"name": "disk-check", "soul": "disk soul"})
agent_dir = tmp_path / "agents" / "disk-check"
# tests/conftest.py installs an autouse fixture that sets the
# contextvar to "test-user-autouse", so the agent is persisted under
# users/test-user-autouse/agents/ rather than the legacy shared dir.
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "disk-check"
assert agent_dir.exists()
assert (agent_dir / "config.yaml").exists()
assert (agent_dir / "SOUL.md").exists()
@@ -545,12 +548,23 @@ class TestAgentsAPI:
def test_delete_removes_files_from_disk(self, agent_client, tmp_path):
agent_client.post("/api/agents", json={"name": "remove-me", "soul": "bye"})
agent_dir = tmp_path / "agents" / "remove-me"
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "remove-me"
assert agent_dir.exists()
agent_client.delete("/api/agents/remove-me")
assert not agent_dir.exists()
def test_create_rejects_legacy_name_collision(self, agent_client, tmp_path):
"""An unmigrated legacy agent must still block name collision so that
running the migration script later won't shadow the legacy entry."""
legacy_dir = tmp_path / "agents" / "legacy-agent"
legacy_dir.mkdir(parents=True)
(legacy_dir / "config.yaml").write_text("name: legacy-agent\n", encoding="utf-8")
(legacy_dir / "SOUL.md").write_text("legacy soul", encoding="utf-8")
response = agent_client.post("/api/agents", json={"name": "legacy-agent", "soul": "x"})
assert response.status_code == 409
# ===========================================================================
# 9. Gateway API User Profile endpoints
+201
View File
@@ -0,0 +1,201 @@
"""Unit tests for scripts/detect_uv_extras.py.
The detector resolves uv extras for `make dev` so that postgres (and any
future opt-in extras) are not wiped on every restart see Issue #2754.
"""
from __future__ import annotations
import importlib.util
from pathlib import Path
import pytest
REPO_ROOT = Path(__file__).resolve().parents[2]
DETECT_SCRIPT_PATH = REPO_ROOT / "scripts" / "detect_uv_extras.py"
spec = importlib.util.spec_from_file_location("deerflow_detect_uv_extras", DETECT_SCRIPT_PATH)
assert spec is not None and spec.loader is not None
detect = importlib.util.module_from_spec(spec)
spec.loader.exec_module(detect)
@pytest.fixture
def isolated_cwd(tmp_path, monkeypatch):
"""Isolate `find_config_file()` from the real repo by chdir + clearing env."""
monkeypatch.chdir(tmp_path)
monkeypatch.delenv("UV_EXTRAS", raising=False)
monkeypatch.delenv("DEER_FLOW_CONFIG_PATH", raising=False)
return tmp_path
def test_parse_env_extras_supports_comma_and_whitespace():
assert detect.parse_env_extras("postgres") == ["postgres"]
assert detect.parse_env_extras("postgres,ollama") == ["postgres", "ollama"]
assert detect.parse_env_extras("postgres ollama") == ["postgres", "ollama"]
assert detect.parse_env_extras(" postgres , ollama ,") == ["postgres", "ollama"]
assert detect.parse_env_extras("") == []
def test_parse_env_extras_drops_shell_metacharacters(capsys):
"""A `.env` value containing shell injection bait must not pass through.
The whitelist guarantees the *bytes* that reach `uv sync` cannot include
shell metacharacters. Any name that looks identifier-like still survives
(uv itself will reject unknown extras with its own error), but `;`, `&`,
backticks, parentheses, slashes, etc. are stripped.
"""
# Pure-metacharacter inputs collapse to empty.
assert detect.parse_env_extras(";") == []
assert detect.parse_env_extras("$(whoami)") == []
assert detect.parse_env_extras("`echo bad`") == []
assert detect.parse_env_extras("postgres;evil") == [] # single token, contains `;`
# Splitting on whitespace yields ['rm'] which is identifier-shaped, but the
# destructive bits (`;`, `-rf`, `/`) are dropped.
assert detect.parse_env_extras("; rm -rf /") == ["rm"]
err = capsys.readouterr().err
assert "ignoring invalid UV_EXTRAS entry" in err
assert "';'" in err # confirms the dangerous token was reported and dropped
def test_parse_env_extras_rejects_leading_digits_and_punctuation():
"""Names must start with a letter — pyproject extras follow this shape."""
assert detect.parse_env_extras("1postgres") == []
assert detect.parse_env_extras("-postgres") == []
# Hyphens and underscores inside the name are fine.
assert detect.parse_env_extras("post_gres") == ["post_gres"]
assert detect.parse_env_extras("post-gres") == ["post-gres"]
def test_format_flags_emits_one_flag_per_extra():
assert detect.format_flags([]) == ""
assert detect.format_flags(["postgres"]) == "--extra postgres"
assert detect.format_flags(["postgres", "ollama"]) == "--extra postgres --extra ollama"
def test_strip_comment_preserves_quoted_hash():
assert detect._strip_comment("backend: postgres # trailing") == "backend: postgres"
assert detect._strip_comment('name: "value#with-hash"') == 'name: "value#with-hash"'
assert detect._strip_comment("# whole line comment") == ""
def test_section_value_finds_nested_key():
yaml_lines = [
"database:",
" backend: postgres",
" postgres_url: $DATABASE_URL",
"",
"checkpointer:",
" type: sqlite",
]
assert detect.section_value(yaml_lines, "database", "backend") == "postgres"
assert detect.section_value(yaml_lines, "checkpointer", "type") == "sqlite"
assert detect.section_value(yaml_lines, "database", "missing") is None
assert detect.section_value(yaml_lines, "absent_section", "anything") is None
def test_section_value_ignores_commented_lines():
yaml_lines = [
"# database:",
"# backend: postgres",
"database:",
" backend: sqlite",
]
assert detect.section_value(yaml_lines, "database", "backend") == "sqlite"
def test_section_value_strips_quotes():
yaml_lines = [
"database:",
' backend: "postgres"',
]
assert detect.section_value(yaml_lines, "database", "backend") == "postgres"
def test_section_value_does_not_descend_into_grandchildren():
yaml_lines = [
"database:",
" backend: sqlite",
" nested:",
" backend: postgres",
]
# Only the immediate child level counts — keeps the parser predictable.
assert detect.section_value(yaml_lines, "database", "backend") == "sqlite"
def test_detect_from_config_postgres_via_database(tmp_path):
cfg = tmp_path / "config.yaml"
cfg.write_text("database:\n backend: postgres\n postgres_url: $DATABASE_URL\n")
assert detect.detect_from_config(cfg) == ["postgres"]
def test_detect_from_config_postgres_via_checkpointer(tmp_path):
cfg = tmp_path / "config.yaml"
cfg.write_text("checkpointer:\n type: postgres\n connection_string: postgresql://localhost/db\n")
assert detect.detect_from_config(cfg) == ["postgres"]
def test_detect_from_config_sqlite_returns_no_extras(tmp_path):
cfg = tmp_path / "config.yaml"
cfg.write_text("database:\n backend: sqlite\n sqlite_dir: .deer-flow/data\n")
assert detect.detect_from_config(cfg) == []
def test_detect_from_config_dedupes_when_both_present(tmp_path):
cfg = tmp_path / "config.yaml"
cfg.write_text("checkpointer:\n type: postgres\ndatabase:\n backend: postgres\n")
# Sorted unique extras, no double-counting.
assert detect.detect_from_config(cfg) == ["postgres"]
def test_detect_from_config_missing_file_returns_empty(tmp_path):
assert detect.detect_from_config(tmp_path / "does-not-exist.yaml") == []
def test_resolve_extras_env_overrides_config(isolated_cwd, monkeypatch):
cfg = isolated_cwd / "config.yaml"
cfg.write_text("database:\n backend: sqlite\n")
monkeypatch.setenv("UV_EXTRAS", "postgres")
assert detect.resolve_extras() == ["postgres"]
def test_resolve_extras_env_supports_multiple(isolated_cwd, monkeypatch):
monkeypatch.setenv("UV_EXTRAS", "postgres,ollama")
assert detect.resolve_extras() == ["postgres", "ollama"]
def test_resolve_extras_falls_back_to_config(isolated_cwd):
(isolated_cwd / "config.yaml").write_text("database:\n backend: postgres\n")
assert detect.resolve_extras() == ["postgres"]
def test_resolve_extras_respects_explicit_config_path(tmp_path, monkeypatch):
monkeypatch.delenv("UV_EXTRAS", raising=False)
elsewhere = tmp_path / "elsewhere.yaml"
elsewhere.write_text("database:\n backend: postgres\n")
monkeypatch.chdir(tmp_path)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(elsewhere))
assert detect.resolve_extras() == ["postgres"]
def test_resolve_extras_no_config_no_env(isolated_cwd):
assert detect.resolve_extras() == []
def test_resolve_extras_finds_backend_subdir_config(isolated_cwd):
sub = isolated_cwd / "backend"
sub.mkdir()
(sub / "config.yaml").write_text("database:\n backend: postgres\n")
assert detect.resolve_extras() == ["postgres"]
def test_resolve_extras_root_config_takes_precedence(isolated_cwd):
(isolated_cwd / "config.yaml").write_text("database:\n backend: sqlite\n")
sub = isolated_cwd / "backend"
sub.mkdir()
(sub / "config.yaml").write_text("database:\n backend: postgres\n")
# Root config.yaml is checked first, matching the precedence in serve.sh.
assert detect.resolve_extras() == []
+102
View File
@@ -0,0 +1,102 @@
"""Unit tests for docker/dev-entrypoint.sh (UV_EXTRAS validation + parsing).
Exercises the script via its `--print-extras` dry-run hook so we don't actually
launch uvicorn or hit /app/logs. Together with test_detect_uv_extras.py these
cover both the local make-dev path and the docker-compose-dev path with the
same shape see PR #2767 / Issue #2754.
"""
from __future__ import annotations
import os
import subprocess
from pathlib import Path
import pytest
REPO_ROOT = Path(__file__).resolve().parents[2]
ENTRYPOINT = REPO_ROOT / "docker" / "dev-entrypoint.sh"
def _run(uv_extras: str | None) -> subprocess.CompletedProcess[str]:
"""Invoke `dev-entrypoint.sh --print-extras` with UV_EXTRAS set."""
env = os.environ.copy()
env.pop("UV_EXTRAS", None)
if uv_extras is not None:
env["UV_EXTRAS"] = uv_extras
return subprocess.run(
["sh", str(ENTRYPOINT), "--print-extras"],
env=env,
capture_output=True,
text=True,
check=False,
)
def test_entrypoint_script_exists_and_is_posix_sh():
assert ENTRYPOINT.is_file()
# Catch syntax errors before runtime — `sh -n` is a parse-only check.
proc = subprocess.run(["sh", "-n", str(ENTRYPOINT)], capture_output=True, text=True, check=False)
assert proc.returncode == 0, proc.stderr
def test_no_uv_extras_yields_empty_flags():
proc = _run(None)
assert proc.returncode == 0
assert proc.stdout.strip() == ""
def test_single_extra():
proc = _run("postgres")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra postgres"
def test_multi_extra_comma_separated():
proc = _run("postgres,ollama")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra postgres --extra ollama"
def test_multi_extra_whitespace_separated():
proc = _run("postgres ollama")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra postgres --extra ollama"
def test_multi_extra_mixed_separators():
proc = _run(" postgres , ollama ,")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra postgres --extra ollama"
def test_empty_string_yields_empty_flags():
proc = _run("")
assert proc.returncode == 0
assert proc.stdout.strip() == ""
@pytest.mark.parametrize(
"bad_value",
[
"; rm -rf /", # the canonical injection attempt
"$(whoami)", # command substitution
"`echo bad`", # backticks
"postgres;evil", # mixed legal+illegal in a single token
"1postgres", # leading digit
"-postgres", # leading hyphen
"post gres extra/path", # contains slash
],
)
def test_metacharacters_abort_with_nonzero_exit(bad_value):
proc = _run(bad_value)
assert proc.returncode != 0, f"expected abort for {bad_value!r}, got 0"
assert "is invalid" in proc.stderr
assert proc.stdout.strip() == ""
def test_underscores_and_hyphens_in_name_are_allowed():
"""Mirrors uv's accepted shape for `[project.optional-dependencies]` keys."""
proc = _run("post_gres,post-gres")
assert proc.returncode == 0
assert proc.stdout.strip() == "--extra post_gres --extra post-gres"
@@ -0,0 +1,336 @@
"""Tests for DynamicContextMiddleware.
Verifies that memory and current date are injected as a <system-reminder> into
the first HumanMessage exactly once per session (frozen-snapshot pattern).
"""
from types import SimpleNamespace
from unittest import mock
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares.dynamic_context_middleware import (
_DYNAMIC_CONTEXT_REMINDER_KEY,
DynamicContextMiddleware,
)
_SYSTEM_REMINDER_TAG = "<system-reminder>"
def _make_middleware(**kwargs) -> DynamicContextMiddleware:
return DynamicContextMiddleware(**kwargs)
def _fake_runtime():
return SimpleNamespace(context={})
def _reminder_msg(content: str, msg_id: str) -> HumanMessage:
"""Build a reminder HumanMessage the way the middleware would produce it."""
return HumanMessage(
content=content,
id=msg_id,
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
)
# ---------------------------------------------------------------------------
# Basic injection
# ---------------------------------------------------------------------------
def test_injects_system_reminder_into_first_human_message():
mw = _make_middleware()
state = {"messages": [HumanMessage(content="Hello", id="msg-1")]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
updated_msgs = result["messages"]
assert len(updated_msgs) == 2
reminder_msg = updated_msgs[0]
assert isinstance(reminder_msg, HumanMessage)
assert reminder_msg.id == "msg-1" # takes the original ID (position swap)
assert reminder_msg.additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert _SYSTEM_REMINDER_TAG in reminder_msg.content
assert "<current_date>2026-05-08, Friday</current_date>" in reminder_msg.content
assert "Hello" not in reminder_msg.content # reminder only — no user text
user_msg = updated_msgs[1]
assert isinstance(user_msg, HumanMessage)
assert user_msg.id == "msg-1__user" # derived ID
assert user_msg.content == "Hello"
def test_memory_included_when_present():
mw = _make_middleware()
state = {"messages": [HumanMessage(content="Hi", id="msg-1")]}
with (
mock.patch(
"deerflow.agents.lead_agent.prompt._get_memory_context",
return_value="<memory>\nUser prefers Python.\n</memory>",
),
mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt,
):
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
# Reminder is the first returned message; user query is the second
reminder_content = result["messages"][0].content
assert "User prefers Python." in reminder_content
assert "<current_date>2026-05-08, Friday</current_date>" in reminder_content
assert result["messages"][1].content == "Hi"
# ---------------------------------------------------------------------------
# Frozen-snapshot: no re-injection within a session
# ---------------------------------------------------------------------------
def test_skips_injection_if_already_present():
"""Second turn: separate reminder message already present → no update."""
mw = _make_middleware()
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
state = {
"messages": [
_reminder_msg(reminder_content, "msg-1"),
HumanMessage(content="Hello", id="msg-1__user"),
AIMessage(content="Hi there"),
HumanMessage(content="Follow-up", id="msg-2"),
]
}
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is None # no update needed
def test_injects_only_into_first_human_message_not_later_ones():
"""Reminder targets the first HumanMessage; subsequent messages are not touched."""
mw = _make_middleware()
state = {
"messages": [
HumanMessage(content="First", id="msg-1"),
AIMessage(content="Reply"),
HumanMessage(content="Second", id="msg-2"),
]
}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
msgs = result["messages"]
# Only the two injected messages are returned (reminder + original first query)
assert len(msgs) == 2
assert msgs[0].id == "msg-1" # reminder takes first message's ID
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert _SYSTEM_REMINDER_TAG in msgs[0].content
assert msgs[1].id == "msg-1__user" # original content with derived ID
assert msgs[1].content == "First"
# "Second" (msg-2) is not in the returned update — it is left unchanged
assert all(m.id != "msg-2" for m in msgs)
def test_summary_human_message_is_not_used_as_injection_target():
"""After summarization, the synthetic summary HumanMessage is not a user turn."""
mw = _make_middleware()
state = {
"messages": [
HumanMessage(content="Here is a summary of the conversation to date:\n\n...", id="summary-1", name="summary"),
AIMessage(content="Earlier reply"),
HumanMessage(content="Follow-up", id="msg-2"),
]
}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
msgs = result["messages"]
assert len(msgs) == 2
assert msgs[0].id == "msg-2"
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert msgs[1].id == "msg-2__user"
assert msgs[1].content == "Follow-up"
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
def test_no_messages_returns_none():
mw = _make_middleware()
result = mw.before_agent({"messages": []}, _fake_runtime())
assert result is None
def test_no_human_message_returns_none():
mw = _make_middleware()
state = {"messages": [AIMessage(content="assistant only")]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""):
result = mw.before_agent(state, _fake_runtime())
assert result is None
def test_list_content_message_handled_as_separate_reminder():
"""List-content (e.g. multi-modal) messages remain intact; reminder is a separate message."""
mw = _make_middleware()
original_content = [{"type": "text", "text": "Hello"}]
state = {"messages": [HumanMessage(content=original_content, id="msg-1")]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
msgs = result["messages"]
assert len(msgs) == 2
# Reminder is a plain string message with the flag set
assert isinstance(msgs[0].content, str)
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert _SYSTEM_REMINDER_TAG in msgs[0].content
# Original list-content message is untouched
assert msgs[1].content == original_content
def test_reminder_uses_original_id_user_message_uses_derived_id():
"""Reminder takes original ID (position swap); user message gets {id}__user."""
mw = _make_middleware()
original_id = "original-id-abc"
state = {"messages": [HumanMessage(content="Hello", id=original_id)]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result["messages"][0].id == original_id
assert result["messages"][1].id == f"{original_id}__user"
def test_message_without_id_gets_stable_uuid():
"""If the original HumanMessage has no ID, a UUID is generated and used consistently."""
mw = _make_middleware()
state = {"messages": [HumanMessage(content="Hello", id=None)]}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
reminder_id = result["messages"][0].id
user_id = result["messages"][1].id
assert reminder_id is not None
assert reminder_id != "None"
assert user_id == f"{reminder_id}__user"
def test_user_message_containing_system_reminder_tag_does_not_prevent_injection():
"""A user message containing '<system-reminder>' must not be mistaken for a reminder."""
mw = _make_middleware()
state = {
"messages": [
HumanMessage(content="What is <system-reminder>?", id="msg-1"),
]
}
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
result = mw.before_agent(state, _fake_runtime())
# Injection must happen — the user message does NOT carry the reminder flag
assert result is not None
assert result["messages"][0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
# ---------------------------------------------------------------------------
# Midnight crossing
# ---------------------------------------------------------------------------
def test_midnight_crossing_injects_date_update_as_separate_message():
"""When the date has changed, a separate date-update reminder is injected before
the current turn's HumanMessage using the ID-swap technique."""
mw = _make_middleware()
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
state = {
"messages": [
_reminder_msg(reminder_content, "msg-1"),
HumanMessage(content="Hello", id="msg-1__user"),
AIMessage(content="Response"),
HumanMessage(content="Good morning", id="msg-2"),
]
}
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
result = mw.before_agent(state, _fake_runtime())
assert result is not None
msgs = result["messages"]
assert len(msgs) == 2
# Date-update reminder takes the current message's ID
assert msgs[0].id == "msg-2"
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
assert _SYSTEM_REMINDER_TAG in msgs[0].content
assert "<current_date>2026-05-09, Saturday</current_date>" in msgs[0].content
assert "Good morning" not in msgs[0].content # reminder only
# Original user text appended with derived ID
assert msgs[1].id == "msg-2__user"
assert msgs[1].content == "Good morning"
def test_midnight_crossing_id_swap():
"""Date-update reminder uses original ID; user message uses {id}__user."""
mw = _make_middleware()
reminder_content = "<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>"
state = {
"messages": [
_reminder_msg(reminder_content, "msg-1"),
HumanMessage(content="Next day message", id="msg-2"),
]
}
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
result = mw.before_agent(state, _fake_runtime())
assert result["messages"][0].id == "msg-2"
assert result["messages"][1].id == "msg-2__user"
def test_no_second_midnight_injection_once_date_updated():
"""After a midnight update is persisted, the same-day path skips re-injection."""
mw = _make_middleware()
date_update_content = "<system-reminder>\n<current_date>2026-05-09, Saturday</current_date>\n</system-reminder>"
state = {
"messages": [
_reminder_msg(
"<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
"msg-1",
),
HumanMessage(content="Hello", id="msg-1__user"),
AIMessage(content="Response"),
_reminder_msg(date_update_content, "msg-2"),
HumanMessage(content="Good morning", id="msg-2__user"),
AIMessage(content="Good morning!"),
HumanMessage(content="Third turn", id="msg-3"),
]
}
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-09, Saturday"
result = mw.before_agent(state, _fake_runtime())
assert result is None # same day as last injected date → no update
@@ -50,7 +50,7 @@ def test_nginx_routes_official_langgraph_prefix_to_gateway_api():
assert "/api/langgraph-compat" not in content
assert "proxy_pass http://langgraph" not in content
assert "rewrite ^/api/langgraph/(.*) /api/$1 break;" in content
assert "proxy_pass http://gateway" in content
assert "proxy_pass http://gateway" in content or "proxy_pass http://$gateway_upstream" in content
def test_frontend_rewrites_langgraph_prefix_to_gateway():
+15
View File
@@ -324,6 +324,21 @@ def test_context_does_not_override_existing_configurable():
assert config["configurable"]["subagent_enabled"] is True
def test_inject_authenticated_user_context_overrides_client_user_id():
"""Run context should carry the authenticated user, not client-supplied user_id."""
from types import SimpleNamespace
from app.gateway.services import build_run_config, inject_authenticated_user_context
config = build_run_config("thread-1", None, None)
config["context"] = {"user_id": "spoofed-client"}
request = SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(id="auth-user-42")))
inject_authenticated_user_context(config, request)
assert config["context"]["user_id"] == "auth-user-42"
# ---------------------------------------------------------------------------
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
# ---------------------------------------------------------------------------
@@ -8,17 +8,20 @@ from unittest.mock import MagicMock
import pytest
from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.config.app_config import AppConfig
from deerflow.config.loop_detection_config import LoopDetectionConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.summarization_config import SummarizationConfig
def _make_app_config(models: list[ModelConfig]) -> AppConfig:
def _make_app_config(models: list[ModelConfig], loop_detection: LoopDetectionConfig | None = None) -> AppConfig:
return AppConfig(
models=models,
sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider"),
loop_detection=loop_detection or LoopDetectionConfig(),
)
@@ -340,6 +343,59 @@ def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypa
assert middlewares[0] == "base-middleware"
def test_build_middlewares_uses_loop_detection_config(monkeypatch):
app_config = _make_app_config(
[_make_model("safe-model", supports_thinking=False)],
loop_detection=LoopDetectionConfig(
warn_threshold=7,
hard_limit=9,
window_size=30,
max_tracked_threads=40,
tool_freq_warn=50,
tool_freq_hard_limit=60,
),
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
model_name="safe-model",
app_config=app_config,
)
loop_detection = next(m for m in middlewares if isinstance(m, LoopDetectionMiddleware))
assert loop_detection.warn_threshold == 7
assert loop_detection.hard_limit == 9
assert loop_detection.window_size == 30
assert loop_detection.max_tracked_threads == 40
assert loop_detection.tool_freq_warn == 50
assert loop_detection.tool_freq_hard_limit == 60
def test_build_middlewares_omits_loop_detection_when_disabled(monkeypatch):
app_config = _make_app_config(
[_make_model("safe-model", supports_thinking=False)],
loop_detection=LoopDetectionConfig(enabled=False),
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "build_lead_runtime_middlewares", lambda *, app_config, lazy_init=True: [])
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda *, app_config=None: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
model_name="safe-model",
app_config=app_config,
)
assert not any(isinstance(m, LoopDetectionMiddleware) for m in middlewares)
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
app_config = _make_app_config([_make_model("model-masswork", supports_thinking=False)])
app_config.summarization = SummarizationConfig(enabled=True, model_name="model-masswork")
+70 -3
View File
@@ -1,22 +1,37 @@
import threading
from types import SimpleNamespace
from typing import cast
import anyio
from deerflow.agents.lead_agent import prompt as prompt_module
from deerflow.config.app_config import AppConfig
from deerflow.config.subagents_config import CustomSubagentConfig, SubagentsAppConfig
from deerflow.skills.types import Skill
from deerflow.skills.types import Skill, SkillCategory
def _set_skills_cache_state(*, skills=None, active=False, version=0):
prompt_module._get_cached_skills_prompt_section.cache_clear()
with prompt_module._enabled_skills_lock:
prompt_module._enabled_skills_cache = skills
prompt_module._enabled_skills_by_config_cache.clear()
prompt_module._enabled_skills_refresh_active = active
prompt_module._enabled_skills_refresh_version = version
prompt_module._enabled_skills_refresh_event.clear()
def test_build_self_update_section_empty_for_default_agent():
assert prompt_module._build_self_update_section(None) == ""
def test_build_self_update_section_present_for_custom_agent():
section = prompt_module._build_self_update_section("my-agent")
assert "<self_update>" in section
assert "my-agent" in section
assert "update_agent" in section
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
@@ -220,7 +235,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
skill_dir=skill_dir,
skill_file=skill_dir / "SKILL.md",
relative_path=skill_dir.relative_to(tmp_path),
category="custom",
category=SkillCategory.CUSTOM,
enabled=True,
)
@@ -240,6 +255,58 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
_set_skills_cache_state()
def test_explicit_config_enabled_skills_are_cached_by_config_identity(monkeypatch, tmp_path):
def make_skill(name: str) -> Skill:
skill_dir = tmp_path / name
return Skill(
name=name,
description=f"Description for {name}",
license="MIT",
skill_dir=skill_dir,
skill_file=skill_dir / "SKILL.md",
relative_path=skill_dir.relative_to(tmp_path),
category=SkillCategory.CUSTOM,
enabled=True,
)
config = cast(
AppConfig,
cast(
object,
SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=False),
),
),
)
load_count = 0
def fake_get_or_new_skill_storage(**kwargs):
nonlocal load_count
assert kwargs == {"app_config": config}
def load_skills(*, enabled_only):
nonlocal load_count
load_count += 1
assert enabled_only is True
return [make_skill("cached-skill")]
return SimpleNamespace(load_skills=load_skills)
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", fake_get_or_new_skill_storage)
_set_skills_cache_state()
try:
first = prompt_module.get_skills_prompt_section(app_config=config)
second = prompt_module.get_skills_prompt_section(app_config=config)
assert "cached-skill" in first
assert "cached-skill" in second
assert load_count == 1
finally:
_set_skills_cache_state()
def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_path):
started = threading.Event()
release = threading.Event()
@@ -257,7 +324,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
skill_dir=skill_dir,
skill_file=skill_dir / "SKILL.md",
relative_path=skill_dir.relative_to(tmp_path),
category="custom",
category=SkillCategory.CUSTOM,
enabled=True,
)
+111 -1
View File
@@ -6,7 +6,12 @@ from deerflow.config.agents_config import AgentConfig
from deerflow.skills.types import Skill
def _make_skill(name: str) -> Skill:
class NamedTool:
def __init__(self, name: str):
self.name = name
def _make_skill(name: str, allowed_tools: list[str] | None = None) -> Skill:
return Skill(
name=name,
description=f"Description for {name}",
@@ -15,6 +20,7 @@ def _make_skill(name: str) -> Skill:
skill_file=Path(f"/tmp/{name}/SKILL.md"),
relative_path=Path(name),
category="public",
allowed_tools=allowed_tools,
enabled=True,
)
@@ -132,6 +138,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
@@ -164,3 +171,106 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["skill1"]))
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert captured_skills[-1] == {"skill1"}
def test_make_lead_agent_filters_tools_from_available_skills(monkeypatch):
from unittest.mock import MagicMock
from deerflow.agents.lead_agent import agent as lead_agent_module
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted", "legacy"]))
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [_make_skill("restricted", ["read_file"]), _make_skill("legacy", None)])
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")])
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert [tool.name for tool in agent_kwargs["tools"]] == ["read_file"]
def test_make_lead_agent_all_legacy_skills_preserve_all_tools(monkeypatch):
from unittest.mock import MagicMock
from deerflow.agents.lead_agent import agent as lead_agent_module
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=None))
monkeypatch.setattr(lead_agent_module, "_load_enabled_skills_for_tool_policy", lambda available_skills, *, app_config: [_make_skill("legacy", None)])
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file")])
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert [tool.name for tool in agent_kwargs["tools"]] == ["bash", "read_file", "update_agent"]
def test_make_lead_agent_enforces_allowed_tools_when_skill_cache_is_cold(monkeypatch):
from unittest.mock import MagicMock
from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.agents.lead_agent import prompt as prompt_module
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
monkeypatch.setattr(lead_agent_module, "apply_prompt_template", lambda **kwargs: "mock_prompt")
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted"]))
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")])
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
mock_storage = SimpleNamespace(load_skills=lambda *, enabled_only: [_make_skill("restricted", ["read_file"])])
with prompt_module._enabled_skills_lock:
prompt_module._enabled_skills_cache = None
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda app_config=None, **kwargs: mock_storage)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
agent_kwargs = lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
assert [tool.name for tool in agent_kwargs["tools"]] == ["read_file"]
def test_make_lead_agent_fails_closed_when_skill_policy_load_fails(monkeypatch):
from unittest.mock import MagicMock
import pytest
from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.agents.lead_agent import prompt as prompt_module
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
create_agent_mock = MagicMock()
monkeypatch.setattr(lead_agent_module, "create_agent", create_agent_mock)
monkeypatch.setattr(lead_agent_module, "load_agent_config", lambda x: AgentConfig(name="test", skills=["restricted"]))
mock_app_config = MagicMock()
mock_app_config.get_model_config.return_value = SimpleNamespace(supports_thinking=False, supports_vision=False)
def fail_storage(*args, **kwargs):
raise RuntimeError("skill storage unavailable")
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", fail_storage)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: mock_app_config)
with pytest.raises(RuntimeError, match="skill storage unavailable"):
lead_agent_module.make_lead_agent({"configurable": {"agent_name": "test"}})
create_agent_mock.assert_not_called()
@@ -105,6 +105,7 @@ def test_execute_command_uses_powershell_command_mode_on_windows(monkeypatch):
"capture_output": True,
"text": True,
"timeout": 600,
"env": None,
},
)
]
@@ -118,6 +119,7 @@ def test_execute_command_uses_posix_shell_command_mode_on_windows(monkeypatch):
return SimpleNamespace(stdout="ok", stderr="", returncode=0)
monkeypatch.setattr(local_sandbox.os, "name", "nt")
monkeypatch.setattr(local_sandbox.os, "environ", {"PATH": r"C:\Program Files\Git\bin"})
monkeypatch.setattr(LocalSandbox, "_get_shell", staticmethod(lambda: r"C:\Program Files\Git\bin\sh.exe"))
monkeypatch.setattr(local_sandbox.subprocess, "run", fake_run)
@@ -132,11 +134,33 @@ def test_execute_command_uses_posix_shell_command_mode_on_windows(monkeypatch):
"capture_output": True,
"text": True,
"timeout": 600,
"env": {
"PATH": r"C:\Program Files\Git\bin",
"MSYS_NO_PATHCONV": "1",
"MSYS2_ARG_CONV_EXCL": "*",
},
},
)
]
def test_execute_command_does_not_set_msys_env_for_non_msys_posix_shell_on_windows(monkeypatch):
calls: list[tuple[object, dict]] = []
def fake_run(*args, **kwargs):
calls.append((args[0], kwargs))
return SimpleNamespace(stdout="ok", stderr="", returncode=0)
monkeypatch.setattr(local_sandbox.os, "name", "nt")
monkeypatch.setattr(LocalSandbox, "_get_shell", staticmethod(lambda: r"C:\tools\busybox\sh.exe"))
monkeypatch.setattr(local_sandbox.subprocess, "run", fake_run)
output = LocalSandbox("t").execute_command("echo /mnt/skills/demo")
assert output == "ok"
assert calls[0][1]["env"] is None
def test_execute_command_uses_cmd_command_mode_on_windows(monkeypatch):
calls: list[tuple[object, dict]] = []
@@ -159,6 +183,7 @@ def test_execute_command_uses_cmd_command_mode_on_windows(monkeypatch):
"capture_output": True,
"text": True,
"timeout": 600,
"env": None,
},
)
]
@@ -0,0 +1,72 @@
"""Tests for loop detection configuration."""
import pytest
from deerflow.config.loop_detection_config import LoopDetectionConfig
class TestLoopDetectionConfig:
def test_defaults_match_middleware_defaults(self):
config = LoopDetectionConfig()
assert config.enabled is True
assert config.warn_threshold == 3
assert config.hard_limit == 5
assert config.window_size == 20
assert config.max_tracked_threads == 100
assert config.tool_freq_warn == 30
assert config.tool_freq_hard_limit == 50
def test_accepts_custom_values(self):
config = LoopDetectionConfig(
enabled=False,
warn_threshold=10,
hard_limit=20,
window_size=50,
max_tracked_threads=200,
tool_freq_warn=60,
tool_freq_hard_limit=80,
)
assert config.enabled is False
assert config.warn_threshold == 10
assert config.hard_limit == 20
assert config.window_size == 50
assert config.max_tracked_threads == 200
assert config.tool_freq_warn == 60
assert config.tool_freq_hard_limit == 80
def test_rejects_zero_thresholds(self):
with pytest.raises(ValueError):
LoopDetectionConfig(warn_threshold=0)
with pytest.raises(ValueError):
LoopDetectionConfig(hard_limit=0)
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_warn=0)
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_hard_limit=0)
def test_rejects_hard_limit_below_warn_threshold(self):
with pytest.raises(ValueError, match="hard_limit"):
LoopDetectionConfig(warn_threshold=5, hard_limit=4)
def test_rejects_tool_freq_hard_limit_below_warn_threshold(self):
with pytest.raises(ValueError, match="tool_freq_hard_limit"):
LoopDetectionConfig(tool_freq_warn=5, tool_freq_hard_limit=4)
def test_tool_freq_override_valid(self):
config = LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 150, "hard_limit": 300}})
override = config.tool_freq_overrides["bash"]
assert override.warn == 150
assert override.hard_limit == 300
def test_tool_freq_override_rejects_zero_warn(self):
with pytest.raises(ValueError):
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 0, "hard_limit": 10}})
def test_tool_freq_override_rejects_hard_limit_below_warn(self):
with pytest.raises(ValueError, match="hard_limit"):
LoopDetectionConfig(tool_freq_overrides={"bash": {"warn": 100, "hard_limit": 50}})
+112 -4
View File
@@ -3,7 +3,7 @@
import copy
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.messages import AIMessage, SystemMessage
from deerflow.agents.middlewares.loop_detection_middleware import (
_HARD_STOP_MSG,
@@ -146,14 +146,42 @@ class TestLoopDetection:
for _ in range(2):
mw._apply(_make_state(tool_calls=call), runtime)
# Third identical call triggers warning
# Third identical call triggers warning. The warning is appended to
# the AIMessage content (tool_calls preserved) — never inserted as a
# separate HumanMessage between the AIMessage(tool_calls) and its
# ToolMessage responses, which would break OpenAI/Moonshot strict
# tool-call pairing validation.
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msgs = result["messages"]
assert len(msgs) == 1
assert isinstance(msgs[0], HumanMessage)
assert isinstance(msgs[0], AIMessage)
assert len(msgs[0].tool_calls) == len(call)
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
assert "LOOP DETECTED" in msgs[0].content
def test_warn_does_not_break_tool_call_pairing(self):
"""Regression: the warn branch must NOT inject a non-tool message
after an AIMessage(tool_calls=...). Moonshot/OpenAI reject the next
request with 'tool_call_ids did not have response messages' if any
non-tool message is wedged between the AIMessage and its ToolMessage
responses. See #2029.
"""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(2):
mw._apply(_make_state(tool_calls=call), runtime)
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msgs = result["messages"]
assert len(msgs) == 1
assert isinstance(msgs[0], AIMessage)
assert len(msgs[0].tool_calls) == len(call)
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
def test_warn_only_injected_once(self):
"""Warning for the same hash should only be injected once per thread."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
@@ -483,7 +511,11 @@ class TestToolFrequencyDetection:
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_4.py")]), runtime)
assert result is not None
msg = result["messages"][0]
assert isinstance(msg, HumanMessage)
# Warning is appended to the AIMessage content; tool_calls preserved
# so the tools node still runs and Moonshot/OpenAI tool-call pairing
# validation does not break.
assert isinstance(msg, AIMessage)
assert msg.tool_calls
assert "read_file" in msg.content
assert "LOOP DETECTED" in msg.content
@@ -616,6 +648,37 @@ class TestToolFrequencyDetection:
assert result is not None
assert "read_file" in result["messages"][0].content
def test_override_tool_uses_override_thresholds(self):
"""A tool in tool_freq_overrides uses its own thresholds, not the global ones."""
mw = LoopDetectionMiddleware(
tool_freq_warn=5,
tool_freq_hard_limit=10,
tool_freq_overrides={"bash": (50, 100)},
)
runtime = _make_runtime()
# 10 bash calls — would hit global hard_limit=10, but bash override is 100
for i in range(10):
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
assert result is None, f"unexpected trigger on call {i + 1}"
def test_non_override_tool_falls_back_to_global(self):
"""A tool NOT in tool_freq_overrides uses the global warn/hard_limit."""
mw = LoopDetectionMiddleware(
tool_freq_warn=3,
tool_freq_hard_limit=6,
tool_freq_overrides={"bash": (50, 100)},
)
runtime = _make_runtime()
for i in range(2):
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
# 3rd read_file call hits global warn=3 (read_file has no override)
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert result is not None
assert "read_file" in result["messages"][0].content
def test_hash_detection_takes_priority(self):
"""Hash-based hard stop fires before frequency check for identical calls."""
mw = LoopDetectionMiddleware(
@@ -636,3 +699,48 @@ class TestToolFrequencyDetection:
msg = result["messages"][0]
assert isinstance(msg, AIMessage)
assert _HARD_STOP_MSG in msg.content
class TestFromConfig:
"""Tests for LoopDetectionMiddleware.from_config — the sole validated construction path."""
@staticmethod
def _config(**kwargs):
from deerflow.config.loop_detection_config import LoopDetectionConfig
return LoopDetectionConfig(**kwargs)
def test_scalar_fields_mapped(self):
config = self._config(
warn_threshold=4,
hard_limit=8,
window_size=15,
max_tracked_threads=50,
tool_freq_warn=20,
tool_freq_hard_limit=40,
)
mw = LoopDetectionMiddleware.from_config(config)
assert mw.warn_threshold == 4
assert mw.hard_limit == 8
assert mw.window_size == 15
assert mw.max_tracked_threads == 50
assert mw.tool_freq_warn == 20
assert mw.tool_freq_hard_limit == 40
def test_overrides_converted_to_tuples(self):
config = self._config(tool_freq_overrides={"bash": {"warn": 50, "hard_limit": 100}})
mw = LoopDetectionMiddleware.from_config(config)
assert mw._tool_freq_overrides == {"bash": (50, 100)}
def test_empty_overrides(self):
mw = LoopDetectionMiddleware.from_config(self._config())
assert mw._tool_freq_overrides == {}
def test_constructed_middleware_detects_loops(self):
mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4))
runtime = _make_runtime()
call = [_bash_call("ls")]
mw._apply(_make_state(tool_calls=call), runtime)
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
@@ -125,3 +125,68 @@ class TestMigrateMemory:
from scripts.migrate_user_isolation import migrate_memory
migrate_memory(paths, user_id="default") # should not raise
class TestMigrateAgents:
@staticmethod
def _seed_legacy_agent(paths: Paths, name: str, *, soul: str = "soul", description: str = "d") -> Path:
legacy_dir = paths.agents_dir / name
legacy_dir.mkdir(parents=True, exist_ok=True)
(legacy_dir / "config.yaml").write_text(f"name: {name}\ndescription: {description}\n", encoding="utf-8")
(legacy_dir / "SOUL.md").write_text(soul, encoding="utf-8")
return legacy_dir
def test_moves_legacy_into_user_layout(self, base_dir: Path, paths: Paths):
self._seed_legacy_agent(paths, "agent-a", soul="soul-a")
self._seed_legacy_agent(paths, "agent-b", soul="soul-b")
from scripts.migrate_user_isolation import migrate_agents
report = migrate_agents(paths, user_id="default")
assert {entry["agent"] for entry in report} == {"agent-a", "agent-b"}
for entry in report:
assert entry["user_id"] == "default"
assert "moved -> " in entry["action"]
for name, soul in [("agent-a", "soul-a"), ("agent-b", "soul-b")]:
dest = paths.user_agent_dir("default", name)
assert dest.exists(), f"{name} should have moved into the per-user layout"
assert (dest / "SOUL.md").read_text() == soul
# Legacy agents/ root is cleaned up once empty.
assert not paths.agents_dir.exists()
def test_dry_run_does_not_move(self, base_dir: Path, paths: Paths):
legacy_dir = self._seed_legacy_agent(paths, "agent-a")
from scripts.migrate_user_isolation import migrate_agents
report = migrate_agents(paths, user_id="default", dry_run=True)
assert len(report) == 1
assert legacy_dir.exists(), "dry-run must not touch the filesystem"
assert not paths.user_agent_dir("default", "agent-a").exists()
def test_existing_destination_is_treated_as_conflict(self, base_dir: Path, paths: Paths):
self._seed_legacy_agent(paths, "agent-a", soul="legacy soul")
dest = paths.user_agent_dir("default", "agent-a")
dest.mkdir(parents=True)
(dest / "SOUL.md").write_text("preexisting", encoding="utf-8")
from scripts.migrate_user_isolation import migrate_agents
report = migrate_agents(paths, user_id="default")
assert report[0]["action"].startswith("conflict -> ")
# Per-user destination must be left untouched.
assert (dest / "SOUL.md").read_text() == "preexisting"
# Legacy copy lands under migration-conflicts/agents/.
conflicts_dir = paths.base_dir / "migration-conflicts" / "agents" / "agent-a"
assert (conflicts_dir / "SOUL.md").read_text() == "legacy soul"
def test_no_legacy_dir_is_noop(self, base_dir: Path, paths: Paths):
from scripts.migrate_user_isolation import migrate_agents
report = migrate_agents(paths, user_id="default")
assert report == []
@@ -50,6 +50,21 @@ class TestUserAgentMemoryFile:
assert paths.user_agent_memory_file("bob", "MyAgent") == expected
class TestUserAgentDir:
def test_user_agents_dir(self, paths: Paths):
assert paths.user_agents_dir("alice") == paths.base_dir / "users" / "alice" / "agents"
def test_user_agent_dir(self, paths: Paths):
assert paths.user_agent_dir("alice", "code-reviewer") == paths.base_dir / "users" / "alice" / "agents" / "code-reviewer"
def test_user_agent_dir_lowercases_name(self, paths: Paths):
assert paths.user_agent_dir("alice", "CodeReviewer") == paths.base_dir / "users" / "alice" / "agents" / "codereviewer"
def test_user_agent_dir_validates_user_id(self, paths: Paths):
with pytest.raises(ValueError, match="Invalid user_id"):
paths.user_agent_dir("../escape", "myagent")
class TestUserThreadDir:
def test_user_thread_dir(self, paths: Paths):
expected = paths.base_dir / "users" / "u1" / "threads" / "t1"
+6 -9
View File
@@ -8,7 +8,9 @@ Tests:
5. Postgres missing-dep error message
"""
import sys
from datetime import UTC, datetime
from unittest.mock import patch
import pytest
@@ -221,13 +223,8 @@ class TestEngineLifecycle:
"""If asyncpg is not installed, error message tells user what to do."""
from deerflow.persistence.engine import init_engine
try:
import asyncpg # noqa: F401
pytest.skip("asyncpg is installed -- cannot test missing-dep path")
except ImportError:
# asyncpg is not installed — this is the expected state for this test.
# We proceed to verify that init_engine raises an actionable ImportError.
pass # noqa: S110 — intentionally ignored
with pytest.raises(ImportError, match="uv sync --extra postgres"):
with (
patch.dict(sys.modules, {"asyncpg": None}),
pytest.raises(ImportError, match="uv sync --all-packages --extra postgres"),
):
await init_engine("postgres", url="postgresql+asyncpg://x:x@localhost/x")
@@ -0,0 +1,293 @@
from __future__ import annotations
import pytest
import requests
from deerflow.community.aio_sandbox.remote_backend import RemoteSandboxBackend
from deerflow.community.aio_sandbox.sandbox_info import SandboxInfo
class _StubResponse:
def __init__(
self,
*,
status_code: int = 200,
payload: object | None = None,
json_exc: Exception | None = None,
):
self.status_code = status_code
self._payload = {} if payload is None else payload
self._json_exc = json_exc
self.ok = 200 <= status_code < 400
self.text = ""
def raise_for_status(self) -> None:
if self.status_code >= 400:
raise requests.HTTPError(f"HTTP {self.status_code}")
def json(self) -> object:
if self._json_exc is not None:
raise self._json_exc
return self._payload
def test_list_running_delegates_to_provisioner_list(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
sandbox_info = SandboxInfo(sandbox_id="test-id", sandbox_url="http://localhost:8080")
def mock_list():
return [sandbox_info]
monkeypatch.setattr(backend, "_provisioner_list", mock_list)
assert backend.list_running() == [sandbox_info]
def test_provisioner_list_returns_sandbox_infos_and_filters_invalid_entries(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
assert url == "http://provisioner:8002/api/sandboxes"
assert timeout == 10
return _StubResponse(
payload={
"sandboxes": [
{"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"},
{"sandbox_id": "missing-url"},
{"sandbox_url": "http://k3s:31002"},
]
}
)
monkeypatch.setattr(requests, "get", mock_get)
infos = backend._provisioner_list()
assert len(infos) == 1
assert infos[0].sandbox_id == "abc123"
assert infos[0].sandbox_url == "http://k3s:31001"
def test_provisioner_list_returns_empty_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
raise requests.RequestException("network down")
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_list() == []
def test_provisioner_list_returns_empty_when_payload_is_not_dict(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(payload=[{"sandbox_id": "abc", "sandbox_url": "http://k3s:31001"}])
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_list() == []
def test_provisioner_list_returns_empty_when_sandboxes_is_not_list(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(payload={"sandboxes": {"sandbox_id": "abc"}})
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_list() == []
def test_provisioner_list_skips_non_dict_sandbox_entries(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(
payload={
"sandboxes": [
{"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"},
"bad-entry",
123,
None,
]
}
)
monkeypatch.setattr(requests, "get", mock_get)
infos = backend._provisioner_list()
assert len(infos) == 1
assert infos[0].sandbox_id == "abc123"
assert infos[0].sandbox_url == "http://k3s:31001"
def test_create_delegates_to_provisioner_create(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
expected = SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001")
def mock_create(thread_id: str, sandbox_id: str, extra_mounts=None):
assert thread_id == "thread-1"
assert sandbox_id == "abc123"
assert extra_mounts == [("/host", "/container", False)]
return expected
monkeypatch.setattr(backend, "_provisioner_create", mock_create)
result = backend.create("thread-1", "abc123", extra_mounts=[("/host", "/container", False)])
assert result == expected
def test_provisioner_create_returns_sandbox_info(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_post(url: str, json: dict, timeout: int):
assert url == "http://provisioner:8002/api/sandboxes"
assert json == {"sandbox_id": "abc123", "thread_id": "thread-1"}
assert timeout == 30
return _StubResponse(payload={"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"})
monkeypatch.setattr(requests, "post", mock_post)
info = backend._provisioner_create("thread-1", "abc123")
assert info.sandbox_id == "abc123"
assert info.sandbox_url == "http://k3s:31001"
def test_provisioner_create_raises_runtime_error_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_post(url: str, json: dict, timeout: int):
raise requests.RequestException("boom")
monkeypatch.setattr(requests, "post", mock_post)
with pytest.raises(RuntimeError, match="Provisioner create failed"):
backend._provisioner_create("thread-1", "abc123")
def test_destroy_delegates_to_provisioner_destroy(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
called: list[str] = []
def mock_destroy(sandbox_id: str):
called.append(sandbox_id)
monkeypatch.setattr(backend, "_provisioner_destroy", mock_destroy)
backend.destroy(SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001"))
assert called == ["abc123"]
def test_provisioner_destroy_calls_delete(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_delete(url: str, timeout: int):
assert url == "http://provisioner:8002/api/sandboxes/abc123"
assert timeout == 15
return _StubResponse(status_code=200)
monkeypatch.setattr(requests, "delete", mock_delete)
backend._provisioner_destroy("abc123")
def test_provisioner_destroy_swallows_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_delete(url: str, timeout: int):
raise requests.RequestException("network down")
monkeypatch.setattr(requests, "delete", mock_delete)
backend._provisioner_destroy("abc123")
def test_is_alive_delegates_to_provisioner_is_alive(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_is_alive(sandbox_id: str):
assert sandbox_id == "abc123"
return True
monkeypatch.setattr(backend, "_provisioner_is_alive", mock_is_alive)
alive = backend.is_alive(SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001"))
assert alive is True
def test_provisioner_is_alive_true_only_when_status_running(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get_running(url: str, timeout: int):
return _StubResponse(payload={"status": "Running"})
monkeypatch.setattr(requests, "get", mock_get_running)
assert backend._provisioner_is_alive("abc123") is True
def mock_get_pending(url: str, timeout: int):
return _StubResponse(payload={"status": "Pending"})
monkeypatch.setattr(requests, "get", mock_get_pending)
assert backend._provisioner_is_alive("abc123") is False
def test_provisioner_is_alive_returns_false_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
raise requests.RequestException("boom")
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_is_alive("abc123") is False
def test_discover_delegates_to_provisioner_discover(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
expected = SandboxInfo(sandbox_id="abc123", sandbox_url="http://k3s:31001")
def mock_discover(sandbox_id: str):
assert sandbox_id == "abc123"
return expected
monkeypatch.setattr(backend, "_provisioner_discover", mock_discover)
result = backend.discover("abc123")
assert result == expected
def test_provisioner_discover_returns_none_on_404(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(status_code=404)
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_discover("abc123") is None
def test_provisioner_discover_returns_info_on_success(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(payload={"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"})
monkeypatch.setattr(requests, "get", mock_get)
info = backend._provisioner_discover("abc123")
assert info is not None
assert info.sandbox_id == "abc123"
assert info.sandbox_url == "http://k3s:31001"
def test_provisioner_discover_returns_none_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
raise requests.RequestException("boom")
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_discover("abc123") is None
+71
View File
@@ -310,6 +310,28 @@ class TestDbRunEventStore:
await close_engine()
@pytest.mark.anyio
async def test_structured_content_round_trips(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
content = [{"type": "text", "text": "hello"}, {"type": "image_url", "image_url": {"url": "https://example.test/a.png"}}]
record = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=content)
assert record["content"] == content
assert record["metadata"]["content_is_json"] is True
assert "content_is_dict" not in record["metadata"]
messages = await s.list_messages("t1")
assert messages[0]["content"] == content
assert messages[0]["metadata"]["content_is_json"] is True
await close_engine()
@pytest.mark.anyio
async def test_pagination(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
@@ -373,6 +395,55 @@ class TestDbRunEventStore:
assert seqs == list(range(1, 51))
await close_engine()
@pytest.mark.anyio
async def test_put_batch_accepts_structured_content(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
content = [{"messages": [{"type": "ai", "content": ""}]}]
results = await s.put_batch(
[
{
"thread_id": "t1",
"run_id": "r1",
"event_type": "run.end",
"category": "outputs",
"content": content,
}
]
)
assert results[0]["content"] == content
assert results[0]["metadata"]["content_is_json"] is True
events = await s.list_events("t1", "r1")
assert events[0]["content"] == content
assert events[0]["metadata"]["content_is_json"] is True
await close_engine()
@pytest.mark.anyio
async def test_dict_content_keeps_legacy_metadata_flag(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
content = {"status": "success"}
record = await s.put(thread_id="t1", run_id="r1", event_type="run.end", category="outputs", content=content)
assert record["content"] == content
assert record["metadata"]["content_is_json"] is True
assert record["metadata"]["content_is_dict"] is True
await close_engine()
# -- Factory tests --
+55
View File
@@ -166,6 +166,61 @@ class TestRunRepository:
assert row["total_tokens"] == 100
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("success-run", thread_id="t1", status="running")
await repo.update_run_completion(
"success-run",
status="success",
total_input_tokens=70,
total_output_tokens=30,
total_tokens=100,
lead_agent_tokens=80,
subagent_tokens=15,
middleware_tokens=5,
)
await repo.put("error-run", thread_id="t1", status="running")
await repo.update_run_completion(
"error-run",
status="error",
total_input_tokens=20,
total_output_tokens=30,
total_tokens=50,
lead_agent_tokens=40,
subagent_tokens=10,
)
await repo.put("running-run", thread_id="t1", status="running")
await repo.update_run_completion(
"running-run",
status="running",
total_input_tokens=900,
total_output_tokens=99,
total_tokens=999,
lead_agent_tokens=999,
)
await repo.put("other-thread-run", thread_id="t2", status="running")
await repo.update_run_completion(
"other-thread-run",
status="success",
total_tokens=888,
lead_agent_tokens=888,
)
agg = await repo.aggregate_tokens_by_thread("t1")
assert agg["total_tokens"] == 150
assert agg["total_input_tokens"] == 90
assert agg["total_output_tokens"] == 60
assert agg["total_runs"] == 2
assert agg["by_model"] == {"unknown": {"tokens": 150, "runs": 2}}
assert agg["by_caller"] == {
"lead_agent": 120,
"subagent": 25,
"middleware": 5,
}
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread_ordered_desc(self, tmp_path):
"""list_by_thread returns newest first."""
+29 -6
View File
@@ -6,6 +6,8 @@ from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from deerflow.tools.builtins.setup_agent_tool import setup_agent
# --- Helpers ---
@@ -27,6 +29,7 @@ def _make_paths_mock(tmp_path: Path):
paths = MagicMock()
paths.base_dir = tmp_path
paths.agent_dir = lambda name: tmp_path / "agents" / name
paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name
return paths
@@ -54,7 +57,7 @@ def test_setup_agent_rejects_invalid_agent_name_before_writing(tmp_path, monkeyp
messages = result.update["messages"]
assert len(messages) == 1
assert "Invalid agent name" in messages[0].content
assert not (tmp_path / "agents").exists()
assert not (tmp_path / "users" / "test-user-autouse" / "agents").exists()
assert not (outside_dir / "evil" / "SOUL.md").exists()
@@ -68,7 +71,7 @@ def test_setup_agent_rejects_absolute_agent_name_before_writing(tmp_path, monkey
messages = result.update["messages"]
assert len(messages) == 1
assert "Invalid agent name" in messages[0].content
assert not (tmp_path / "agents").exists()
assert not (tmp_path / "users" / "test-user-autouse" / "agents").exists()
assert not (Path(absolute_agent) / "SOUL.md").exists()
@@ -81,10 +84,10 @@ class TestSetupAgentNoDataLoss:
def test_existing_agent_dir_preserved_on_failure(self, tmp_path: Path):
"""If the agent directory already exists and setup fails,
the directory and its contents must NOT be deleted."""
agent_dir = tmp_path / "agents" / "test-agent"
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "test-agent"
agent_dir.mkdir(parents=True)
old_soul = agent_dir / "SOUL.md"
old_soul.write_text("original soul content")
old_soul.write_text("original soul content", encoding="utf-8")
with patch("deerflow.tools.builtins.setup_agent_tool.get_paths", return_value=_make_paths_mock(tmp_path)):
# Force soul_file.write_text to raise after directory already exists
@@ -103,7 +106,7 @@ class TestSetupAgentNoDataLoss:
def test_new_agent_dir_cleaned_up_on_failure(self, tmp_path: Path):
"""If the agent directory is newly created and setup fails,
the directory should be cleaned up."""
agent_dir = tmp_path / "agents" / "test-agent"
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "test-agent"
assert not agent_dir.exists()
with patch("deerflow.tools.builtins.setup_agent_tool.get_paths", return_value=_make_paths_mock(tmp_path)):
@@ -121,7 +124,27 @@ class TestSetupAgentNoDataLoss:
"""Happy path: setup_agent creates config.yaml and SOUL.md."""
_call_setup_agent(tmp_path, soul="# My Agent", description="A test agent")
agent_dir = tmp_path / "agents" / "test-agent"
agent_dir = tmp_path / "users" / "test-user-autouse" / "agents" / "test-agent"
assert agent_dir.exists()
assert (agent_dir / "SOUL.md").read_text() == "# My Agent"
assert (agent_dir / "config.yaml").exists()
@pytest.mark.no_auto_user
def test_runtime_user_id_used_when_contextvar_missing(self, tmp_path: Path):
"""setup_agent should not fall back to default when runtime carries user_id."""
runtime = _DummyRuntime(
context={"agent_name": "test-agent", "user_id": "auth-user-42"},
tool_call_id="tool-3",
)
with patch("deerflow.tools.builtins.setup_agent_tool.get_paths", return_value=_make_paths_mock(tmp_path)):
setup_agent.func(
soul="# My Agent",
description="A test agent",
runtime=runtime,
)
expected_dir = tmp_path / "users" / "auth-user-42" / "agents" / "test-agent"
default_dir = tmp_path / "users" / "default" / "agents" / "test-agent"
assert (expected_dir / "SOUL.md").read_text() == "# My Agent"
assert not default_dir.exists()
+2 -2
View File
@@ -313,7 +313,7 @@ class TestWriteConfigYaml:
{
"config_version": 5,
"log_level": "info",
"token_usage": {"enabled": False},
"token_usage": {"enabled": True},
"tool_groups": [{"name": "web"}, {"name": "file:read"}, {"name": "file:write"}, {"name": "bash"}],
"tools": [
{
@@ -361,7 +361,7 @@ class TestWriteConfigYaml:
data = yaml.safe_load(f)
assert data["log_level"] == "info"
assert data["token_usage"]["enabled"] is False
assert data["token_usage"]["enabled"] is True
assert data["tool_groups"][0]["name"] == "web"
assert data["summarization"]["max_tokens"] == 2048
assert any(tool["name"] == "image_search" and tool["max_results"] == 5 for tool in data["tools"])
+27
View File
@@ -86,6 +86,33 @@ def test_parse_license_field(tmp_path):
assert skill.license == "MIT"
def test_parse_missing_allowed_tools_returns_none(tmp_path):
skill_file = _write_skill(tmp_path, "name: my-skill\ndescription: Test")
skill = parse_skill_file(skill_file, category="custom")
assert skill is not None
assert skill.allowed_tools is None
def test_parse_allowed_tools_list(tmp_path):
skill_file = _write_skill(tmp_path, 'name: my-skill\ndescription: Test\nallowed-tools: ["bash", "read_file"]')
skill = parse_skill_file(skill_file, category="custom")
assert skill is not None
assert skill.allowed_tools == ["bash", "read_file"]
def test_parse_empty_allowed_tools_list(tmp_path):
skill_file = _write_skill(tmp_path, "name: my-skill\ndescription: Test\nallowed-tools: []")
skill = parse_skill_file(skill_file, category="custom")
assert skill is not None
assert skill.allowed_tools == []
def test_parse_invalid_allowed_tools_returns_none(tmp_path):
skill_file = _write_skill(tmp_path, "name: my-skill\ndescription: Test\nallowed-tools: bash")
skill = parse_skill_file(skill_file, category="custom")
assert skill is None
def test_parse_missing_name_returns_none(tmp_path):
"""Skills missing a name field are rejected."""
skill_file = _write_skill(tmp_path, "description: A test skill")
+35 -1
View File
@@ -30,13 +30,47 @@ class TestValidateSkillFrontmatter:
def test_valid_with_all_allowed_fields(self, tmp_path):
skill_dir = _write_skill(
tmp_path,
"---\nname: my-skill\ndescription: A skill\nlicense: MIT\nversion: '1.0'\nauthor: test\n---\n\nBody\n",
"---\nname: my-skill\ndescription: A skill\nlicense: MIT\nversion: '1.0'\nauthor: test\nallowed-tools: [bash, read_file]\n---\n\nBody\n",
)
valid, msg, name = _validate_skill_frontmatter(skill_dir)
assert valid is True
assert msg == "Skill is valid!"
assert name == "my-skill"
def test_allows_empty_allowed_tools(self, tmp_path):
skill_dir = _write_skill(
tmp_path,
"---\nname: my-skill\ndescription: A skill\nallowed-tools: []\n---\n\nBody\n",
)
valid, msg, name = _validate_skill_frontmatter(skill_dir)
assert valid is True
assert msg == "Skill is valid!"
assert name == "my-skill"
def test_rejects_allowed_tools_string(self, tmp_path):
skill_dir = _write_skill(
tmp_path,
"---\nname: my-skill\ndescription: A skill\nallowed-tools: bash\n---\n\nBody\n",
)
valid, msg, name = _validate_skill_frontmatter(skill_dir)
assert valid is False
assert "allowed-tools" in msg
assert str(tmp_path) not in msg
assert "SKILL.md" in msg
assert name is None
def test_rejects_allowed_tools_non_string_entry(self, tmp_path):
skill_dir = _write_skill(
tmp_path,
"---\nname: my-skill\ndescription: A skill\nallowed-tools: [bash, 1]\n---\n\nBody\n",
)
valid, msg, name = _validate_skill_frontmatter(skill_dir)
assert valid is False
assert "allowed-tools" in msg
assert str(tmp_path) not in msg
assert "SKILL.md" in msg
assert name is None
def test_missing_skill_md(self, tmp_path):
valid, msg, name = _validate_skill_frontmatter(tmp_path)
assert valid is False
+141 -4
View File
@@ -17,11 +17,14 @@ import asyncio
import sys
import threading
from datetime import datetime
from pathlib import Path
from types import ModuleType, SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from deerflow.skills.types import Skill
# Module names that need to be mocked to break circular imports
_MOCKED_MODULE_NAMES = [
"deerflow.agents",
@@ -32,14 +35,15 @@ _MOCKED_MODULE_NAMES = [
"deerflow.sandbox.middleware",
"deerflow.sandbox.security",
"deerflow.models",
"deerflow.skills.storage",
]
@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(autouse=True)
def _setup_executor_classes():
"""Set up mocked modules and import real executor classes.
This fixture runs once per session and yields the executor classes.
This fixture runs once per test and yields the executor classes.
It handles module cleanup to avoid affecting other test files.
"""
# Save original modules
@@ -53,6 +57,9 @@ def _setup_executor_classes():
# Set up mocks
for name in _MOCKED_MODULE_NAMES:
sys.modules[name] = MagicMock()
storage_module = ModuleType("deerflow.skills.storage")
storage_module.get_or_new_skill_storage = lambda **kwargs: SimpleNamespace(load_skills=lambda *, enabled_only: [])
sys.modules["deerflow.skills.storage"] = storage_module
# Import real classes inside fixture
from langchain_core.messages import AIMessage, HumanMessage
@@ -117,6 +124,26 @@ class MockAIMessage:
return msg
class NamedTool:
def __init__(self, name: str):
self.name = name
def _skill(name: str, allowed_tools: list[str] | None) -> Skill:
skill_dir = Path(f"/tmp/{name}")
return Skill(
name=name,
description=f"{name} skill",
license=None,
skill_dir=skill_dir,
skill_file=skill_dir / "SKILL.md",
relative_path=Path(name),
category="custom",
allowed_tools=allowed_tools,
enabled=True,
)
async def async_iterator(items):
"""Helper to create an async iterator from a list."""
for item in items:
@@ -288,7 +315,7 @@ class TestAgentConstruction:
captured["app_config"] = app_config
return SimpleNamespace(load_skills=lambda *, enabled_only: [SimpleNamespace(name="demo-skill", skill_file=skill_file)])
monkeypatch.setattr("deerflow.skills.storage.get_or_new_skill_storage", fake_get_or_new_skill_storage)
monkeypatch.setattr(sys.modules["deerflow.skills.storage"], "get_or_new_skill_storage", fake_get_or_new_skill_storage)
executor = SubagentExecutor(
config=base_config,
@@ -297,7 +324,8 @@ class TestAgentConstruction:
thread_id="test-thread",
)
messages = await executor._load_skill_messages()
skills = await executor._load_skills()
messages = await executor._load_skill_messages(skills)
assert captured["app_config"] is app_config
assert len(messages) == 1
@@ -487,6 +515,115 @@ class TestAsyncExecutionPath:
assert "Task" in result.result
class TestSkillAllowedTools:
@pytest.mark.anyio
async def test_skill_allowed_tools_union_filters_agent_tools(self, classes, base_config, mock_agent, msg):
SubagentExecutor = classes["SubagentExecutor"]
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
async def load_skills():
return [_skill("a", ["bash"]), _skill("b", ["read_file"])]
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
await executor._aexecute("Task")
create_agent_mock.assert_called_once()
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash", "read_file"]
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
@pytest.mark.anyio
async def test_all_missing_allowed_tools_preserves_legacy_allow_all(self, classes, base_config, mock_agent, msg):
SubagentExecutor = classes["SubagentExecutor"]
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
async def load_skills():
return [_skill("legacy-a", None), _skill("legacy-b", None)]
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
await executor._aexecute("Task")
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash", "read_file", "web_search"]
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
@pytest.mark.anyio
async def test_mixed_missing_allowed_tools_does_not_disable_explicit_restrictions(self, classes, base_config, mock_agent, msg):
SubagentExecutor = classes["SubagentExecutor"]
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
async def load_skills():
return [_skill("legacy", None), _skill("restricted", ["bash"])]
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
await executor._aexecute("Task")
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash"]
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
@pytest.mark.anyio
async def test_mixed_missing_allowed_tools_order_does_not_disable_explicit_restrictions(self, classes, base_config, mock_agent, msg):
SubagentExecutor = classes["SubagentExecutor"]
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
async def load_skills():
return [_skill("restricted", ["bash"]), _skill("legacy", None)]
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
await executor._aexecute("Task")
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["bash"]
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
@pytest.mark.anyio
async def test_empty_allowed_tools_contributes_no_tools(self, classes, base_config, mock_agent, msg, caplog):
SubagentExecutor = classes["SubagentExecutor"]
final_state = {"messages": [msg.human("Task"), msg.ai("Done", "msg-1")]}
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
tools = [NamedTool("bash"), NamedTool("read_file"), NamedTool("web_search")]
executor = SubagentExecutor(config=base_config, tools=tools, thread_id="test-thread")
async def load_skills():
return [_skill("empty", []), _skill("reader", ["read_file"])]
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock, caplog.at_level("INFO"):
await executor._aexecute("Task")
assert [tool.name for tool in create_agent_mock.call_args.args[0]] == ["read_file"]
assert [tool.name for tool in executor.tools] == ["bash", "read_file", "web_search"]
assert "declared empty allowed-tools" in caplog.text
@pytest.mark.anyio
async def test_skill_load_failure_fails_without_creating_agent(self, classes, base_config, mock_agent):
SubagentExecutor = classes["SubagentExecutor"]
executor = SubagentExecutor(config=base_config, tools=[NamedTool("bash")], thread_id="test-thread")
async def load_skills():
raise RuntimeError("skill storage unavailable")
with patch.object(executor, "_load_skills", load_skills), patch.object(executor, "_create_agent", return_value=mock_agent) as create_agent_mock:
result = await executor._aexecute("Task")
assert result.status == classes["SubagentStatus"].FAILED
assert result.error == "skill storage unavailable"
create_agent_mock.assert_not_called()
# -----------------------------------------------------------------------------
# Sync Execution Path Tests
# -----------------------------------------------------------------------------
@@ -27,6 +27,14 @@ def _other_call(name="bash", call_id="call_other"):
return {"name": name, "id": call_id, "args": {}}
def _raw_tool_call(call_id: str, name: str = "task") -> dict:
return {
"id": call_id,
"type": "function",
"function": {"name": name, "arguments": "{}"},
}
class TestClampSubagentLimit:
def test_below_min_clamped_to_min(self):
assert _clamp_subagent_limit(0) == MIN_SUBAGENT_LIMIT
@@ -117,6 +125,23 @@ class TestTruncateTaskCalls:
task_calls = [tc for tc in updated_msg.tool_calls if tc["name"] == "task"]
assert len(task_calls) == 2
def test_truncation_syncs_raw_provider_tool_calls(self):
mw = SubagentLimitMiddleware(max_concurrent=2)
msg = AIMessage(
content="",
tool_calls=[_task_call("t1"), _task_call("t2"), _task_call("t3"), _task_call("t4")],
additional_kwargs={"tool_calls": [_raw_tool_call("t1"), _raw_tool_call("t2"), _raw_tool_call("t3"), _raw_tool_call("t4")]},
response_metadata={"finish_reason": "tool_calls"},
)
result = mw._truncate_task_calls({"messages": [msg]})
assert result is not None
updated_msg = result["messages"][0]
assert [tc["id"] for tc in updated_msg.tool_calls] == ["t1", "t2"]
assert [tc["id"] for tc in updated_msg.additional_kwargs["tool_calls"]] == ["t1", "t2"]
assert updated_msg.response_metadata["finish_reason"] == "tool_calls"
def test_only_non_task_calls_returns_none(self):
mw = SubagentLimitMiddleware()
msg = AIMessage(
@@ -1,12 +1,14 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest import mock
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY, DynamicContextMiddleware
from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent
from deerflow.config.memory_config import MemoryConfig
@@ -20,6 +22,14 @@ def _messages() -> list:
]
def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
return HumanMessage(
content="<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
id=msg_id,
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
)
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
context = {}
if thread_id is not None:
@@ -75,6 +85,14 @@ def _skill_conversation() -> list:
]
def _raw_tool_call(tool_id: str, name: str = "read_file") -> dict:
return {
"id": tool_id,
"type": "function",
"function": {"name": name, "arguments": "{}"},
}
def test_before_summarization_hook_receives_messages_before_compression() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append])
@@ -90,6 +108,38 @@ def test_before_summarization_hook_receives_messages_before_compression() -> Non
assert result["messages"][1].content.startswith("Here is a summary")
def test_dynamic_context_reminder_is_preserved_across_summarization() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append])
reminder = _dynamic_context_reminder()
result = middleware.before_model(
{
"messages": [
reminder,
HumanMessage(content="user-1"),
AIMessage(content="assistant-1"),
HumanMessage(content="user-2"),
]
},
_runtime(),
)
assert len(captured) == 1
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1"]
assert captured[0].preserved_messages[0] is reminder
emitted = result["messages"]
assert isinstance(emitted[0], RemoveMessage)
assert emitted[1].name == "summary"
assert emitted[2] is reminder
followup_state = {"messages": [*emitted[1:], HumanMessage(content="Follow-up", id="msg-2")]}
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
assert DynamicContextMiddleware().before_agent(followup_state, _runtime()) is None
def test_before_summarization_hook_not_called_when_threshold_not_met() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append], trigger=("messages", 10))
@@ -413,6 +463,47 @@ def test_skill_rescue_does_not_preserve_non_skill_outputs_from_mixed_tool_calls(
assert any(isinstance(m, ToolMessage) and m.content == "user notes" for m in summarized)
def test_skill_rescue_syncs_raw_provider_tool_calls_on_split_ai_messages() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="reading skill and notes",
tool_calls=[
_skill_read_call("skill-1", "alpha"),
{"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}},
],
additional_kwargs={"tool_calls": [_raw_tool_call("skill-1"), _raw_tool_call("file-1")]},
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
ToolMessage(content="user notes", tool_call_id="file-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"]
assert [tc["id"] for tc in preserved_ai.additional_kwargs["tool_calls"]] == ["skill-1"]
assert [tc["id"] for tc in summarized_ai.tool_calls] == ["file-1"]
assert [tc["id"] for tc in summarized_ai.additional_kwargs["tool_calls"]] == ["file-1"]
def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
@@ -451,6 +542,42 @@ def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None:
assert summarized_ai.content == "reading skill and notes"
def test_skill_rescue_removes_raw_provider_tool_calls_from_content_only_summary_clone() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="reading skill",
tool_calls=[_skill_read_call("skill-1", "alpha")],
additional_kwargs={"tool_calls": [_raw_tool_call("skill-1")], "function_call": {"name": "read_file"}},
response_metadata={"finish_reason": "tool_calls"},
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
summarized = captured[0].messages_to_summarize
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage))
assert summarized_ai.content == "reading skill"
assert summarized_ai.tool_calls == []
assert "tool_calls" not in summarized_ai.additional_kwargs
assert "function_call" not in summarized_ai.additional_kwargs
assert summarized_ai.response_metadata["finish_reason"] == "stop"
def test_skill_rescue_only_preserves_skill_calls_with_matched_tool_results() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
+1 -2
View File
@@ -221,7 +221,6 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
prompt="collect diagnostics",
subagent_type="general-purpose",
tool_call_id="tc-123",
max_turns=7,
)
assert output == "Task Succeeded. Result: all done"
@@ -229,7 +228,7 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
assert captured["task_id"] == "tc-123"
assert captured["executor_kwargs"]["thread_id"] == "thread-1"
assert captured["executor_kwargs"]["parent_model"] == "ark-model"
assert captured["executor_kwargs"]["config"].max_turns == 7
assert captured["executor_kwargs"]["config"].max_turns == config.max_turns
# Skills are no longer appended to system_prompt; they are loaded per-session
# by SubagentExecutor and injected as conversation items (Codex pattern).
assert captured["executor_kwargs"]["config"].system_prompt == "Base system prompt"
+55
View File
@@ -0,0 +1,55 @@
"""Tests for thread-level token usage aggregation API."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
from _router_auth_helpers import make_authed_test_app
from fastapi.testclient import TestClient
from app.gateway.routers import thread_runs
def _make_app(run_store: MagicMock):
app = make_authed_test_app()
app.include_router(thread_runs.router)
app.state.run_store = run_store
return app
def test_thread_token_usage_returns_stable_shape():
run_store = MagicMock()
run_store.aggregate_tokens_by_thread = AsyncMock(
return_value={
"total_tokens": 150,
"total_input_tokens": 90,
"total_output_tokens": 60,
"total_runs": 2,
"by_model": {"unknown": {"tokens": 150, "runs": 2}},
"by_caller": {
"lead_agent": 120,
"subagent": 25,
"middleware": 5,
},
},
)
app = _make_app(run_store)
with TestClient(app) as client:
response = client.get("/api/threads/thread-1/token-usage")
assert response.status_code == 200
assert response.json() == {
"thread_id": "thread-1",
"total_tokens": 150,
"total_input_tokens": 90,
"total_output_tokens": 60,
"total_runs": 2,
"by_model": {"unknown": {"tokens": 150, "runs": 2}},
"by_caller": {
"lead_agent": 120,
"subagent": 25,
"middleware": 5,
},
}
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
@@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares import title_middleware as title_middleware_module
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
@@ -44,6 +45,22 @@ class TestTitleMiddlewareCoreLogic:
assert middleware._should_generate_title(state) is True
def test_should_generate_title_with_dynamic_context_reminder(self):
_set_test_title_config(enabled=True)
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(
content="<system-reminder>\n<memory>User prefers Python.</memory>\n</system-reminder>",
additional_kwargs={_DYNAMIC_CONTEXT_REMINDER_KEY: True},
),
HumanMessage(content="帮我总结这段代码"),
AIMessage(content="好的,我先看结构"),
]
}
assert middleware._should_generate_title(state) is True
def test_should_not_generate_title_when_disabled_or_already_set(self):
middleware = TitleMiddleware()
@@ -243,6 +260,25 @@ class TestTitleMiddlewareCoreLogic:
prompt, _ = middleware._build_title_prompt(state)
assert "<think>" not in prompt
def test_build_title_prompt_uses_real_user_message_with_dynamic_context_reminder(self):
_set_test_title_config(enabled=True)
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(
content="<system-reminder>\n<memory>User prefers Python.</memory>\n</system-reminder>",
additional_kwargs={_DYNAMIC_CONTEXT_REMINDER_KEY: True},
),
HumanMessage(content="请帮我写测试"),
AIMessage(content="好的"),
]
}
prompt, user_msg = middleware._build_title_prompt(state)
assert user_msg == "请帮我写测试"
assert "<system-reminder>" not in prompt
assert "User prefers Python" not in prompt
def test_generate_title_async_strips_think_tags_in_response(self, monkeypatch):
"""Async title generation strips <think> blocks from the model response."""
_set_test_title_config(max_chars=50)
+5
View File
@@ -0,0 +1,5 @@
from deerflow.config.token_usage_config import TokenUsageConfig
def test_token_usage_enabled_by_default():
assert TokenUsageConfig().enabled is True
@@ -1,5 +1,6 @@
"""Tests for TokenUsageMiddleware attribution annotations."""
import logging
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage
@@ -17,6 +18,82 @@ def _make_runtime():
class TestTokenUsageMiddleware:
def test_logs_cache_token_details(self, caplog):
middleware = TokenUsageMiddleware()
message = AIMessage(
content="Here is the final answer.",
usage_metadata={
"input_tokens": 350,
"output_tokens": 240,
"total_tokens": 590,
"input_token_details": {
"audio": 10,
"cache_creation": 200,
"cache_read": 100,
},
"output_token_details": {
"audio": 10,
"reasoning": 200,
},
},
)
with caplog.at_level(
logging.INFO,
logger="deerflow.agents.middlewares.token_usage_middleware",
):
result = middleware.after_model({"messages": [message]}, _make_runtime())
assert result is not None
assert "LLM token usage: input=350 output=240 total=590" in caplog.text
assert "input_token_details={'audio': 10, 'cache_creation': 200, 'cache_read': 100}" in caplog.text
assert "output_token_details={'audio': 10, 'reasoning': 200}" in caplog.text
def test_logs_basic_tokens_when_no_detail_fields_in_usage_metadata(self, caplog):
"""When usage_metadata has only totals (no input_token_details), log just the counts."""
middleware = TokenUsageMiddleware()
message = AIMessage(
content="Here is the final answer.",
usage_metadata={
"input_tokens": 350,
"output_tokens": 240,
"total_tokens": 590,
},
)
with caplog.at_level(
logging.INFO,
logger="deerflow.agents.middlewares.token_usage_middleware",
):
result = middleware.after_model({"messages": [message]}, _make_runtime())
assert result is not None
assert "LLM token usage: input=350 output=240 total=590" in caplog.text
assert "input_token_details" not in caplog.text
def test_no_log_when_usage_metadata_is_missing(self, caplog):
"""When usage_metadata is absent, no token usage line is logged."""
middleware = TokenUsageMiddleware()
message = AIMessage(
content="Here is the final answer.",
response_metadata={
"usage": {
"input_tokens": 350,
"output_tokens": 240,
"total_tokens": 590,
}
},
)
with caplog.at_level(
logging.INFO,
logger="deerflow.agents.middlewares.token_usage_middleware",
):
result = middleware.after_model({"messages": [message]}, _make_runtime())
assert result is not None
assert "LLM token usage" not in caplog.text
def test_annotates_todo_updates_with_structured_actions(self):
middleware = TokenUsageMiddleware()
message = AIMessage(
@@ -0,0 +1,91 @@
"""Regression test: tool args schemas must not emit Pydantic serialization warnings.
DeerFlow tools annotate their runtime parameter as ``Runtime``
(``deerflow.tools.types.Runtime`` = ``ToolRuntime[dict[str, Any], ThreadState]``)
so the LangChain tool framework injects the runtime automatically.
When the inner ``Runtime.context`` field is left as the unbound ``ContextT``
TypeVar (default ``None``), Pydantic's ``model_dump()`` on the auto-generated
args schema emits a ``PydanticSerializationUnexpectedValue`` warning on every
tool call because the actual context DeerFlow installs is a dict. Using the
``Runtime`` alias (which binds the context to ``dict[str, Any]``) keeps
Pydantic's serialization expectations aligned with reality.
"""
from __future__ import annotations
import warnings
import pytest
from langchain.tools import ToolRuntime
from deerflow.sandbox.tools import (
bash_tool,
glob_tool,
grep_tool,
ls_tool,
read_file_tool,
str_replace_tool,
write_file_tool,
)
from deerflow.tools.builtins.present_file_tool import present_file_tool
from deerflow.tools.builtins.setup_agent_tool import setup_agent
from deerflow.tools.builtins.task_tool import task_tool
from deerflow.tools.builtins.update_agent_tool import update_agent
from deerflow.tools.builtins.view_image_tool import view_image_tool
from deerflow.tools.skill_manage_tool import skill_manage_tool
def _make_runtime(context: dict) -> ToolRuntime:
return ToolRuntime(
state={"sandbox": {"sandbox_id": "local"}, "thread_data": {}},
context=context,
config={"configurable": {"thread_id": context.get("thread_id", "thread-1")}},
stream_writer=lambda _: None,
tools=[],
tool_call_id="call-1",
store=None,
)
_TOOL_CASES = [
(bash_tool, {"description": "list", "command": "ls"}),
(ls_tool, {"description": "list", "path": "/tmp"}),
(glob_tool, {"description": "find", "pattern": "*.py", "path": "/tmp"}),
(grep_tool, {"description": "search", "pattern": "x", "path": "/tmp"}),
(read_file_tool, {"description": "read", "path": "/tmp/x"}),
(write_file_tool, {"description": "write", "path": "/tmp/x", "content": "hi"}),
(str_replace_tool, {"description": "replace", "path": "/tmp/x", "old_str": "a", "new_str": "b"}),
(present_file_tool, {"filepaths": ["/tmp/x"], "tool_call_id": "call-1"}),
(view_image_tool, {"image_path": "/tmp/img.png", "tool_call_id": "call-1"}),
(task_tool, {"description": "do", "prompt": "go", "subagent_type": "general-purpose", "tool_call_id": "call-1"}),
(skill_manage_tool, {"action": "list", "name": "demo"}),
(setup_agent, {"soul": "s", "description": "d"}),
(update_agent, {}),
]
@pytest.mark.parametrize(
("tool_obj", "extra_args"),
_TOOL_CASES,
ids=[case[0].name for case in _TOOL_CASES],
)
def test_tool_args_schema_does_not_emit_pydantic_context_warning(tool_obj, extra_args) -> None:
"""``model_dump()`` of the auto-generated args_schema must not warn about ``context``.
The model_dump path is hit by LangChain's ``BaseTool._parse_input`` on every tool
invocation (see langchain_core/tools/base.py:712), so any warning here would fire
once per tool call and pollute production logs.
"""
schema = tool_obj.args_schema
assert schema is not None, f"{tool_obj.name} has no args_schema"
runtime_obj = _make_runtime({"thread_id": "thread-1", "sandbox_id": "local"})
payload = {**extra_args, "runtime": runtime_obj}
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
validated = schema.model_validate(payload)
validated.model_dump()
pydantic_warnings = [w for w in caught if "PydanticSerializationUnexpectedValue" in str(w.message)]
assert not pydantic_warnings, f"{tool_obj.name} args_schema.model_dump() emitted Pydantic context serialization warnings: {[str(w.message) for w in pydantic_warnings]}"
+310
View File
@@ -0,0 +1,310 @@
"""Tests for update_agent tool — partial updates, atomic writes, and validation.
Resolves issue #2616: a custom agent must be able to persist updates to its
own SOUL.md / config.yaml from inside a normal chat (not only from bootstrap).
The tool writes per-user (``{base_dir}/users/{user_id}/agents/{name}/``) so
that one user's update cannot mutate another user's agent.
"""
from __future__ import annotations
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
import yaml
from deerflow.config.agents_config import AgentConfig
from deerflow.tools.builtins.update_agent_tool import update_agent
DEFAULT_USER = "test-user-autouse" # matches the autouse fixture in tests/conftest.py
class _DummyRuntime(SimpleNamespace):
context: dict
tool_call_id: str
def _runtime(agent_name: str | None = "test-agent", tool_call_id: str = "call_1") -> _DummyRuntime:
return _DummyRuntime(context={"agent_name": agent_name} if agent_name is not None else {}, tool_call_id=tool_call_id)
def _make_paths_mock(tmp_path: Path) -> MagicMock:
paths = MagicMock()
paths.base_dir = tmp_path
paths.agent_dir = lambda name: tmp_path / "agents" / name
paths.agents_dir = tmp_path / "agents"
paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name
paths.user_agents_dir = lambda user_id: tmp_path / "users" / user_id / "agents"
return paths
def _user_agent_dir(tmp_path: Path, name: str = "test-agent", user_id: str = DEFAULT_USER) -> Path:
return tmp_path / "users" / user_id / "agents" / name
def _seed_agent(
tmp_path: Path,
name: str = "test-agent",
*,
description: str = "old desc",
soul: str = "old soul",
skills: list[str] | None = None,
user_id: str = DEFAULT_USER,
) -> Path:
"""Create a baseline agent dir with config.yaml and SOUL.md for tests to mutate."""
agent_dir = _user_agent_dir(tmp_path, name, user_id=user_id)
agent_dir.mkdir(parents=True, exist_ok=True)
cfg: dict = {"name": name, "description": description}
if skills is not None:
cfg["skills"] = skills
(agent_dir / "config.yaml").write_text(yaml.safe_dump(cfg, sort_keys=False), encoding="utf-8")
(agent_dir / "SOUL.md").write_text(soul, encoding="utf-8")
return agent_dir
@pytest.fixture()
def patched_paths(tmp_path: Path):
paths_mock = _make_paths_mock(tmp_path)
with patch("deerflow.tools.builtins.update_agent_tool.get_paths", return_value=paths_mock):
# load_agent_config also calls get_paths(); patch the same target it uses.
with patch("deerflow.config.agents_config.get_paths", return_value=paths_mock):
yield paths_mock
@pytest.fixture()
def stub_app_config():
"""Stub get_app_config so model validation accepts only known names."""
fake = MagicMock()
fake.get_model_config.side_effect = lambda name: object() if name in {"gpt-known", "m1"} else None
with patch("deerflow.tools.builtins.update_agent_tool.get_app_config", return_value=fake):
yield fake
# --- Validation tests ---
def test_update_agent_rejects_missing_agent_name(patched_paths):
result = update_agent.func(runtime=_runtime(agent_name=None), soul="new soul")
msg = result.update["messages"][0]
assert "only available inside a custom agent's chat" in msg.content
def test_update_agent_rejects_invalid_agent_name(patched_paths):
result = update_agent.func(runtime=_runtime(agent_name="../../etc/passwd"), soul="x")
msg = result.update["messages"][0]
assert "Invalid agent name" in msg.content
def test_update_agent_rejects_unknown_agent(tmp_path, patched_paths):
result = update_agent.func(runtime=_runtime(agent_name="ghost"), soul="x")
msg = result.update["messages"][0]
assert "does not exist" in msg.content
assert not _user_agent_dir(tmp_path, "ghost").exists()
def test_update_agent_requires_at_least_one_field(tmp_path, patched_paths):
_seed_agent(tmp_path)
result = update_agent.func(runtime=_runtime())
msg = result.update["messages"][0]
assert "No fields provided" in msg.content
def test_update_agent_rejects_unknown_model(tmp_path, patched_paths, stub_app_config):
"""Copilot review: model must be validated against configured models before
being persisted; otherwise _resolve_model_name silently falls back to the
default and the user gets repeated warnings on every later turn."""
_seed_agent(tmp_path)
result = update_agent.func(runtime=_runtime(), model="not-in-config")
msg = result.update["messages"][0]
assert "Unknown model" in msg.content
cfg = yaml.safe_load((_user_agent_dir(tmp_path) / "config.yaml").read_text())
assert "model" not in cfg, "Invalid model must not have been written to config.yaml"
def test_update_agent_accepts_known_model(tmp_path, patched_paths, stub_app_config):
_seed_agent(tmp_path)
result = update_agent.func(runtime=_runtime(), model="gpt-known")
cfg = yaml.safe_load((_user_agent_dir(tmp_path) / "config.yaml").read_text())
assert cfg["model"] == "gpt-known"
assert "model" in result.update["messages"][0].content
# --- Partial update tests ---
def test_update_agent_updates_soul_only(tmp_path, patched_paths):
agent_dir = _seed_agent(tmp_path, description="keep me", soul="old soul")
result = update_agent.func(runtime=_runtime(), soul="brand new soul")
assert (agent_dir / "SOUL.md").read_text() == "brand new soul"
cfg = yaml.safe_load((agent_dir / "config.yaml").read_text())
assert cfg["description"] == "keep me", "description must be preserved"
assert "soul" in result.update["messages"][0].content
def test_update_agent_updates_description_only(tmp_path, patched_paths):
agent_dir = _seed_agent(tmp_path, description="old desc", soul="keep this soul")
result = update_agent.func(runtime=_runtime(), description="new desc")
cfg = yaml.safe_load((agent_dir / "config.yaml").read_text())
assert cfg["description"] == "new desc"
assert (agent_dir / "SOUL.md").read_text() == "keep this soul", "SOUL.md must be preserved"
assert "description" in result.update["messages"][0].content
def test_update_agent_skills_empty_list_disables_all(tmp_path, patched_paths):
agent_dir = _seed_agent(tmp_path, skills=["a", "b"])
result = update_agent.func(runtime=_runtime(), skills=[])
cfg = yaml.safe_load((agent_dir / "config.yaml").read_text())
assert cfg["skills"] == [], "empty list must persist as empty list (not be omitted)"
assert "skills" in result.update["messages"][0].content
def test_update_agent_skills_omitted_keeps_existing(tmp_path, patched_paths):
agent_dir = _seed_agent(tmp_path, skills=["alpha", "beta"])
update_agent.func(runtime=_runtime(), description="bumped")
cfg = yaml.safe_load((agent_dir / "config.yaml").read_text())
assert cfg["skills"] == ["alpha", "beta"], "omitting skills must preserve the existing whitelist"
def test_update_agent_no_op_when_values_match_existing(tmp_path, patched_paths):
_seed_agent(tmp_path, description="same")
result = update_agent.func(runtime=_runtime(), description="same")
assert "No changes applied" in result.update["messages"][0].content
def test_update_agent_forces_name_to_directory(tmp_path, patched_paths):
"""Copilot review: if the existing config.yaml has a drifted ``name`` field,
update_agent must rewrite it to match the directory name so on-disk state
stays consistent with the runtime context."""
agent_dir = _user_agent_dir(tmp_path)
agent_dir.mkdir(parents=True)
(agent_dir / "config.yaml").write_text(yaml.safe_dump({"name": "drifted-name", "description": "old"}, sort_keys=False), encoding="utf-8")
(agent_dir / "SOUL.md").write_text("soul", encoding="utf-8")
update_agent.func(runtime=_runtime(), description="bumped")
cfg = yaml.safe_load((agent_dir / "config.yaml").read_text())
assert cfg["name"] == "test-agent", "config.yaml name must follow the directory name, not legacy yaml content"
# --- Atomicity tests ---
def test_update_agent_failure_preserves_existing_files(tmp_path, patched_paths):
agent_dir = _seed_agent(tmp_path, soul="original soul")
real_replace = Path.replace
def _explode(self, target):
if str(target).endswith("SOUL.md"):
raise OSError("disk full")
return real_replace(self, target)
with patch.object(Path, "replace", _explode):
result = update_agent.func(runtime=_runtime(), soul="poisoned content")
assert (agent_dir / "SOUL.md").read_text() == "original soul", "atomic write must not corrupt existing SOUL.md"
assert "Error" in result.update["messages"][0].content
leftover_tmps = list(agent_dir.glob("*.tmp"))
assert leftover_tmps == [], "temp files must be cleaned up on failure"
def test_update_agent_soul_failure_does_not_replace_config(tmp_path, patched_paths):
"""Copilot review: if both config.yaml and SOUL.md are scheduled to be
written and SOUL.md staging fails *before* any rename, config.yaml must
NOT be replaced. The fix stages every temp file first and only renames
after all temps exist on disk."""
agent_dir = _seed_agent(tmp_path, description="original-desc", soul="original soul")
real_named_temp_file = __import__("tempfile").NamedTemporaryFile
call_count = {"n": 0}
def _explode_on_soul(*args, **kwargs):
# Inspect target dir + suffix; the SOUL temp file is the second one we stage.
call_count["n"] += 1
if call_count["n"] >= 2:
raise OSError("disk full while staging SOUL.md")
return real_named_temp_file(*args, **kwargs)
with patch("deerflow.tools.builtins.update_agent_tool.tempfile.NamedTemporaryFile", side_effect=_explode_on_soul):
result = update_agent.func(runtime=_runtime(), description="new-desc", soul="new soul")
cfg = yaml.safe_load((agent_dir / "config.yaml").read_text())
assert cfg["description"] == "original-desc", "config.yaml must not be replaced when SOUL.md staging fails"
assert (agent_dir / "SOUL.md").read_text() == "original soul"
assert "Error" in result.update["messages"][0].content
assert list(agent_dir.glob("*.tmp")) == [], "staged config.yaml temp must be cleaned up on SOUL.md failure"
# --- Per-user isolation ---
def test_update_agent_only_writes_under_current_user(tmp_path, patched_paths):
"""An update from user 'alice' must never touch user 'bob's agent files."""
from deerflow.runtime.user_context import reset_current_user, set_current_user
# Seed an agent for both users with the same name.
alice_dir = _seed_agent(tmp_path, name="shared", description="alice-desc", soul="alice soul", user_id="alice")
bob_dir = _seed_agent(tmp_path, name="shared", description="bob-desc", soul="bob soul", user_id="bob")
# Override the autouse contextvar so update_agent runs as Alice.
token = set_current_user(SimpleNamespace(id="alice"))
try:
update_agent.func(runtime=_runtime(agent_name="shared"), description="alice-bumped")
finally:
reset_current_user(token)
alice_cfg = yaml.safe_load((alice_dir / "config.yaml").read_text())
bob_cfg = yaml.safe_load((bob_dir / "config.yaml").read_text())
assert alice_cfg["description"] == "alice-bumped"
assert bob_cfg["description"] == "bob-desc", "bob's config.yaml must not have been touched"
assert (bob_dir / "SOUL.md").read_text() == "bob soul"
# --- Loader passthrough sanity check ---
def test_update_agent_round_trips_known_fields(tmp_path, patched_paths):
"""update_agent reads through load_agent_config so all fields the loader
knows about (name, description, model, tool_groups, skills) round-trip
on a partial update.
Note: ``load_agent_config`` strips unknown fields before constructing
AgentConfig, so legacy/extra YAML keys are NOT preserved across
updates by design.
"""
_seed_agent(tmp_path, description="legacy")
fake_cfg = AgentConfig(name="test-agent", description="legacy", skills=["s1"], tool_groups=["g1"], model="m1")
fake_app_config = MagicMock()
fake_app_config.get_model_config.return_value = object()
with patch("deerflow.tools.builtins.update_agent_tool.load_agent_config", return_value=fake_cfg):
with patch("deerflow.tools.builtins.update_agent_tool.get_app_config", return_value=fake_app_config):
update_agent.func(runtime=_runtime(), description="bumped")
cfg = yaml.safe_load((_user_agent_dir(tmp_path) / "config.yaml").read_text())
assert cfg["description"] == "bumped"
assert cfg["skills"] == ["s1"]
assert cfg["tool_groups"] == ["g1"]
assert cfg["model"] == "m1"
+10 -5
View File
@@ -126,15 +126,18 @@ class TestWriteUploadFileNoSymlink:
assert dest.read_bytes() == b"new contents"
assert os.stat(dest).st_nlink == 1
def test_fails_closed_without_no_follow_support(self, tmp_path, monkeypatch):
def test_fallback_without_no_follow_support_succeeds(self, tmp_path, monkeypatch):
monkeypatch.delattr(os, "O_NOFOLLOW", raising=False)
with pytest.raises(UnsafeUploadPathError, match="O_NOFOLLOW"):
write_upload_file_no_symlink(tmp_path, "notes.txt", b"hello")
assert not (tmp_path / "notes.txt").exists()
# When O_NOFOLLOW is absent (Windows), the function falls back to
# a dual-lstat + fstat approach and succeeds.
result = write_upload_file_no_symlink(tmp_path, "notes.txt", b"hello")
assert result == tmp_path / "notes.txt"
assert (tmp_path / "notes.txt").read_bytes() == b"hello"
def test_open_uses_nonblocking_flag_when_available(self, tmp_path):
if not hasattr(os, "O_NONBLOCK"):
pytest.skip("O_NONBLOCK not available on this platform")
with patch("deerflow.uploads.manager.os.open", side_effect=OSError(errno.ENXIO, "no reader")) as open_mock:
with pytest.raises(UnsafeUploadPathError, match="Unsafe upload destination"):
write_upload_file_no_symlink(tmp_path, "pipe.txt", b"hello")
@@ -144,6 +147,8 @@ class TestWriteUploadFileNoSymlink:
@pytest.mark.parametrize("open_errno", [errno.ENXIO, errno.EAGAIN])
def test_nonblocking_special_file_open_errors_are_unsafe(self, tmp_path, open_errno):
if not hasattr(os, "O_NONBLOCK"):
pytest.skip("O_NONBLOCK not available on this platform")
with patch("deerflow.uploads.manager.os.open", side_effect=OSError(open_errno, "would block")):
with pytest.raises(UnsafeUploadPathError, match="Unsafe upload destination"):
write_upload_file_no_symlink(tmp_path, "pipe.txt", b"hello")
+33
View File
@@ -61,6 +61,39 @@ def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_pat
sandbox.update_file.assert_not_called()
def test_upload_files_auto_renames_duplicate_form_filenames(tmp_path):
thread_uploads_dir = tmp_path / "uploads"
thread_uploads_dir.mkdir(parents=True)
provider = MagicMock()
provider.uses_thread_data_mounts = True
with (
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "get_sandbox_provider", return_value=provider),
):
result = asyncio.run(
call_unwrapped(
uploads.upload_files,
"thread-local",
request=MagicMock(),
files=[
UploadFile(filename="data.txt", file=BytesIO(b"first")),
UploadFile(filename="data.txt", file=BytesIO(b"second")),
],
config=SimpleNamespace(),
)
)
assert result.success is True
assert [file_info["filename"] for file_info in result.files] == ["data.txt", "data_1.txt"]
assert "original_filename" not in result.files[0]
assert result.files[1]["original_filename"] == "data.txt"
assert (thread_uploads_dir / "data.txt").read_bytes() == b"first"
assert (thread_uploads_dir / "data_1.txt").read_bytes() == b"second"
def test_upload_files_skips_acquire_when_thread_data_is_mounted(tmp_path):
thread_uploads_dir = tmp_path / "uploads"
thread_uploads_dir.mkdir(parents=True)
+10 -10
View File
@@ -788,7 +788,7 @@ requires-dist = [
{ name = "lark-oapi", specifier = ">=1.4.0" },
{ name = "markdown-to-mrkdwn", specifier = ">=0.3.1" },
{ name = "pyjwt", specifier = ">=2.9.0" },
{ name = "python-multipart", specifier = ">=0.0.26" },
{ name = "python-multipart", specifier = ">=0.0.27" },
{ name = "python-telegram-bot", specifier = ">=21.0" },
{ name = "slack-sdk", specifier = ">=3.33.0" },
{ name = "sse-starlette", specifier = ">=2.1.0" },
@@ -1725,7 +1725,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "1.3.2"
version = "1.3.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jsonpatch" },
@@ -1738,9 +1738,9 @@ dependencies = [
{ name = "typing-extensions" },
{ name = "uuid-utils" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a8/03/7219502e8ca728d65eb44d7a3eb60239230742a70dbfc9241b9bfd61c4ab/langchain_core-1.3.2.tar.gz", hash = "sha256:fd7a50b2f28ba561fd9d7f5d2760bc9e06cf00cdf820a3ccafe88a94ffa8d5b7", size = 911813, upload-time = "2026-04-24T15:49:23.699Z" }
sdist = { url = "https://files.pythonhosted.org/packages/d3/ae/8b74458fc3850ec3d150eb9f45e857db129dafa801fb5cf173dfc9f8bbf3/langchain_core-1.3.3.tar.gz", hash = "sha256:fa510a5db8efdc0c6ff41c0939fb5c00a0183c11f6b84233e892e3227ff69182", size = 915041, upload-time = "2026-05-05T19:02:36.612Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7d/d5/8fa4431007cbb7cfed7590f4d6a5dea3ad724f4174d248f6642ef5ce7d05/langchain_core-1.3.2-py3-none-any.whl", hash = "sha256:d44a66127f9f8db735bdfd0ab9661bccb47a97113cfd3f2d89c74864422b7274", size = 542390, upload-time = "2026-04-24T15:49:21.991Z" },
{ url = "https://files.pythonhosted.org/packages/1f/01/4771b7ab2af1d1aba5b710bd8f13d9225c609425214b357590a17b01be77/langchain_core-1.3.3-py3-none-any.whl", hash = "sha256:18aae8506f37da7f74398492279a7d6efcee4f8e23c4c41c7af080eeb7ef7bd1", size = 543857, upload-time = "2026-05-05T19:02:34.52Z" },
]
[[package]]
@@ -2145,14 +2145,14 @@ wheels = [
[[package]]
name = "mako"
version = "1.3.11"
version = "1.3.12"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markupsafe" },
]
sdist = { url = "https://files.pythonhosted.org/packages/59/8a/805404d0c0b9f3d7a326475ca008db57aea9c5c9f2e1e39ed0faa335571c/mako-1.3.11.tar.gz", hash = "sha256:071eb4ab4c5010443152255d77db7faa6ce5916f35226eb02dc34479b6858069", size = 399811, upload-time = "2026-04-14T20:19:51.493Z" }
sdist = { url = "https://files.pythonhosted.org/packages/00/62/791b31e69ae182791ec67f04850f2f062716bbd205483d63a215f3e062d3/mako-1.3.12.tar.gz", hash = "sha256:9f778e93289bd410bb35daadeb4fc66d95a746f0b75777b942088b7fd7af550a", size = 400219, upload-time = "2026-04-28T19:01:08.512Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/68/a5/19d7aaa7e433713ffe881df33705925a196afb9532efc8475d26593921a6/mako-1.3.11-py3-none-any.whl", hash = "sha256:e372c6e333cf004aa736a15f425087ec977e1fcbd2966aae7f17c8dc1da27a77", size = 78503, upload-time = "2026-04-14T20:19:53.233Z" },
{ url = "https://files.pythonhosted.org/packages/bc/b1/a0ec7a5a9db730a08daef1fdfb8090435b82465abbf758a596f0ea88727e/mako-1.3.12-py3-none-any.whl", hash = "sha256:8f61569480282dbf557145ce441e4ba888be453c30989f879f0d652e39f53ea9", size = 78521, upload-time = "2026-04-28T19:01:10.393Z" },
]
[[package]]
@@ -3532,11 +3532,11 @@ wheels = [
[[package]]
name = "python-multipart"
version = "0.0.26"
version = "0.0.27"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/88/71/b145a380824a960ebd60e1014256dbb7d2253f2316ff2d73dfd8928ec2c3/python_multipart-0.0.26.tar.gz", hash = "sha256:08fadc45918cd615e26846437f50c5d6d23304da32c341f289a617127b081f17", size = 43501, upload-time = "2026-04-10T14:09:59.473Z" }
sdist = { url = "https://files.pythonhosted.org/packages/69/9b/f23807317a113dc36e74e75eb265a02dd1a4d9082abc3c1064acd22997c4/python_multipart-0.0.27.tar.gz", hash = "sha256:9870a6a8c5a20a5bf4f07c017bd1489006ff8836cff097b6933355ee2b49b602", size = 44043, upload-time = "2026-04-27T10:51:26.649Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9a/22/f1925cdda983ab66fc8ec6ec8014b959262747e58bdca26a4e3d1da29d56/python_multipart-0.0.26-py3-none-any.whl", hash = "sha256:c0b169f8c4484c13b0dcf2ef0ec3a4adb255c4b7d18d8e420477d2b1dd03f185", size = 28847, upload-time = "2026-04-10T14:09:58.131Z" },
{ url = "https://files.pythonhosted.org/packages/99/78/4126abcbdbd3c559d43e0db7f7b9173fc6befe45d39a2856cc0b8ec2a5a6/python_multipart-0.0.27-py3-none-any.whl", hash = "sha256:6fccfad17a27334bd0193681b369f476eda3409f17381a2d65aa7df3f7275645", size = 29254, upload-time = "2026-04-27T10:51:24.997Z" },
]
[[package]]
+49 -5
View File
@@ -15,7 +15,7 @@
# ============================================================================
# Bump this number when the config schema changes.
# Run `make config-upgrade` to merge new fields into your local config.yaml.
config_version: 8
config_version: 9
# ============================================================================
# Logging
@@ -30,7 +30,7 @@ log_level: info
# When enabled, DeerFlow records input/output/total tokens per model call
# and shows usage metadata in the workspace UI when providers return it.
token_usage:
enabled: false
enabled: true
# ============================================================================
# Models Configuration
@@ -506,6 +506,29 @@ tools:
tool_search:
enabled: false
# ============================================================================
# Loop Detection Configuration
# ============================================================================
# Detect and interrupt repeated identical tool-call loops.
# Frequency thresholds are safety limits for repeated use of the same tool type.
loop_detection:
enabled: true
warn_threshold: 3
hard_limit: 5
window_size: 20
max_tracked_threads: 100
tool_freq_warn: 30
tool_freq_hard_limit: 50
# Per-tool overrides for tool_freq_warn / tool_freq_hard_limit. Values can be
# higher or lower than the global defaults. Commonly used to raise thresholds
# for high-frequency tools like bash in batch workflows (e.g. RNA-seq pipelines)
# without weakening protection on every other tool.
# tool_freq_overrides:
# bash:
# warn: 150
# hard_limit: 300
# ============================================================================
# Sandbox Configuration
# ============================================================================
@@ -578,6 +601,11 @@ sandbox:
# # Optional: Prefix for container names (default: deer-flow-sandbox)
# # container_prefix: deer-flow-sandbox
#
# # Optional: Automatically restart crashed sandbox containers (default: true)
# # When enabled, a dead container is detected on the next tool call and
# # transparently replaced with a fresh one. Set to false to disable.
# # auto_restart: true
#
# # Optional: Additional mount directories from host to container
# # NOTE: Skills directory is automatically mounted from skills.path to skills.container_path
# # mounts:
@@ -848,9 +876,25 @@ skill_evolution:
#
# Postgres mode: put your connection URL in .env as DATABASE_URL,
# then reference it here with $DATABASE_URL.
# Install the driver first:
# Local: uv sync --extra postgres
# Docker: UV_EXTRAS=postgres docker compose build
#
# Install the driver — Issue #2754 fix lands `UV_EXTRAS` in every code path:
# Local `make dev` auto-detects from `database.backend: postgres` below
# and passes `--extra postgres` to `uv sync` on every restart, so
# the extra is no longer wiped. To opt in explicitly (or layer
# extras like `postgres,ollama`), set in project-root .env:
# UV_EXTRAS=postgres
# Docker dev `make docker-start` reads `UV_EXTRAS` from project-root .env via
# `env_file`. Set:
# UV_EXTRAS=postgres
# Multiple extras (`postgres,ollama`) supported here too — see
# docker/dev-entrypoint.sh.
# Docker img build-arg `UV_EXTRAS=postgres docker compose build` — single
# extra only at build time (backend/Dockerfile passes the value
# as one token to `--extra`).
#
# First-time bootstrap (before `make dev`):
# cd backend && uv sync --all-packages --extra postgres
# (--all-packages propagates the extra into workspace members — see PR #2584)
#
# NOTE: When both `checkpointer` and `database` are configured,
# `checkpointer` takes precedence for LangGraph state persistence.
+85
View File
@@ -0,0 +1,85 @@
#!/usr/bin/env sh
#
# DeerFlow gateway dev entrypoint — runs inside the docker-compose-dev gateway
# container. Extracted from docker/docker-compose-dev.yaml's inline `command:`
# (PR #2767, addressing review on Issue #2754).
#
# Responsibilities:
# 1. Resolve `--extra X` flags from UV_EXTRAS (comma- or whitespace-separated,
# mirroring scripts/detect_uv_extras.py for parity with local `make dev`).
# 2. Validate each extra against [A-Za-z][A-Za-z0-9_-]* so a stray shell
# metacharacter in `.env` cannot reach `uv sync`.
# 3. `uv sync --all-packages` so workspace member extras (deerflow-harness's
# postgres extra in particular) are installed — see PR #2584.
# 4. Self-heal: if the first sync fails, recreate .venv and retry once.
# 5. Hand off to uvicorn with reload, replacing this shell so uvicorn becomes
# PID 1 inside the container.
#
# Anchored at /bin/sh (not bash) since alpine-based base images may not ship
# bash. Uses POSIX-only constructs throughout.
set -e
# `--print-extras` is a dry-run hook: parse + validate UV_EXTRAS, print the
# resulting `--extra X` flags to stdout, and exit. Used by the unit test in
# backend/tests/test_dev_entrypoint.py and useful for ad-hoc debugging.
PRINT_EXTRAS_ONLY=0
if [ "${1:-}" = "--print-extras" ]; then
PRINT_EXTRAS_ONLY=1
fi
# Mirror the legacy command's behavior: redirect both stdout and stderr to the
# host-mounted log file (../logs/gateway.log → /app/logs/gateway.log). Skip
# the redirect under --print-extras so the test runner can capture stdout.
if [ "$PRINT_EXTRAS_ONLY" = "0" ]; then
exec >/app/logs/gateway.log 2>&1
fi
# ── Resolve extras ──────────────────────────────────────────────────────────
EXTRAS_FLAGS=""
if [ -n "${UV_EXTRAS:-}" ]; then
# Normalize comma → space, then split on whitespace via the unquoted `for`.
for raw in $(printf '%s' "$UV_EXTRAS" | tr ',' ' '); do
[ -z "$raw" ] && continue
# Reject anything that does not look like an identifier.
# Two patterns: leading non-letter, or any non-[A-Za-z0-9_-] character.
case "$raw" in
[!A-Za-z]* | *[!A-Za-z0-9_-]*)
echo "[startup] UV_EXTRAS entry '$raw' is invalid (must match [A-Za-z][A-Za-z0-9_-]*) — aborting" >&2
exit 1
;;
esac
EXTRAS_FLAGS="$EXTRAS_FLAGS --extra $raw"
done
fi
if [ "$PRINT_EXTRAS_ONLY" = "1" ]; then
# Trim leading space for tidier output, then exit.
printf '%s\n' "${EXTRAS_FLAGS# }"
exit 0
fi
if [ -n "$EXTRAS_FLAGS" ]; then
echo "[startup] uv extras:$EXTRAS_FLAGS"
fi
# ── Sync dependencies (with self-heal) ──────────────────────────────────────
cd /app/backend
# `--all-packages` propagates extras into workspace members (PR #2584).
# `$EXTRAS_FLAGS` intentionally unquoted so each `--extra X` becomes its own arg.
# shellcheck disable=SC2086 # word-splitting is intentional here
if ! uv sync --all-packages $EXTRAS_FLAGS; then
echo "[startup] uv sync failed; recreating .venv and retrying once"
uv venv --allow-existing .venv
# shellcheck disable=SC2086
uv sync --all-packages $EXTRAS_FLAGS
fi
# ── Hand off to uvicorn ─────────────────────────────────────────────────────
PYTHONPATH=. exec uv run uvicorn app.gateway.app:app \
--host 0.0.0.0 --port 8001 \
--reload --reload-include='*.yaml .env'

Some files were not shown because too many files have changed in this diff Show More