Compare commits

...

34 Commits

Author SHA1 Message Date
greatmengqi 2eb45e9bb5 fix: thread app config through client and sync providers 2026-05-02 12:07:26 +08:00
greatmengqi 8ba01dfd83 refactor: thread app_config through lead and subagent task path (#2666)
* refactor: thread app config through lead prompt

* fix: honor explicit app config across runtime paths

* style: format subagent executor tests

* fix: thread resolved app config and guard subagents-only fallback

Address two PR review findings:

1. _create_summarization_middleware passed the original (possibly None)
   app_config into create_chat_model, forcing the model factory back to
   ambient get_app_config() and risking config drift between the
   middleware's resolved view and the model's view. Pass the resolved
   AppConfig instance through end-to-end.

2. get_available_subagent_names accepted Any-typed config and forwarded
   it to is_host_bash_allowed, which reads ``.sandbox``. A
   SubagentsAppConfig (also accepted upstream as a sum-type input) has
   no ``.sandbox`` attribute and would be silently treated as "no
   sandbox configured", incorrectly disabling the bash subagent. Guard
   on hasattr and fall back to ambient lookup otherwise.

Adds regression tests for both paths.

* chore: simplify hasattr guard and tighten regression tests

- Collapse if/else into ternary in get_available_subagent_names; hasattr(None, ...) is False so the explicit None check was redundant.
- Drop comments that narrate the change rather than explain non-obvious WHY (test names already convey intent).
- Replace stringly-typed sentinel "no-arg" in regression test with direct args tuple comparison.

---------

Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
2026-05-02 06:37:49 +08:00
Willem Jiang 189b82405c fix(sandbox): pass no_change_timeout to exec_command to prevent 120s premature termination (#2685)
* fix(sandbox): pass no_change_timeout to exec_command to prevent 120s premature termination

  The agent_sandbox library's shell API defaults no_change_timeout to 120
  seconds. When AioSandbox.execute_command() called exec_command() without
  this parameter, commands producing no output for 120s would return with
  NO_CHANGE_TIMEOUT status even though the script was still running.

  Pass no_change_timeout=600 to all exec_command calls (matching the
  client-level HTTP timeout) so long-running commands are not cut short.

  Fixes #2668

* test(sandbox): add assertions for no_change_timeout in execute_command and list_dir

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/2f37bc72-0826-4443-a6ba-e5b78c22fb5a

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-05-01 22:27:02 +08:00
Nan Gao 487c1d939f fix(subagents): use model override for tools and middleware (#2641)
* fix(subagents): use model override for tools and middleware

* fix(config): resolve effective subagent model

* fix(subagents): defer app config loading

* fix(subagents): fully defer config.yaml load in executor __init__

The previous attempt only relocated the explicit get_app_config() call,
but left resolve_subagent_model_name(...) running eagerly in __init__.
That helper has its own internal get_app_config() fallback, which still
fired when both app_config and parent_model were None and
config.model == "inherit" — exactly the path unit tests hit, breaking
21 tests in CI with FileNotFoundError: config.yaml.

Skip the eager resolve in __init__ when it would require loading the
config file, and defer to _create_agent (which already has the
app_config or get_app_config() fallback).
2026-05-01 22:21:10 +08:00
Nan Gao c09c334544 fix(harness): resolve runtime paths from project root (#2642)
* fix(harness): resolve runtime paths from project root

* docs(config): update

* fix(config): address runtime path review feedback

* test(config): fix skills path e2e root

* test(config): cover legacy config fallback when project root lacks config files

Verifies that when DEER_FLOW_PROJECT_ROOT is unset and cwd has no
config.yaml/extensions_config.json, AppConfig and ExtensionsConfig fall back
to the legacy backend/repo-root candidates — the backward-compat path
requested in PR #2642 review.

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-01 22:19:50 +08:00
KiteEater 8939ccaed2 fix(uploads): enforce streaming upload limits in gateway (#2589)
* fix: enforce gateway upload limits

* fix: acquire sandbox before upload writes

* Fix upload limit config wiring

* Sanitize upload size error filenames

* test: call upload routes unwrapped

* fix: guard upload limits endpoint

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-01 20:19:30 +08:00
JerryLee 83938cf35a fix(subagents): propagate user context across threaded execution (#2676) 2026-05-01 16:27:18 +08:00
yangzheli 78633c69ac fix(agents): propagate agent_name into ToolRuntime.context for setup_agent (#2679)
* fix(agents): propagate agent_name into ToolRuntime.context for setup_agent (#2677)

When creating a custom agent via the web UI, SOUL.md was always written
to the global base_dir/SOUL.md instead of agents/<name>/SOUL.md.

Root cause: the bootstrap flow sends agent_name via body.context, but
two layers were broken:

1. services.py only forwarded body.context keys into config["configurable"];
   config["context"] was never populated.
2. worker.py constructed the parent Runtime with a hard-coded
   {thread_id, run_id} context, ignoring config["context"] entirely.

After the langgraph >= 1.1.9 bump (#98a5b34f), ToolRuntime.context no
longer falls back to configurable, so setup_agent's
runtime.context.get("agent_name") returned None and the tool's silent
agent_name=None -> base_dir fallback kicked in, overwriting the global
SOUL.md.

Fix:
- services.py: extract merge_run_context_overrides() and write the
  whitelisted context keys into both configurable (legacy readers) and
  context (langgraph 1.1+ ToolRuntime consumers).
- worker.py: extract _build_runtime_context() and merge config["context"]
  into the Runtime's context (without letting callers override
  thread_id/run_id).

The base_dir fallback in setup_agent_tool.py is left in place because
the IM /bootstrap channel command depends on it. That code path can
be tightened in a follow-up.

Adds regression tests covering both helpers.

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-05-01 16:00:11 +08:00
greatmengqi 8b61c94e1d fix: keep lead agent graph factory signature compatible (#2678)
Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
2026-05-01 15:43:28 +08:00
Xun 1ad1420e31 refactor(skills): Unified skill storage capability (#2613) 2026-05-01 13:23:26 +08:00
He Wang eba3b9e18d fix(config): unify log_level from config.yaml across Gateway and debug entry points (#2601)
Centralize log level parsing in `logging_level_from_config()` and
application in `apply_logging_level()` within `deerflow.config.app_config`.

- Gateway lifespan applies configured log level on startup
- `debug.py` uses shared helpers instead of local duplicates
- `apply_logging_level()` targets only `deerflow`/`app` logger hierarchies
  so third-party library verbosity is not affected; root handler levels
  are only lowered (never raised) to allow configured loggers through
  without suppressing third-party output; root logger level is not modified
- Config field description updated to clarify scope
- Tests save/restore global logging state to avoid test pollution

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-30 22:27:14 +08:00
Willem Jiang c0da278269 fix(memory): replace short-lived asyncio.run() with persistent event loop (#2627)
* fix(memory): replace short-lived asyncio.run() with persistent event loop to prevent zombie httpx connections

  The memory updater used asyncio.run() inside daemon threads, creating
  and destroying short-lived event loops on every update. Langchain
  providers (e.g. langchain-anthropic) cache httpx AsyncClient instances
  globally via @lru_cache, so SSL connections created on a loop that is
  subsequently destroyed become zombie connections in the shared pool.
  When the main agent's lead run later reuses one of these connections,
  httpx/anyio triggers RuntimeError: Event loop is closed during
  connection cleanup.

  Replace the ThreadPoolExecutor + asyncio.run() pattern with a
  _MemoryLoopRunner that maintains a single persistent event loop in a
  daemon thread for the process lifetime. Since the loop never closes,
  connections bound to it never become invalid. The _run_async_update_sync
  function now submits coroutines to this persistent loop via
  run_coroutine_threadsafe instead of creating throwaway loops.

* update the code to address the review comments

* Fix the review comments of 2615

 P1 — user_id forwarded through sync path: Added user_id parameter to _prepare_update_prompt, _finalize_update, and _do_update_memory_sync, and forwarded it to get_memory_data(agent_name, user_id=user_id) and
  save(..., user_id=user_id). The update_memory() entry point now passes user_id through both the executor.submit path and the direct call path. Added TestUserIdForwarding with two regression tests (sync + async)
   verifying get_memory_data and save receive the correct user_id.

  P2 — aupdate_memory() delegates to sync: Replaced the model.ainvoke() call with asyncio.to_thread(self._do_update_memory_sync, ...). This eliminates the unsafe async provider client path entirely — all memory
  updater entry points now use the isolated sync model.invoke() path. Updated the test from asserting ainvoke is awaited to asserting invoke is called and ainvoke is not.

  Nit — duplicate comment removed: Removed the duplicated # Matches sentences... comment on line 230.

* Chore(test): update the code of test_memory_updater

---------

Co-authored-by: rayhpeng <rayhpeng@gmail.com>
2026-04-30 17:59:57 +08:00
KiteEater 7dea1666ce fix: avoid temporary event loops in async subagent execution (#2414)
* fix: avoid temporary event loops in async subagent execution

* Rename isolated subagent loop globals

* Harden isolated subagent loop shutdown and logging

* Sort subagent executor imports

* Format subagent executor

* Remove isolated loop pool from subagent executor

* Format subagent executor cleanup

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-30 15:29:17 +08:00
Chincherry93 88d47f677f fix(nginx): add catch-all /api/ location for auth routes (#2657)
The recent refactor (7bf618de) removed the langgraph upstream and added
individual /api/* prefix locations for models, memory, mcp, skills,
agents, threads, and sandboxes. However, /api/v1/auth/* routes (login,
register, setup-status, etc.) were not covered by any explicit location
block, causing them to fall through to the frontend catch-all.

In Docker production builds, Next.js rewrites are baked at build time
with http://127.0.0.1:8001 (the gateway is unreachable from the
frontend container's localhost), resulting in ECONNREFUSED errors for
all auth operations.

Adding a catch-all `location /api/` block after the more specific
prefix/regex locations ensures all gateway API routes are properly
proxied. nginx's location matching priority means the more specific
locations above still take precedence for /api/langgraph/, /api/models,
/api/memory, /api/mcp, /api/skills, /api/agents, /api/threads/*, and
/api/sandboxes.

Co-authored-by: Chincherry93 <Chincherry93@users.noreply.github.com>
2026-04-30 15:21:22 +08:00
greatmengqi 38714b6ceb refactor: thread app_config through middleware factories (#2652)
* refactor: thread app_config through middleware factories

Continues the incremental config-refactor sequence (#2611 root, #2612 lead
path) one layer deeper into the middleware factories. Two ambient lookups
inside _build_runtime_middlewares are eliminated and the LLMErrorHandling
band-aid removed:

- _build_runtime_middlewares / build_lead_runtime_middlewares /
  build_subagent_runtime_middlewares now require app_config: AppConfig.
- get_guardrails_config() inside the factory is replaced with
  app_config.guardrails (semantically identical — same default-factory
  GuardrailsConfig — verified by direct equality check).
- LLMErrorHandlingMiddleware.__init__ now requires app_config and reads
  circuit_breaker fields directly. The class-level
  circuit_failure_threshold / circuit_recovery_timeout_sec defaults are
  removed along with the try/except (FileNotFoundError, RuntimeError):
  pass band-aid — the let-it-crash invariant the rest of the refactor
  enforces.

Caller chain (already-resolved app_config sources):
- _build_middlewares in lead_agent/agent.py: reorder so
  resolved_app_config = app_config or get_app_config() is computed BEFORE
  build_lead_runtime_middlewares is called, then passed as kwarg.
- SubagentExecutor: optional app_config parameter (mirrors the lead-agent
  pattern); _create_agent does the same `or get_app_config()` fallback at
  agent-build time, so task_tool callers don't need to plumb app_config
  through yet (typed-context plumbing for tool runtimes is a separate
  refactor).

Tests:
- test_llm_error_handling_middleware: _make_app_config helper using
  AppConfig(sandbox=SandboxConfig(use="test")) — same minimal-config
  pattern conftest already uses. Three direct LLMErrorHandlingMiddleware()
  calls each followed by post-construction circuit_breaker mutation fold
  cleanly into _build_middleware(circuit_failure_threshold=...,
  circuit_recovery_timeout_sec=...).

Verification:
- tests/test_llm_error_handling_middleware.py — 14 passed
- tests/test_subagent_executor.py — 28 passed
- tests/test_tool_error_handling_middleware.py — 6 passed
- tests/test_task_tool_core_logic.py — 18 passed (verifies task_tool
  unchanged behavior)
- Full suite: 2697 passed, 3 skipped. The single intermittent failure in
  tests/test_client_e2e.py::test_tool_call_produces_events is pre-existing
  LLM flakiness (the test asserts the model decided to call a tool;
  reproduces 1/3 on unchanged main as well).

* fix: address middleware app config review comments

* fix: satisfy app config annotation lint

* test: cover explicit app config middleware wiring

---------

Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
2026-04-30 12:41:09 +08:00
Hinotobi 74081a85a6 [security] fix(sandbox): bind local Docker ports to loopback (#2633)
* fix(sandbox): bind local Docker ports to loopback

* fix(sandbox): preserve IPv6 loopback Docker binds

* fix(sandbox): log Docker bind host selection
2026-04-30 11:40:28 +08:00
Jsonz 24a5a00679 fix: avoid duplicate call to extractReasoningContentFromMessage (#2661)
In convertToSteps(), the extractReasoningContentFromMessage function was
called twice for the same message - once to check if reasoning exists and
again to assign it to the step object. Reuse the already-extracted value
from the local variable instead.
2026-04-30 11:33:49 +08:00
He Wang 08afdcb907 feat(channels): add DingTalk channel integration (#2628)
* feat(channels): add DingTalk channel integration

Add a new DingTalk messaging channel using the dingtalk-stream SDK
with Stream Push (WebSocket), requiring no public IP. Supports both
plain sampleMarkdown replies and optional AI Card streaming for a
typewriter effect when card_template_id is configured.

- Add DingTalkChannel implementation with token management, message
  routing, allowed_users filtering, and markdown adaptation
- Register dingtalk in channel service registry and capability map
- Propagate inbound metadata to outbound messages in ChannelManager
  for DingTalk sender context (sender_staff_id, conversation_type)
- Add dingtalk-stream dependency to pyproject.toml
- Add configuration examples in config.example.yaml and .env.example
- Update all README translations with setup instructions
- Add comprehensive test suite (test_dingtalk_channel.py) and
  metadata propagation test in test_channels.py
- Update backend CLAUDE.md to document DingTalk channel

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix(channels): address PR review feedback for DingTalk integration

- Replace runtime mutation of CHANNEL_CAPABILITIES with a
  `supports_streaming` property on the Channel base class, overridden
  by DingTalkChannel, FeishuChannel, and WeComChannel
- Store stream client reference and attempt graceful disconnect in
  stop(); guard _on_chatbot_message with _running check to prevent
  post-stop message processing
- Use msg.chat_id as the primary routing key in send/send_file via
  a shared _resolve_routing helper, with metadata as fallback
- Fix process() return type annotation from tuple[str, str] to
  tuple[int, str] to match AckMessage.STATUS_OK
- Protect _incoming_messages with threading.Lock for cross-thread
  safety between the Stream Push thread and the asyncio loop
- Re-add Docker Compose URL guidance removed during DingTalk setup
  docs addition in README.md
- Fix incomplete sentence in README_zh.md (missing verb "启用")

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix(docs): restore plain paragraph format for Docker Compose note

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix(channels): fix isinstance TypeError and add file size guard in DingTalk channel

Use tuple syntax for isinstance() type check to avoid runtime TypeError
with PEP 604 union types. Add upload size limit (20MB) before reading
files into memory. Narrow exception handlers to specific types.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix(channels): propagate markdown fallback errors and validate access token response

- Re-raise exceptions in _send_markdown_fallback to prevent partial
  deliveries (files sent without accompanying text)
- Validate _get_access_token response: reject non-dict bodies, empty
  tokens, and coerce invalid expireIn to a safe default
- Add tests for both fixes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix(channels): validate upload response and broaden send_file exception handling

- Validate _upload_media JSON response: handle JSONDecodeError and
  non-dict payloads gracefully by returning None
- Broaden send_file exception tuple to include TypeError and
  AttributeError for unexpected JSON shapes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix(channels): fix streaming race on channel registration and slim outbound metadata

- Register channel in service before calling start() to avoid race
  where background receiver publishes inbound before registration,
  causing manager to fall back to static CHANNEL_CAPABILITIES
- Strip known-large metadata keys (raw_message, ref_msg) from outbound
  messages to prevent memory bloat from propagated inbound payloads

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Update service.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update CLAUDE.md

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-30 11:25:33 +08:00
sunsine 0691c4dda3 fix(security): allow disabling API docs in production via GATEWAY_ENABLE_DOCS (#2651)
* fix(security): allow disabling API docs in production via GATEWAY_ENABLE_DOCS

Expose /docs, /redoc, and /openapi.json only when GATEWAY_ENABLE_DOCS=true
(default). Setting GATEWAY_ENABLE_DOCS=false disables all three endpoints,
preventing unauthorized API surface discovery in production deployments.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* test(security): add unit tests and docs for GATEWAY_ENABLE_DOCS

Add 7 tests covering default behavior, env var parsing (case-insensitive,
fail-closed), endpoint visibility, and health endpoint independence.
Update CONFIGURATION.md and CLAUDE.md with the new toggle.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* style(security): apply ruff formatting to gateway app.py

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-30 10:58:32 +08:00
yangzheli f7b10d42e4 fix(frontend): create thread on first submit in new-agent page (#2656)
The new-agent page pre-generates a thread UUID and passed it directly
to useThreadStream, which made the LangGraph SDK POST to
/threads/{uuid}/runs/stream against a thread the backend had never
created. After PR #2566 introduced multi-tenant owner checks on the
runs endpoints, that request now 404s with "Thread not found".

Pass threadId: undefined to useThreadStream so the SDK takes the
create-then-run path. The pre-generated UUID is still forwarded via
SubmitOptions.threadId in sendMessage, so the new thread is created
with that exact id and onCreated rebinds the hook to it.

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-30 06:41:54 +08:00
Willem Jiang 4a9f1d547b Merge pull request #2566 from bytedance/release/2.0-rc
introduces a complete authentication/authorization system, SQL persistence layer, run event history, user data isolation, and extensive documentation in both English and Chinese. The core additions are:

Auth system: JWT-based auth with local email/password provider, CSRF protection, rate limiting
Persistence layer: SQLAlchemy 2.0 async ORM with SQLite backend (users, threads, runs, events, feedback)
User isolation: Per-user data scoping via contextvars sentinel pattern
Frontend: Login/setup pages, AuthProvider, CSRF-aware fetcher
Documentation: Comprehensive EN/ZH docs for harness, application, tutorials
2026-04-28 21:56:35 +08:00
Willem Jiang 11afd32459 Fix the log Injection error of skills.py 2026-04-28 21:42:38 +08:00
Willem Jiang 64f4dc1639 fixed the CI build errors 2026-04-28 19:01:36 +08:00
Willem Jiang 844ad8e528 Merge branch 'main' into release/2.0-rc 2026-04-28 15:44:02 +08:00
pyp0327 395c14357b chore(adpator):Adapt MindIE engine model and improve testing and fixes (#2523)
* feat(models): 适配 MindIE引擎的模型

* test: add unit tests for MindIEChatModel adapter and fix PR review comments

* chore: update uv.lock with pytest-asyncio

* build: add pytest-asyncio to test dependencies

* fix: address PR review comments (lazy import, cache clients, safe newline escape, strict xml regex)

* fix(mindie): preserve string args without JSON quotes in XML tool call serialization

* fix(mindie): preserve string args without JSON quotes in XML tool call serialization

* test_mindie_provider:format

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix(mindie): prevent nested tool_call params from leaking into outer args

* fixed by escaping XML entities in _fix_messages and unescaping during parse, with regression tests added.

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-28 15:09:31 +08:00
greatmengqi e82940c03d refactor: thread release config through lead path (#2612)
Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
2026-04-28 14:53:18 +08:00
DanielWalnut 6bd88fe14c fix(sandbox): block host bash traversal escapes (#2560)
* fix(sandbox): block host bash traversal escapes

Fixes #2535

* fix(sandbox): harden local bash path guards

* fix(sandbox): avoid bash cd argument false positives

* Fix the lint error

Add function to resolve and validate user data path.

* Fix the lint error

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-28 12:18:41 +08:00
DanielWalnut 39c5da94f3 fix(sandbox): prevent local custom mount symlink escapes (#2558)
* fix(sandbox): prevent local custom mount symlink escapes

Fixes #2506

* fix(sandbox): harden custom mount symlink handling

* fix(sandbox): format internal symlink directory listings
2026-04-28 11:59:46 +08:00
DanielWalnut 707ed328dd fix(skills): scan skill archives before install (#2561)
* fix(skills): scan skill archives before install

Fixes #2536

* fix(skills): scan archive support files before install

* style(skills): format archive installer

* fix(skills): address archive install review comments
2026-04-28 11:56:11 +08:00
DanielWalnut f7dfb88a30 fix(aio-sandbox): redact env values in container logs (#2562)
* fix(aio-sandbox): redact env values in container logs

Fixes #2534

* fix(aio-sandbox): address env log review comments
2026-04-28 11:47:56 +08:00
Willem Jiang 69649d8aae Fix the issues when reviewing 2566 persistant part (#2604)
* Fix the code review command of journal & event store P0,P1 issues

* Fix the code review command of journal & event store P2 issues

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update backend/packages/harness/deerflow/runtime/journal.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Refactor logger debug message formatting

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-28 11:44:40 +08:00
Willem Jiang 4e4e4f92a0 fix(security): harden auth system and fix run journal logic bug (#2593)
* fix(security): harden auth system and fix run journal logic bug

  - Fix inverted condition in RunJournal.on_chat_model_start that prevented
    first human message capture (not messages → messages)
  - Pre-hash passwords with SHA-256 before bcrypt to avoid silent 72-byte
    truncation vulnerability
  - Move load_dotenv() from module scope into get_auth_config() to prevent
    import-time os.environ mutation breaking test isolation
  - Return generic ‘Invalid token’ instead of exposing specific error
    variants (expired, malformed, invalid_signature) to clients
  - Make @require_auth independently enforce 401 instead of silently
    passing through when AuthMiddleware is absent
  - Rate-limit /setup-status endpoint with per-IP cooldown to mitigate
    initialization-state information leak
  - Document in-process rate limiter limitation for multi-worker deployments

* fix(security): return 429+Retry-After on setup-status rate limit, bound cooldown dict

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/070d0be8-99a5-46c8-85bb-6b81b5284021

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

* fix(security): add versioned password hashes with auto-migration on login

  The SHA-256 pre-hash change silently broke verification for any existing
  bcrypt-only password hashes. Introduce a <N>$ prefix scheme so hashes
  are self-describing:

  - v2 (current): bcrypt(b64(sha256(password))) with $ prefix
  - v1 (legacy): plain bcrypt, prefixed $ or bare (no prefix)

  verify_password auto-detects the version and falls back to v1 for older
  hashes. LocalAuthProvider.authenticate() now rehashes legacy hashes to v2
  on successful login via needs_rehash(), so existing users upgrade
  transparently without a dedicated migration step.

* fix(auth): harden verify_password, best-effort rehash, update require_auth docstring, downgrade journal logging

- password.py: wrap bcrypt.checkpw in try/except → return False for malformed/corrupt hashes instead of crashing
- local_provider.py: wrap auto-rehash update_user() in try/except so transient DB errors don't fail valid logins
- authz.py: update require_auth docstring to reflect independent 401 enforcement
- journal.py: downgrade on_chat_model_start from INFO to DEBUG, log only metadata (batch_count, message_counts) instead of full serialized/messages content

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/48c5cf31-a4ab-418a-982a-6343c37bb299

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

* fix(auth): address code review - narrow ValueError catch, add rehash warning log, rename num_batches

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/48c5cf31-a4ab-418a-982a-6343c37bb299

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-04-28 11:34:07 +08:00
DanielWalnut af8c0cfb78 fix(harness): constrain view_image to thread data paths (#2557)
* fix(harness): constrain view_image to thread data paths

Fixes #2530

* fix(harness): address view_image review findings

* style(harness): format view_image changes

* fix(harness): address view_image review comments
2026-04-28 11:13:17 +08:00
greatmengqi b8bc4826d8 refactor: root release config in gateway runtime (#2611)
Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
2026-04-28 00:13:04 +08:00
128 changed files with 9108 additions and 1717 deletions
+5
View File
@@ -40,3 +40,8 @@ INFOQUEST_API_KEY=your-infoquest-api-key
#
# WECOM_BOT_ID=your-wecom-bot-id
# WECOM_BOT_SECRET=your-wecom-bot-secret
# DINGTALK_CLIENT_ID=your-dingtalk-client-id
# DINGTALK_CLIENT_SECRET=your-dingtalk-client-secret
# Set to "false" to disable Swagger UI, ReDoc, and OpenAPI schema in production
# GATEWAY_ENABLE_DOCS=false
+21 -1
View File
@@ -251,7 +251,7 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
If you prefer running services locally:
Prerequisite: complete the "Configuration" steps above first (`make setup`). `make dev` requires a valid `config.yaml` in the project root (can be overridden via `DEER_FLOW_CONFIG_PATH`). Run `make doctor` to verify your setup before starting.
Prerequisite: complete the "Configuration" steps above first (`make setup`). `make dev` requires a valid `config.yaml` in the project root. Set `DEER_FLOW_PROJECT_ROOT` to define that root explicitly, or `DEER_FLOW_CONFIG_PATH` to point at a specific config file. Runtime state defaults to `.deer-flow` under the project root and can be moved with `DEER_FLOW_HOME`; skills default to `skills/` under the project root and can be moved with `DEER_FLOW_SKILLS_PATH`. Run `make doctor` to verify your setup before starting.
On Windows, run the local development flow from Git Bash. Native `cmd.exe` and PowerShell shells are not supported for the bash-based service scripts, and WSL is not guaranteed because some scripts rely on Git for Windows utilities such as `cygpath`.
1. **Check prerequisites**:
@@ -345,6 +345,7 @@ DeerFlow supports receiving tasks from messaging apps. Channels auto-start when
| Feishu / Lark | WebSocket | Moderate |
| WeChat | Tencent iLink (long-polling) | Moderate |
| WeCom | WebSocket | Moderate |
| DingTalk | Stream Push (WebSocket) | Moderate |
**Configuration in `config.yaml`:**
@@ -414,6 +415,13 @@ channels:
context:
thinking_enabled: true
subagent_enabled: true
dingtalk:
enabled: true
client_id: $DINGTALK_CLIENT_ID # Client ID of your DingTalk application
client_secret: $DINGTALK_CLIENT_SECRET # Client Secret of your DingTalk application
allowed_users: [] # empty = allow all
card_template_id: "" # Optional: AI Card template ID for streaming typewriter effect
```
Notes:
@@ -442,6 +450,10 @@ WECHAT_ILINK_BOT_ID=your_ilink_bot_id
# WeCom
WECOM_BOT_ID=your_bot_id
WECOM_BOT_SECRET=your_bot_secret
# DingTalk
DINGTALK_CLIENT_ID=your_client_id
DINGTALK_CLIENT_SECRET=your_client_secret
```
**Telegram Setup**
@@ -480,6 +492,14 @@ WECOM_BOT_SECRET=your_bot_secret
4. Make sure backend dependencies include `wecom-aibot-python-sdk`. The channel uses a WebSocket long connection and does not require a public callback URL.
5. The current integration supports inbound text, image, and file messages. Final images/files generated by the agent are also sent back to the WeCom conversation.
**DingTalk Setup**
1. Create a DingTalk application in the [DingTalk Developer Console](https://open.dingtalk.com/) and enable **Robot** capability.
2. Set the message receiving mode to **Stream Mode** in the robot configuration page.
3. Copy the `Client ID` and `Client Secret`, set `DINGTALK_CLIENT_ID` and `DINGTALK_CLIENT_SECRET` in `.env`, and enable the channel in `config.yaml`.
4. *(Optional)* To enable streaming AI Card replies (typewriter effect), create an **AI Card** template on the [DingTalk Card Platform](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card), then set `card_template_id` in `config.yaml` to the template ID. You also need to apply for the `Card.Streaming.Write` and `Card.Instance.Write` permissions.
When DeerFlow runs in Docker Compose, IM channels execute inside the `gateway` container. In that case, do not point `channels.langgraph_url` or `channels.gateway_url` at `localhost`; use container service names such as `http://gateway:8001/api` and `http://gateway:8001`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` and `DEER_FLOW_CHANNELS_GATEWAY_URL`.
**Commands**
+19
View File
@@ -290,6 +290,7 @@ DeerFlow peut recevoir des tâches depuis des applications de messagerie. Les ca
| Telegram | Bot API (long-polling) | Facile |
| Slack | Socket Mode | Modérée |
| Feishu / Lark | WebSocket | Modérée |
| DingTalk | Stream Push (WebSocket) | Modérée |
**Configuration dans `config.yaml` :**
@@ -341,6 +342,13 @@ channels:
context:
thinking_enabled: true
subagent_enabled: true
dingtalk:
enabled: true
client_id: $DINGTALK_CLIENT_ID # ClientId depuis DingTalk Open Platform
client_secret: $DINGTALK_CLIENT_SECRET # ClientSecret depuis DingTalk Open Platform
allowed_users: [] # vide = tout le monde autorisé
card_template_id: "" # Optionnel : ID de modèle AI Card pour l'effet machine à écrire en streaming
```
Définissez les clés API correspondantes dans votre fichier `.env` :
@@ -356,6 +364,10 @@ SLACK_APP_TOKEN=xapp-...
# Feishu / Lark
FEISHU_APP_ID=cli_xxxx
FEISHU_APP_SECRET=your_app_secret
# DingTalk
DINGTALK_CLIENT_ID=your_client_id
DINGTALK_CLIENT_SECRET=your_client_secret
```
**Configuration Telegram**
@@ -378,6 +390,13 @@ FEISHU_APP_SECRET=your_app_secret
3. Dans **Events**, abonnez-vous à `im.message.receive_v1` et sélectionnez le mode **Long Connection**.
4. Copiez l'App ID et l'App Secret. Définissez `FEISHU_APP_ID` et `FEISHU_APP_SECRET` dans `.env` et activez le canal dans `config.yaml`.
**Configuration DingTalk**
1. Créez une application sur [DingTalk Open Platform](https://open.dingtalk.com/) et activez la capacité **Robot**.
2. Dans la page de configuration du robot, définissez le mode de réception des messages sur **Stream**.
3. Copiez le `Client ID` et le `Client Secret`. Définissez `DINGTALK_CLIENT_ID` et `DINGTALK_CLIENT_SECRET` dans `.env` et activez le canal dans `config.yaml`.
4. *(Optionnel)* Pour activer les réponses en streaming AI Card (effet machine à écrire), créez un modèle **AI Card** sur la [plateforme de cartes DingTalk](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card), puis définissez `card_template_id` dans `config.yaml` avec l'ID du modèle. Vous devez également demander les permissions `Card.Streaming.Write` et `Card.Instance.Write`.
**Commandes**
Une fois un canal connecté, vous pouvez interagir avec DeerFlow directement depuis le chat :
+19
View File
@@ -243,6 +243,7 @@ DeerFlowはメッセージングアプリからのタスク受信をサポート
| Telegram | Bot API(ロングポーリング) | 簡単 |
| Slack | Socket Mode | 中程度 |
| Feishu / Lark | WebSocket | 中程度 |
| DingTalk | Stream PushWebSocket | 中程度 |
**`config.yaml`での設定:**
@@ -294,6 +295,13 @@ channels:
context:
thinking_enabled: true
subagent_enabled: true
dingtalk:
enabled: true
client_id: $DINGTALK_CLIENT_ID # DingTalk Open PlatformのClientId
client_secret: $DINGTALK_CLIENT_SECRET # DingTalk Open PlatformのClientSecret
allowed_users: [] # 空 = 全員許可
card_template_id: "" # オプション:ストリーミングタイプライター効果用のAIカードテンプレートID
```
対応するAPIキーを`.env`ファイルに設定します:
@@ -309,6 +317,10 @@ SLACK_APP_TOKEN=xapp-...
# Feishu / Lark
FEISHU_APP_ID=cli_xxxx
FEISHU_APP_SECRET=your_app_secret
# DingTalk
DINGTALK_CLIENT_ID=your_client_id
DINGTALK_CLIENT_SECRET=your_client_secret
```
**Telegramのセットアップ**
@@ -331,6 +343,13 @@ FEISHU_APP_SECRET=your_app_secret
3. **イベント**で`im.message.receive_v1`を購読し、**ロングコネクション**モードを選択。
4. App IDとApp Secretをコピー。`.env`に`FEISHU_APP_ID`と`FEISHU_APP_SECRET`を設定し、`config.yaml`でチャネルを有効にします。
**DingTalkのセットアップ**
1. [DingTalk Open Platform](https://open.dingtalk.com/)でアプリを作成し、**ロボット**機能を有効化します。
2. ロボット設定ページでメッセージ受信モードを**Streamモード**に設定します。
3. `Client ID`と`Client Secret`をコピー。`.env`に`DINGTALK_CLIENT_ID`と`DINGTALK_CLIENT_SECRET`を設定し、`config.yaml`でチャネルを有効にします。
4. *(オプション)* ストリーミングAIカード返信(タイプライター効果)を有効にするには、[DingTalkカードプラットフォーム](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card)で**AIカード**テンプレートを作成し、`config.yaml`の`card_template_id`にテンプレートIDを設定します。`Card.Streaming.Write` および `Card.Instance.Write` 権限の申請も必要です。
**コマンド**
チャネル接続後、チャットから直接DeerFlowと対話できます:
+15
View File
@@ -256,6 +256,7 @@ DeerFlow принимает задачи прямо из мессенджеро
| Telegram | Bot API (long-polling) | Просто |
| Slack | Socket Mode | Средне |
| Feishu / Lark | WebSocket | Средне |
| DingTalk | Stream Push (WebSocket) | Средне |
**Конфигурация в `config.yaml`:**
@@ -278,6 +279,13 @@ channels:
enabled: true
bot_token: $TELEGRAM_BOT_TOKEN
allowed_users: []
dingtalk:
enabled: true
client_id: $DINGTALK_CLIENT_ID # ClientId с DingTalk Open Platform
client_secret: $DINGTALK_CLIENT_SECRET # ClientSecret с DingTalk Open Platform
allowed_users: [] # пусто = разрешить всем
card_template_id: "" # Опционально: ID шаблона AI Card для потокового эффекта печатной машинки
```
**Настройка Telegram**
@@ -285,6 +293,13 @@ channels:
1. Напишите [@BotFather](https://t.me/BotFather), отправьте `/newbot` и скопируйте HTTP API-токен.
2. Укажите `TELEGRAM_BOT_TOKEN` в `.env` и включите канал в `config.yaml`.
**Настройка DingTalk**
1. Создайте приложение на [DingTalk Open Platform](https://open.dingtalk.com/) и включите возможность **Робот**.
2. На странице настроек робота установите режим приёма сообщений на **Stream**.
3. Скопируйте `Client ID` и `Client Secret`. Укажите `DINGTALK_CLIENT_ID` и `DINGTALK_CLIENT_SECRET` в `.env` и включите канал в `config.yaml`.
4. *(Опционально)* Для включения потоковых ответов AI Card (эффект печатной машинки) создайте шаблон **AI Card** на [платформе карточек DingTalk](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card), затем укажите `card_template_id` в `config.yaml` с ID шаблона. Также необходимо запросить разрешения `Card.Streaming.Write` и `Card.Instance.Write`.
**Доступные команды**
| Команда | Описание |
+20 -1
View File
@@ -194,7 +194,7 @@ make down # 停止并移除容器
如果你更希望直接在本地启动各个服务:
前提:先完成上面的“配置”步骤(`make config` 和模型 API key 配置)。`make dev` 需要有效配置文件,默认读取项目根目录下的 `config.yaml`,也可以通过 `DEER_FLOW_CONFIG_PATH` 覆盖。
前提:先完成上面的“配置”步骤(`make config` 和模型 API key 配置)。`make dev` 需要有效配置文件,默认读取项目根目录下的 `config.yaml`。可以用 `DEER_FLOW_PROJECT_ROOT` 显式指定项目根目录,也可以 `DEER_FLOW_CONFIG_PATH` 指向某个具体配置文件。运行期状态默认写到项目根目录下的 `.deer-flow`,可用 `DEER_FLOW_HOME` 覆盖;skills 默认读取项目根目录下的 `skills/`,可用 `DEER_FLOW_SKILLS_PATH` 覆盖。
在 Windows 上,请使用 Git Bash 运行本地开发流程。基于 bash 的服务脚本不支持直接在原生 `cmd.exe` 或 PowerShell 中执行,且 WSL 也不保证可用,因为部分脚本依赖 Git for Windows 的 `cygpath` 等工具。
1. **检查依赖环境**
@@ -248,6 +248,7 @@ DeerFlow 支持从即时通讯应用接收任务。只要配置完成,对应
| Slack | Socket Mode | 中等 |
| Feishu / Lark | WebSocket | 中等 |
| 企业微信智能机器人 | WebSocket | 中等 |
| 钉钉 | Stream PushWebSocket | 中等 |
**`config.yaml` 中的配置示例:**
@@ -304,6 +305,13 @@ channels:
context:
thinking_enabled: true
subagent_enabled: true
dingtalk:
enabled: true
client_id: $DINGTALK_CLIENT_ID # 钉钉开放平台 ClientId
client_secret: $DINGTALK_CLIENT_SECRET # 钉钉开放平台 ClientSecret
allowed_users: [] # 留空表示允许所有人
card_template_id: "" # 可选:AI 卡片模板 ID,用于流式打字机效果
```
说明:
@@ -327,6 +335,10 @@ FEISHU_APP_SECRET=your_app_secret
# 企业微信智能机器人
WECOM_BOT_ID=your_bot_id
WECOM_BOT_SECRET=your_bot_secret
# 钉钉
DINGTALK_CLIENT_ID=your_client_id
DINGTALK_CLIENT_SECRET=your_client_secret
```
**Telegram 配置**
@@ -357,6 +369,13 @@ WECOM_BOT_SECRET=your_bot_secret
4. 安装后端依赖时确保包含 `wecom-aibot-python-sdk`,渠道会通过 WebSocket 长连接接收消息,无需公网回调地址。
5. 当前支持文本、图片和文件入站消息;agent 生成的最终图片/文件也会回传到企业微信会话中。
**钉钉配置**
1. 在 [钉钉开放平台](https://open.dingtalk.com/) 创建应用,并启用 **机器人** 能力。
2. 在机器人配置页面设置消息接收模式为 **Stream模式**。
3. 复制 `Client ID` 和 `Client Secret`,在 `.env` 中设置 `DINGTALK_CLIENT_ID` 和 `DINGTALK_CLIENT_SECRET`,并在 `config.yaml` 中启用该渠道。
4. *(可选)* 如需开启流式 AI 卡片回复(打字机效果),请在[钉钉卡片平台](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card)创建 **AI 卡片**模板,然后在 `config.yaml` 中将 `card_template_id` 设为该模板 ID。同时需要申请 `Card.Streaming.Write` 和 `Card.Instance.Write` 权限。
**命令**
渠道连接完成后,你可以直接在聊天窗口里和 DeerFlow 交互:
+10 -7
View File
@@ -112,7 +112,7 @@ CI runs these regression tests for every pull request via [.github/workflows/bac
The backend is split into two layers with a strict dependency direction:
- **Harness** (`packages/harness/deerflow/`): Publishable agent framework package (`deerflow-harness`). Import prefix: `deerflow.*`. Contains agent orchestration, tools, sandbox, models, MCP, skills, config — everything needed to build and run agents.
- **App** (`app/`): Unpublished application code. Import prefix: `app.*`. Contains the FastAPI Gateway API and IM channel integrations (Feishu, Slack, Telegram).
- **App** (`app/`): Unpublished application code. Import prefix: `app.*`. Contains the FastAPI Gateway API and IM channel integrations (Feishu, Slack, Telegram, DingTalk).
**Dependency rule**: App imports deerflow, but deerflow never imports app. This boundary is enforced by `tests/test_harness_boundary.py` which runs in CI.
@@ -205,7 +205,7 @@ Configuration priority:
### Gateway API (`app/gateway/`)
FastAPI application on port 8001 with health check at `GET /health`.
FastAPI application on port 8001 with health check at `GET /health`. Set `GATEWAY_ENABLE_DOCS=false` to disable `/docs`, `/redoc`, and `/openapi.json` in production (default: enabled).
**Routers**:
@@ -312,7 +312,8 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
### IM Channels System (`app/channels/`)
Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow agent via Gateway's LangGraph-compatible API.
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via the LangGraph Server.
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
@@ -322,7 +323,7 @@ Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow a
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
- `slack.py` / `feishu.py` / `telegram.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place)
- `slack.py` / `feishu.py` / `telegram.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
**Message Flow**:
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
@@ -331,14 +332,16 @@ Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow a
4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
5. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
7. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
8. Outbound → channel callbacks → platform reply
7. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
8. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
9. Outbound → channel callbacks → platform reply
**Configuration** (`config.yaml` -> `channels`):
- `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`)
- `gateway_url` - Gateway API URL for auxiliary commands (default: `http://localhost:8001`)
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token)
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming)
### Memory System (`packages/harness/deerflow/agents/memory/`)
+4
View File
@@ -31,6 +31,10 @@ class Channel(ABC):
def is_running(self) -> bool:
return self._running
@property
def supports_streaming(self) -> bool:
return False
# -- lifecycle ---------------------------------------------------------
@abstractmethod
+740
View File
@@ -0,0 +1,740 @@
"""DingTalk channel implementation."""
from __future__ import annotations
import asyncio
import json
import logging
import re
import threading
import time
from pathlib import Path
from typing import Any
import httpx
from app.channels.base import Channel
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
DINGTALK_API_BASE = "https://api.dingtalk.com"
_TOKEN_REFRESH_MARGIN_SECONDS = 300
_CONVERSATION_TYPE_P2P = "1"
_CONVERSATION_TYPE_GROUP = "2"
_MAX_UPLOAD_SIZE_BYTES = 20 * 1024 * 1024
def _normalize_conversation_type(raw: Any) -> str:
"""Normalize ``conversationType`` to ``"1"`` (P2P) or ``"2"`` (group).
Stream payloads may send int or string values.
"""
if raw is None:
return _CONVERSATION_TYPE_P2P
s = str(raw).strip()
if s == _CONVERSATION_TYPE_GROUP:
return _CONVERSATION_TYPE_GROUP
return _CONVERSATION_TYPE_P2P
def _normalize_allowed_users(allowed_users: Any) -> set[str]:
if allowed_users is None:
return set()
if isinstance(allowed_users, str):
values = [allowed_users]
elif isinstance(allowed_users, (list, tuple, set)):
values = allowed_users
else:
logger.warning(
"DingTalk allowed_users should be a list of user IDs; treating %s as one string value",
type(allowed_users).__name__,
)
values = [allowed_users]
return {str(uid) for uid in values if str(uid)}
def _is_dingtalk_command(text: str) -> bool:
if not text.startswith("/"):
return False
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
def _extract_text_from_rich_text(rich_text_list: list) -> str:
parts: list[str] = []
for item in rich_text_list:
if isinstance(item, dict) and "text" in item:
parts.append(item["text"])
return " ".join(parts)
_FENCED_CODE_BLOCK_RE = re.compile(r"```(\w*)\n(.*?)```", re.DOTALL)
_INLINE_CODE_RE = re.compile(r"`([^`\n]+)`")
_HORIZONTAL_RULE_RE = re.compile(r"^-{3,}$", re.MULTILINE)
_TABLE_SEPARATOR_RE = re.compile(r"^\|[-:| ]+\|$", re.MULTILINE)
def _convert_markdown_table(text: str) -> str:
# DingTalk sampleMarkdown does not render pipe-delimited tables.
lines = text.split("\n")
result: list[str] = []
i = 0
while i < len(lines):
line = lines[i]
# Detect table: header row followed by separator row
if i + 1 < len(lines) and line.strip().startswith("|") and _TABLE_SEPARATOR_RE.match(lines[i + 1].strip()):
headers = [h.strip() for h in line.strip().strip("|").split("|")]
i += 2 # skip header + separator
while i < len(lines) and lines[i].strip().startswith("|"):
cells = [c.strip() for c in lines[i].strip().strip("|").split("|")]
for h, c in zip(headers, cells):
result.append(f"> **{h}**: {c}")
result.append("")
i += 1
else:
result.append(line)
i += 1
return "\n".join(result)
def _adapt_markdown_for_dingtalk(text: str) -> str:
"""Adapt markdown for DingTalk's limited sampleMarkdown renderer."""
def _code_block_to_quote(match: re.Match) -> str:
lang = match.group(1)
code = match.group(2).rstrip("\n")
prefix = f"> **{lang}**\n" if lang else ""
quoted_lines = "\n".join(f"> {line}" for line in code.split("\n"))
return f"{prefix}{quoted_lines}\n"
text = _FENCED_CODE_BLOCK_RE.sub(_code_block_to_quote, text)
text = _INLINE_CODE_RE.sub(r"**\1**", text)
text = _convert_markdown_table(text)
text = _HORIZONTAL_RULE_RE.sub("───────────", text)
return text
class DingTalkChannel(Channel):
"""DingTalk IM channel using Stream Push (WebSocket, no public IP needed)."""
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
super().__init__(name="dingtalk", bus=bus, config=config)
self._thread: threading.Thread | None = None
self._main_loop: asyncio.AbstractEventLoop | None = None
self._client_id: str = ""
self._client_secret: str = ""
self._allowed_users: set[str] = _normalize_allowed_users(config.get("allowed_users"))
self._cached_token: str = ""
self._token_expires_at: float = 0.0
self._token_lock = asyncio.Lock()
self._card_template_id: str = config.get("card_template_id", "")
self._card_track_ids: dict[str, str] = {}
self._dingtalk_client: Any = None
self._stream_client: Any = None
self._incoming_messages: dict[str, Any] = {}
self._incoming_messages_lock = threading.Lock()
self._card_repliers: dict[str, Any] = {}
@property
def supports_streaming(self) -> bool:
return bool(self._card_template_id)
async def start(self) -> None:
if self._running:
return
try:
import dingtalk_stream # noqa: F401
except ImportError:
logger.error("dingtalk-stream is not installed. Install it with: uv add dingtalk-stream")
return
client_id = self.config.get("client_id", "")
client_secret = self.config.get("client_secret", "")
if not client_id or not client_secret:
logger.error("DingTalk channel requires client_id and client_secret")
return
self._client_id = client_id
self._client_secret = client_secret
self._main_loop = asyncio.get_running_loop()
if self._card_template_id:
logger.info("[DingTalk] AI Card mode enabled (template=%s)", self._card_template_id)
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
self._thread = threading.Thread(
target=self._run_stream,
args=(client_id, client_secret),
daemon=True,
)
self._thread.start()
logger.info("DingTalk channel started")
async def stop(self) -> None:
self._running = False
self.bus.unsubscribe_outbound(self._on_outbound)
stream_client = self._stream_client
if stream_client is not None:
try:
if hasattr(stream_client, "disconnect"):
stream_client.disconnect()
except Exception:
logger.debug("[DingTalk] error disconnecting stream client", exc_info=True)
self._dingtalk_client = None
self._stream_client = None
with self._incoming_messages_lock:
self._incoming_messages.clear()
self._card_repliers.clear()
self._card_track_ids.clear()
if self._thread:
self._thread.join(timeout=5)
self._thread = None
logger.info("DingTalk channel stopped")
def _resolve_routing(self, msg: OutboundMessage) -> tuple[str, str, str]:
"""Return (conversation_type, sender_staff_id, conversation_id).
Uses msg.chat_id as the primary routing key; metadata as fallback.
"""
conversation_type = _normalize_conversation_type(msg.metadata.get("conversation_type"))
sender_staff_id = msg.metadata.get("sender_staff_id", "")
conversation_id = msg.metadata.get("conversation_id", "")
if conversation_type == _CONVERSATION_TYPE_GROUP:
conversation_id = msg.chat_id or conversation_id
else:
sender_staff_id = msg.chat_id or sender_staff_id
return conversation_type, sender_staff_id, conversation_id
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
conversation_type, sender_staff_id, conversation_id = self._resolve_routing(msg)
robot_code = self._client_id
# Card mode: stream update to existing AI card
source_key = self._make_card_source_key_from_outbound(msg)
out_track_id = self._card_track_ids.get(source_key)
# ``card_template_id`` enables ``runs.stream`` (non-final + final outbounds).
# If card creation failed, skip non-final chunks to avoid duplicate messages.
if self._card_template_id and not out_track_id and not msg.is_final:
return
if out_track_id:
try:
await self._stream_update_card(
out_track_id,
msg.text,
is_finalize=msg.is_final,
)
except Exception:
logger.warning("[DingTalk] card stream failed, falling back to sampleMarkdown")
if msg.is_final:
self._card_track_ids.pop(source_key, None)
self._card_repliers.pop(out_track_id, None)
await self._send_markdown_fallback(robot_code, conversation_type, sender_staff_id, conversation_id, msg.text)
return
if msg.is_final:
self._card_track_ids.pop(source_key, None)
self._card_repliers.pop(out_track_id, None)
return
# Non-card mode: send sampleMarkdown with retry
last_exc: Exception | None = None
for attempt in range(_max_retries):
try:
if conversation_type == _CONVERSATION_TYPE_GROUP:
await self._send_group_message(robot_code, conversation_id, msg.text, at_user_ids=[sender_staff_id] if sender_staff_id else None)
else:
await self._send_p2p_message(robot_code, sender_staff_id, msg.text)
return
except Exception as exc:
last_exc = exc
if attempt < _max_retries - 1:
delay = 2**attempt
logger.warning(
"[DingTalk] send failed (attempt %d/%d), retrying in %ds: %s",
attempt + 1,
_max_retries,
delay,
exc,
)
await asyncio.sleep(delay)
logger.error("[DingTalk] send failed after %d attempts: %s", _max_retries, last_exc)
if last_exc is None:
raise RuntimeError("DingTalk send failed without an exception from any attempt")
raise last_exc
async def _send_markdown_fallback(
self,
robot_code: str,
conversation_type: str,
sender_staff_id: str,
conversation_id: str,
text: str,
) -> None:
try:
if conversation_type == _CONVERSATION_TYPE_GROUP:
await self._send_group_message(robot_code, conversation_id, text)
else:
await self._send_p2p_message(robot_code, sender_staff_id, text)
except Exception:
logger.exception("[DingTalk] markdown fallback also failed")
raise
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if attachment.size > _MAX_UPLOAD_SIZE_BYTES:
logger.warning("[DingTalk] file too large (%d bytes), skipping: %s", attachment.size, attachment.filename)
return False
conversation_type, sender_staff_id, conversation_id = self._resolve_routing(msg)
robot_code = self._client_id
try:
media_id = await self._upload_media(attachment.actual_path, "image" if attachment.is_image else "file")
if not media_id:
return False
if attachment.is_image:
msg_key = "sampleImageMsg"
msg_param = json.dumps({"photoURL": media_id})
else:
msg_key = "sampleFile"
msg_param = json.dumps(
{
"fileUrl": media_id,
"fileName": attachment.filename,
"fileSize": str(attachment.size),
}
)
token = await self._get_access_token()
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
if conversation_type == _CONVERSATION_TYPE_GROUP:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/robot/groupMessages/send",
headers=self._api_headers(token),
json={
"msgKey": msg_key,
"msgParam": msg_param,
"robotCode": robot_code,
"openConversationId": conversation_id,
},
)
else:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/robot/oToMessages/batchSend",
headers=self._api_headers(token),
json={
"msgKey": msg_key,
"msgParam": msg_param,
"robotCode": robot_code,
"userIds": [sender_staff_id],
},
)
response.raise_for_status()
logger.info("[DingTalk] file sent: %s", attachment.filename)
return True
except (httpx.HTTPError, OSError, ValueError, TypeError, AttributeError):
logger.exception("[DingTalk] failed to send file: %s", attachment.filename)
return False
# -- stream client (runs in dedicated thread) --------------------------
def _run_stream(self, client_id: str, client_secret: str) -> None:
try:
import dingtalk_stream
credential = dingtalk_stream.Credential(client_id, client_secret)
client = dingtalk_stream.DingTalkStreamClient(credential)
self._stream_client = client
client.register_callback_handler(
dingtalk_stream.chatbot.ChatbotMessage.TOPIC,
_DingTalkMessageHandler(self),
)
client.start_forever()
except Exception:
if self._running:
logger.exception("DingTalk Stream Push error")
finally:
self._stream_client = None
def _on_chatbot_message(self, message: Any) -> None:
if not self._running:
return
try:
sender_staff_id = message.sender_staff_id or ""
conversation_type = _normalize_conversation_type(message.conversation_type)
conversation_id = message.conversation_id or ""
msg_id = message.message_id or ""
sender_nick = message.sender_nick or ""
if self._allowed_users and sender_staff_id not in self._allowed_users:
logger.debug("[DingTalk] ignoring message from non-allowed user: %s", sender_staff_id)
return
text = self._extract_text(message)
if not text:
logger.info("[DingTalk] empty text, ignoring message")
return
logger.info(
"[DingTalk] parsed message: conv_type=%s, msg_id=%s, sender=%s(%s), text=%r",
conversation_type,
msg_id,
sender_staff_id,
sender_nick,
text[:100],
)
if _is_dingtalk_command(text):
msg_type = InboundMessageType.COMMAND
else:
msg_type = InboundMessageType.CHAT
# P2P: topic_id=None (single thread per user, like Telegram private chat)
# Group: topic_id=msg_id (each new message starts a new topic, like Feishu)
topic_id: str | None = msg_id if conversation_type == _CONVERSATION_TYPE_GROUP else None
# chat_id uses conversation_id for groups, sender_staff_id for P2P
chat_id = conversation_id if conversation_type == _CONVERSATION_TYPE_GROUP else sender_staff_id
inbound = self._make_inbound(
chat_id=chat_id,
user_id=sender_staff_id,
text=text,
msg_type=msg_type,
thread_ts=msg_id,
metadata={
"conversation_type": conversation_type,
"conversation_id": conversation_id,
"sender_staff_id": sender_staff_id,
"sender_nick": sender_nick,
"message_id": msg_id,
},
)
inbound.topic_id = topic_id
if self._card_template_id:
source_key = self._make_card_source_key(inbound)
with self._incoming_messages_lock:
self._incoming_messages[source_key] = message
if self._main_loop and self._main_loop.is_running():
logger.info("[DingTalk] publishing inbound message to bus (type=%s, msg_id=%s)", msg_type.value, msg_id)
fut = asyncio.run_coroutine_threadsafe(
self._prepare_inbound(chat_id, inbound),
self._main_loop,
)
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "prepare_inbound", mid))
else:
logger.warning("[DingTalk] main loop not running, cannot publish inbound message")
except Exception:
logger.exception("[DingTalk] error processing chatbot message")
@staticmethod
def _extract_text(message: Any) -> str:
msg_type = message.message_type
if msg_type == "text" and message.text:
return message.text.content.strip()
if msg_type == "richText" and message.rich_text_content:
return _extract_text_from_rich_text(message.rich_text_content.rich_text_list).strip()
return ""
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
# Running reply must finish before publish_inbound so AI card tracks are
# registered before the manager emits streaming outbounds.
await self._send_running_reply(chat_id, inbound)
await self.bus.publish_inbound(inbound)
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
conversation_id = inbound.metadata.get("conversation_id", "")
text = "\u23f3 Working on it..."
try:
if self._card_template_id:
source_key = self._make_card_source_key(inbound)
with self._incoming_messages_lock:
chatbot_message = self._incoming_messages.pop(source_key, None)
out_track_id = await self._create_and_deliver_card(
text,
chatbot_message=chatbot_message,
)
if out_track_id:
self._card_track_ids[source_key] = out_track_id
logger.info("[DingTalk] AI card running reply sent for chat=%s", chat_id)
return
robot_code = self._client_id
if conversation_type == _CONVERSATION_TYPE_GROUP:
await self._send_text_message_to_group(robot_code, conversation_id, text)
else:
await self._send_text_message_to_user(robot_code, sender_staff_id, text)
logger.info("[DingTalk] 'Working on it...' reply sent for chat=%s", chat_id)
except Exception:
logger.exception("[DingTalk] failed to send running reply for chat=%s", chat_id)
# -- DingTalk API helpers ----------------------------------------------
async def _get_access_token(self) -> str:
if self._cached_token and time.monotonic() < self._token_expires_at:
return self._cached_token
async with self._token_lock:
if self._cached_token and time.monotonic() < self._token_expires_at:
return self._cached_token
async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/oauth2/accessToken",
json={"appKey": self._client_id, "appSecret": self._client_secret}, # DingTalk API field names
)
response.raise_for_status()
data = response.json()
if not isinstance(data, dict):
raise ValueError(f"DingTalk access token response must be a JSON object, got {type(data).__name__}")
access_token = data.get("accessToken")
if not isinstance(access_token, str) or not access_token.strip():
raise ValueError("DingTalk access token response did not contain a usable accessToken")
raw_expires_in = data.get("expireIn", 7200)
try:
expires_in = int(raw_expires_in)
except (TypeError, ValueError):
logger.warning("[DingTalk] invalid expireIn value %r, using default 7200s", raw_expires_in)
expires_in = 7200
self._cached_token = access_token.strip()
self._token_expires_at = time.monotonic() + expires_in - _TOKEN_REFRESH_MARGIN_SECONDS
return self._cached_token
@staticmethod
def _api_headers(token: str) -> dict[str, str]:
return {
"x-acs-dingtalk-access-token": token,
"Content-Type": "application/json",
}
async def _send_text_message_to_user(self, robot_code: str, user_id: str, text: str) -> None:
token = await self._get_access_token()
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/robot/oToMessages/batchSend",
headers=self._api_headers(token),
json={
"msgKey": "sampleText",
"msgParam": json.dumps({"content": text}),
"robotCode": robot_code,
"userIds": [user_id],
},
)
response.raise_for_status()
async def _send_text_message_to_group(self, robot_code: str, conversation_id: str, text: str) -> None:
token = await self._get_access_token()
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/robot/groupMessages/send",
headers=self._api_headers(token),
json={
"msgKey": "sampleText",
"msgParam": json.dumps({"content": text}),
"robotCode": robot_code,
"openConversationId": conversation_id,
},
)
response.raise_for_status()
async def _send_p2p_message(self, robot_code: str, user_id: str, text: str) -> None:
text = _adapt_markdown_for_dingtalk(text)
token = await self._get_access_token()
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/robot/oToMessages/batchSend",
headers=self._api_headers(token),
json={
"msgKey": "sampleMarkdown",
"msgParam": json.dumps({"title": "DeerFlow", "text": text}),
"robotCode": robot_code,
"userIds": [user_id],
},
)
response.raise_for_status()
data = response.json()
if data.get("processQueryKey"):
logger.info("[DingTalk] P2P message sent to user=%s", user_id)
else:
logger.warning("[DingTalk] P2P send response: %s", data)
async def _send_group_message(
self,
robot_code: str,
conversation_id: str,
text: str,
*,
at_user_ids: list[str] | None = None, # noqa: ARG002
) -> None:
# at_user_ids accepted for call-site compatibility but not passed to the API
# (sampleMarkdown does not support @mentions).
text = _adapt_markdown_for_dingtalk(text)
token = await self._get_access_token()
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/robot/groupMessages/send",
headers=self._api_headers(token),
json={
"msgKey": "sampleMarkdown",
"msgParam": json.dumps({"title": "DeerFlow", "text": text}),
"robotCode": robot_code,
"openConversationId": conversation_id,
},
)
response.raise_for_status()
data = response.json()
if data.get("processQueryKey"):
logger.info("[DingTalk] group message sent to conversation=%s", conversation_id)
else:
logger.warning("[DingTalk] group send response: %s", data)
# -- AI Card streaming helpers -------------------------------------------
def _make_card_source_key(self, inbound: InboundMessage) -> str:
m = inbound.metadata
return f"{m.get('conversation_type', '')}:{m.get('sender_staff_id', '')}:{m.get('conversation_id', '')}:{m.get('message_id', '')}"
def _make_card_source_key_from_outbound(self, msg: OutboundMessage) -> str:
m = msg.metadata
correlation_id = m.get("message_id") or msg.thread_ts or ""
return f"{m.get('conversation_type', '')}:{m.get('sender_staff_id', '')}:{m.get('conversation_id', '')}:{correlation_id}"
async def _create_and_deliver_card(
self,
initial_text: str,
*,
chatbot_message: Any = None,
) -> str | None:
if self._dingtalk_client is None or chatbot_message is None:
logger.warning("[DingTalk] SDK client or chatbot_message unavailable, skipping AI card")
return None
try:
from dingtalk_stream.card_replier import AICardReplier
except ImportError:
logger.warning("[DingTalk] dingtalk-stream card_replier not available")
return None
try:
replier = AICardReplier(self._dingtalk_client, chatbot_message)
card_instance_id = await replier.async_create_and_deliver_card(
card_template_id=self._card_template_id,
card_data={"content": initial_text},
)
if not card_instance_id:
return None
self._card_repliers[card_instance_id] = replier
logger.info("[DingTalk] AI card created: outTrackId=%s", card_instance_id)
return card_instance_id
except Exception:
logger.exception("[DingTalk] failed to create AI card")
return None
async def _stream_update_card(
self,
out_track_id: str,
content: str,
*,
is_finalize: bool = False,
is_error: bool = False,
) -> None:
replier = self._card_repliers.get(out_track_id)
if not replier:
raise RuntimeError(f"No AICardReplier found for track ID {out_track_id}")
await replier.async_streaming(
card_instance_id=out_track_id,
content_key="content",
content_value=content,
append=False,
finished=is_finalize,
failed=is_error,
)
# -- media upload --------------------------------------------------------
async def _upload_media(self, file_path: str | Path, media_type: str) -> str | None:
try:
file_bytes = await asyncio.to_thread(Path(file_path).read_bytes)
token = await self._get_access_token()
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/files/upload",
headers={"x-acs-dingtalk-access-token": token},
files={"file": ("upload", file_bytes)},
data={"type": media_type},
)
response.raise_for_status()
try:
payload = response.json()
except json.JSONDecodeError:
logger.exception("[DingTalk] failed to decode upload response JSON: %s", file_path)
return None
if not isinstance(payload, dict):
logger.warning("[DingTalk] unexpected upload response type %s for %s", type(payload).__name__, file_path)
return None
return payload.get("mediaId")
except (httpx.HTTPError, OSError):
logger.exception("[DingTalk] failed to upload media: %s", file_path)
return None
@staticmethod
def _log_future_error(fut: Any, name: str, msg_id: str) -> None:
try:
exc = fut.exception()
if exc:
logger.error("[DingTalk] %s failed for msg_id=%s: %s", name, msg_id, exc)
except (asyncio.CancelledError, asyncio.InvalidStateError):
pass
class _DingTalkMessageHandler:
"""Callback handler registered with dingtalk-stream."""
def __init__(self, channel: DingTalkChannel) -> None:
self._channel = channel
def pre_start(self) -> None:
if hasattr(self, "dingtalk_client") and self.dingtalk_client is not None:
self._channel._dingtalk_client = self.dingtalk_client
async def raw_process(self, callback_message: Any) -> Any:
import dingtalk_stream
from dingtalk_stream.frames import Headers
code, message = await self.process(callback_message)
ack_message = dingtalk_stream.AckMessage()
ack_message.code = code
ack_message.headers.message_id = callback_message.headers.message_id
ack_message.headers.content_type = Headers.CONTENT_TYPE_APPLICATION_JSON
ack_message.data = {"response": message}
return ack_message
async def process(self, callback: Any) -> tuple[int, str]:
import dingtalk_stream
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
self._channel._on_chatbot_message(incoming_message)
return dingtalk_stream.AckMessage.STATUS_OK, "OK"
+4
View File
@@ -63,6 +63,10 @@ class FeishuChannel(Channel):
self._GetMessageResourceRequest = None
self._thread_lock = threading.Lock()
@property
def supports_streaming(self) -> bool:
return True
async def start(self) -> None:
if self._running:
return
+20
View File
@@ -38,6 +38,7 @@ STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
CHANNEL_CAPABILITIES = {
"dingtalk": {"supports_streaming": False},
"discord": {"supports_streaming": False},
"feishu": {"supports_streaming": True},
"slack": {"supports_streaming": False},
@@ -48,6 +49,13 @@ CHANNEL_CAPABILITIES = {
InboundFileReader = Callable[[dict[str, Any], httpx.AsyncClient], Awaitable[bytes | None]]
_METADATA_DROP_KEYS = frozenset({"raw_message", "ref_msg"})
def _slim_metadata(meta: dict[str, Any]) -> dict[str, Any]:
"""Return a shallow copy of *meta* with known-large keys removed."""
return {k: v for k, v in meta.items() if k not in _METADATA_DROP_KEYS}
INBOUND_FILE_READERS: dict[str, InboundFileReader] = {}
@@ -543,6 +551,13 @@ class ChannelManager:
@staticmethod
def _channel_supports_streaming(channel_name: str) -> bool:
from .service import get_channel_service
service = get_channel_service()
if service:
channel = service.get_channel(channel_name)
if channel is not None:
return channel.supports_streaming
return CHANNEL_CAPABILITIES.get(channel_name, {}).get("supports_streaming", False)
def _resolve_session_layer(self, msg: InboundMessage) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -772,6 +787,7 @@ class ChannelManager:
artifacts=artifacts,
attachments=attachments,
thread_ts=msg.thread_ts,
metadata=_slim_metadata(msg.metadata),
)
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
await self.bus.publish_outbound(outbound)
@@ -833,6 +849,7 @@ class ChannelManager:
text=latest_text,
is_final=False,
thread_ts=msg.thread_ts,
metadata=_slim_metadata(msg.metadata),
)
)
last_published_text = latest_text
@@ -877,6 +894,7 @@ class ChannelManager:
attachments=attachments,
is_final=True,
thread_ts=msg.thread_ts,
metadata=_slim_metadata(msg.metadata),
)
)
@@ -935,6 +953,7 @@ class ChannelManager:
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
text=reply,
thread_ts=msg.thread_ts,
metadata=_slim_metadata(msg.metadata),
)
await self.bus.publish_outbound(outbound)
@@ -968,5 +987,6 @@ class ChannelManager:
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
text=error_text,
thread_ts=msg.thread_ts,
metadata=_slim_metadata(msg.metadata),
)
await self.bus.publish_outbound(outbound)
+19 -8
View File
@@ -4,7 +4,7 @@ from __future__ import annotations
import logging
import os
from typing import Any
from typing import TYPE_CHECKING, Any
from app.channels.base import Channel
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
@@ -13,8 +13,12 @@ from app.channels.store import ChannelStore
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
# Channel name → import path for lazy loading
_CHANNEL_REGISTRY: dict[str, str] = {
"dingtalk": "app.channels.dingtalk:DingTalkChannel",
"discord": "app.channels.discord:DiscordChannel",
"feishu": "app.channels.feishu:FeishuChannel",
"slack": "app.channels.slack:SlackChannel",
@@ -25,6 +29,7 @@ _CHANNEL_REGISTRY: dict[str, str] = {
# Keys that indicate a user has configured credentials for a channel.
_CHANNEL_CREDENTIAL_KEYS: dict[str, list[str]] = {
"dingtalk": ["client_id", "client_secret"],
"discord": ["bot_token"],
"feishu": ["app_id", "app_secret"],
"slack": ["bot_token", "app_token"],
@@ -75,14 +80,15 @@ class ChannelService:
self._running = False
@classmethod
def from_app_config(cls) -> ChannelService:
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
"""Create a ChannelService from the application config."""
from deerflow.config.app_config import get_app_config
if app_config is None:
from deerflow.config.app_config import get_app_config
config = get_app_config()
app_config = get_app_config()
channels_config = {}
# extra fields are allowed by AppConfig (extra="allow")
extra = config.model_extra or {}
extra = app_config.model_extra or {}
if "channels" in extra:
channels_config = extra["channels"]
return cls(channels_config=channels_config)
@@ -162,11 +168,16 @@ class ChannelService:
try:
channel = channel_cls(bus=self.bus, config=config)
await channel.start()
self._channels[name] = channel
await channel.start()
if not channel.is_running:
self._channels.pop(name, None)
logger.error("Channel %s did not enter a running state after start()", name)
return False
logger.info("Channel %s started", name)
return True
except Exception:
self._channels.pop(name, None)
logger.exception("Failed to start channel %s", name)
return False
@@ -201,12 +212,12 @@ def get_channel_service() -> ChannelService | None:
return _channel_service
async def start_channel_service() -> ChannelService:
async def start_channel_service(app_config: AppConfig | None = None) -> ChannelService:
"""Create and start the global ChannelService from app config."""
global _channel_service
if _channel_service is not None:
return _channel_service
_channel_service = ChannelService.from_app_config()
_channel_service = ChannelService.from_app_config(app_config)
await _channel_service.start()
return _channel_service
+4
View File
@@ -29,6 +29,10 @@ class WeComChannel(Channel):
self._ws_stream_ids: dict[str, str] = {}
self._working_message = "Working on it..."
@property
def supports_streaming(self) -> bool:
return True
def _clear_ws_context(self, thread_ts: str | None) -> None:
if not thread_ts:
return
+12 -7
View File
@@ -28,9 +28,13 @@ from app.gateway.routers import (
threads,
uploads,
)
from deerflow.config.app_config import get_app_config
from deerflow.config import app_config as deerflow_app_config
from deerflow.config.app_config import apply_logging_level
# Configure logging
AppConfig = deerflow_app_config.AppConfig
get_app_config = deerflow_app_config.get_app_config
# Default logging; lifespan overrides from config.yaml log_level.
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
@@ -160,7 +164,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Load config and check necessary environment variables at startup
try:
get_app_config()
app.state.config = get_app_config()
apply_logging_level(app.state.config.log_level)
logger.info("Configuration loaded successfully")
except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}"
@@ -181,7 +186,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
try:
from app.channels.service import start_channel_service
channel_service = await start_channel_service()
channel_service = await start_channel_service(app.state.config)
logger.info("Channel service started: %s", channel_service.get_status())
except Exception:
logger.exception("No IM channels configured or channel service failed to start")
@@ -213,6 +218,8 @@ def create_app() -> FastAPI:
Returns:
Configured FastAPI application instance.
"""
config = get_gateway_config()
docs_kwargs = {"docs_url": "/docs", "redoc_url": "/redoc", "openapi_url": "/openapi.json"} if config.enable_docs else {"docs_url": None, "redoc_url": None, "openapi_url": None}
app = FastAPI(
title="DeerFlow API Gateway",
@@ -237,9 +244,7 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
""",
version="0.1.0",
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
**docs_kwargs,
openapi_tags=[
{
"name": "models",
+3 -3
View File
@@ -4,11 +4,8 @@ import logging
import os
import secrets
from dotenv import load_dotenv
from pydantic import BaseModel, Field
load_dotenv()
logger = logging.getLogger(__name__)
@@ -37,6 +34,9 @@ def get_auth_config() -> AuthConfig:
"""Get the global AuthConfig instance. Parses from env on first call."""
global _auth_config
if _auth_config is None:
from dotenv import load_dotenv
load_dotenv()
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
if not jwt_secret:
jwt_secret = secrets.token_urlsafe(32)
+14 -1
View File
@@ -1,10 +1,14 @@
"""Local email/password authentication provider."""
import logging
from app.gateway.auth.models import User
from app.gateway.auth.password import hash_password_async, verify_password_async
from app.gateway.auth.password import hash_password_async, needs_rehash, verify_password_async
from app.gateway.auth.providers import AuthProvider
from app.gateway.auth.repositories.base import UserRepository
logger = logging.getLogger(__name__)
class LocalAuthProvider(AuthProvider):
"""Email/password authentication provider using local database."""
@@ -43,6 +47,15 @@ class LocalAuthProvider(AuthProvider):
if not await verify_password_async(password, user.password_hash):
return None
if needs_rehash(user.password_hash):
try:
user.password_hash = await hash_password_async(password)
await self._repo.update_user(user)
except Exception:
# Rehash is an opportunistic upgrade; a transient DB error must not
# prevent an otherwise-valid login from succeeding.
logger.warning("Failed to rehash password for user %s; login will still succeed", user.email, exc_info=True)
return user
async def get_user(self, user_id: str) -> User | None:
+53 -5
View File
@@ -1,18 +1,66 @@
"""Password hashing utilities using bcrypt directly."""
"""Password hashing utilities with versioned hash format.
Hash format: ``$dfv<N>$<bcrypt_hash>`` where ``<N>`` is the version.
- **v1** (legacy): ``bcrypt(password)`` — plain bcrypt, susceptible to
72-byte silent truncation.
- **v2** (current): ``bcrypt(b64(sha256(password)))`` — SHA-256 pre-hash
avoids the 72-byte truncation limit so the full password contributes
to the hash.
Verification auto-detects the version and falls back to v1 for hashes
without a prefix, so existing deployments upgrade transparently on next
login.
"""
import asyncio
import base64
import hashlib
import bcrypt
_CURRENT_VERSION = 2
_PREFIX_V2 = "$dfv2$"
_PREFIX_V1 = "$dfv1$"
def _pre_hash_v2(password: str) -> bytes:
"""SHA-256 pre-hash to bypass bcrypt's 72-byte limit."""
return base64.b64encode(hashlib.sha256(password.encode("utf-8")).digest())
def hash_password(password: str) -> str:
"""Hash a password using bcrypt."""
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
"""Hash a password (current version: v2 — SHA-256 + bcrypt)."""
raw = bcrypt.hashpw(_pre_hash_v2(password), bcrypt.gensalt()).decode("utf-8")
return f"{_PREFIX_V2}{raw}"
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
"""Verify a password, auto-detecting the hash version.
Accepts v2 (``$dfv2$…``), v1 (``$dfv1$…``), and bare bcrypt hashes
(treated as v1 for backward compatibility with pre-versioning data).
"""
try:
if hashed_password.startswith(_PREFIX_V2):
bcrypt_hash = hashed_password[len(_PREFIX_V2) :]
return bcrypt.checkpw(_pre_hash_v2(plain_password), bcrypt_hash.encode("utf-8"))
if hashed_password.startswith(_PREFIX_V1):
bcrypt_hash = hashed_password[len(_PREFIX_V1) :]
else:
bcrypt_hash = hashed_password
return bcrypt.checkpw(plain_password.encode("utf-8"), bcrypt_hash.encode("utf-8"))
except ValueError:
# bcrypt raises ValueError for malformed or corrupt hashes (e.g., invalid salt).
# Fail closed rather than crashing the request.
return False
def needs_rehash(hashed_password: str) -> bool:
"""Return True if the hash uses an older version and should be rehashed."""
return not hashed_password.startswith(_PREFIX_V2)
async def hash_password_async(password: str) -> str:
+10 -2
View File
@@ -145,7 +145,11 @@ async def _authenticate(request: Request) -> AuthContext:
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
"""Decorator that authenticates the request and sets AuthContext.
"""Decorator that authenticates the request and enforces authentication.
Independently raises HTTP 401 for unauthenticated requests, regardless of
whether ``AuthMiddleware`` is present in the ASGI stack. Sets the resolved
``AuthContext`` on ``request.state.auth`` for downstream handlers.
Must be placed ABOVE other decorators (executes after them).
@@ -158,7 +162,8 @@ def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
...
Raises:
ValueError: If 'request' parameter is missing
HTTPException: 401 if the request is unauthenticated.
ValueError: If 'request' parameter is missing.
"""
@functools.wraps(func)
@@ -181,6 +186,9 @@ def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
auth_context = await _authenticate(request)
request.state.auth = auth_context
if not auth_context.is_authenticated:
raise HTTPException(status_code=401, detail="Authentication required")
return await func(*args, **kwargs)
return wrapper
+2
View File
@@ -9,6 +9,7 @@ class GatewayConfig(BaseModel):
host: str = Field(default="0.0.0.0", description="Host to bind the gateway server")
port: int = Field(default=8001, description="Port to bind the gateway server")
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"], description="Allowed CORS origins")
enable_docs: bool = Field(default=True, description="Enable Swagger/ReDoc/OpenAPI endpoints")
_gateway_config: GatewayConfig | None = None
@@ -23,5 +24,6 @@ def get_gateway_config() -> GatewayConfig:
host=os.getenv("GATEWAY_HOST", "0.0.0.0"),
port=int(os.getenv("GATEWAY_PORT", "8001")),
cors_origins=cors_origins_str.split(","),
enable_docs=os.getenv("GATEWAY_ENABLE_DOCS", "true").lower() == "true",
)
return _gateway_config
+19 -8
View File
@@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, TypeVar, cast
from fastapi import FastAPI, HTTPException, Request
from langgraph.types import Checkpointer
from deerflow.config.app_config import AppConfig
from deerflow.persistence.feedback import FeedbackRepository
from deerflow.runtime import RunContext, RunManager, StreamBridge
from deerflow.runtime.events.store.base import RunEventStore
@@ -29,6 +30,14 @@ if TYPE_CHECKING:
T = TypeVar("T")
def get_config(request: Request) -> AppConfig:
"""Return the app-scoped ``AppConfig`` stored on ``app.state``."""
config = getattr(request.app.state, "config", None)
if config is None:
raise HTTPException(status_code=503, detail="Configuration not available")
return config
@asynccontextmanager
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
"""Bootstrap and tear down all LangGraph runtime singletons.
@@ -38,22 +47,24 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
async with langgraph_runtime(app):
yield
"""
from deerflow.config import get_app_config
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
from deerflow.runtime import make_store, make_stream_bridge
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
from deerflow.runtime.events.store import make_run_event_store
async with AsyncExitStack() as stack:
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
config = getattr(app.state, "config", None)
if config is None:
raise RuntimeError("langgraph_runtime() requires app.state.config to be initialized")
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge(config))
# Initialize persistence engine BEFORE checkpointer so that
# auto-create-database logic runs first (postgres backend).
config = get_app_config()
await init_engine_from_config(config.database)
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
app.state.store = await stack.enter_async_context(make_store())
app.state.checkpointer = await stack.enter_async_context(make_checkpointer(config))
app.state.store = await stack.enter_async_context(make_store(config))
# Initialize repositories — one get_session_factory() call for all.
sf = get_session_factory()
@@ -130,14 +141,14 @@ def get_run_context(request: Request) -> RunContext:
Returns a *base* context with infrastructure dependencies.
"""
from deerflow.config import get_app_config
config = get_config(request)
return RunContext(
checkpointer=get_checkpointer(request),
store=get_store(request),
event_store=get_run_event_store(request),
run_events_config=getattr(get_app_config(), "run_events", None),
run_events_config=getattr(config, "run_events", None),
thread_store=get_thread_store(request),
app_config=config,
)
+1 -1
View File
@@ -73,7 +73,7 @@ async def authenticate(request):
if isinstance(payload, TokenError):
raise Auth.exceptions.HTTPException(
status_code=401,
detail=f"Token error: {payload.value}",
detail="Invalid token",
)
user = await get_local_provider().get_user(payload.sub)
+36 -2
View File
@@ -146,7 +146,13 @@ def _set_session_cookie(response: Response, token: str, request: Request) -> Non
# ── Rate Limiting ────────────────────────────────────────────────────────
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
# In-process dict — not shared across workers.
#
# **Limitation**: with multi-worker deployments (e.g., gunicorn -w N), each
# worker maintains its own lockout table, so an attacker effectively gets
# N × _MAX_LOGIN_ATTEMPTS guesses before being locked out everywhere. For
# production multi-worker setups, replace this with a shared store (Redis,
# database-backed counter) to enforce a true per-IP limit.
_MAX_LOGIN_ATTEMPTS = 5
_LOCKOUT_SECONDS = 300 # 5 minutes
@@ -376,9 +382,37 @@ async def get_me(request: Request):
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
_SETUP_STATUS_COOLDOWN: dict[str, float] = {}
_SETUP_STATUS_COOLDOWN_SECONDS = 60
_MAX_TRACKED_SETUP_STATUS_IPS = 10000
@router.get("/setup-status")
async def setup_status():
async def setup_status(request: Request):
"""Check if an admin account exists. Returns needs_setup=True when no admin exists."""
client_ip = _get_client_ip(request)
now = time.time()
last_check = _SETUP_STATUS_COOLDOWN.get(client_ip, 0)
elapsed = now - last_check
if elapsed < _SETUP_STATUS_COOLDOWN_SECONDS:
retry_after = max(1, int(_SETUP_STATUS_COOLDOWN_SECONDS - elapsed))
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Setup status check is rate limited",
headers={"Retry-After": str(retry_after)},
)
# Evict stale entries when dict grows too large to bound memory usage.
if len(_SETUP_STATUS_COOLDOWN) >= _MAX_TRACKED_SETUP_STATUS_IPS:
cutoff = now - _SETUP_STATUS_COOLDOWN_SECONDS
stale = [k for k, t in _SETUP_STATUS_COOLDOWN.items() if t < cutoff]
for k in stale:
del _SETUP_STATUS_COOLDOWN[k]
# If still too large after evicting expired entries, remove oldest half.
if len(_SETUP_STATUS_COOLDOWN) >= _MAX_TRACKED_SETUP_STATUS_IPS:
by_time = sorted(_SETUP_STATUS_COOLDOWN.items(), key=lambda kv: kv[1])
for k, _ in by_time[: len(by_time) // 2]:
del _SETUP_STATUS_COOLDOWN[k]
_SETUP_STATUS_COOLDOWN[client_ip] = now
admin_count = await get_local_provider().count_admin_users()
return {"needs_setup": admin_count == 0}
+5 -6
View File
@@ -1,7 +1,8 @@
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from deerflow.config import get_app_config
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
router = APIRouter(prefix="/api", tags=["models"])
@@ -36,7 +37,7 @@ class ModelsListResponse(BaseModel):
summary="List All Models",
description="Retrieve a list of all available AI models configured in the system.",
)
async def list_models() -> ModelsListResponse:
async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResponse:
"""List all available models from configuration.
Returns model information suitable for frontend display,
@@ -72,7 +73,6 @@ async def list_models() -> ModelsListResponse:
}
```
"""
config = get_app_config()
models = [
ModelResponse(
name=model.name,
@@ -96,7 +96,7 @@ async def list_models() -> ModelsListResponse:
summary="Get Model Details",
description="Retrieve detailed information about a specific AI model by its name.",
)
async def get_model(model_name: str) -> ModelResponse:
async def get_model(model_name: str, config: AppConfig = Depends(get_config)) -> ModelResponse:
"""Get a specific model by name.
Args:
@@ -118,7 +118,6 @@ async def get_model(model_name: str) -> ModelResponse:
}
```
"""
config = get_app_config()
model = config.get_model_config(model_name)
if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
+69 -79
View File
@@ -1,30 +1,20 @@
import errno
import json
import logging
import shutil
from pathlib import Path
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from app.gateway.deps import get_config
from app.gateway.path_utils import resolve_thread_virtual_path
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
from deerflow.skills.manager import (
append_history,
atomic_write,
custom_skill_exists,
ensure_custom_skill_is_editable,
get_custom_skill_dir,
get_custom_skill_file,
get_skill_history_file,
read_custom_skill_content,
read_history,
validate_skill_markdown_content,
)
from deerflow.skills import Skill
from deerflow.skills.installer import SkillAlreadyExistsError
from deerflow.skills.security_scanner import scan_skill_content
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.types import SKILL_MD_FILE, SkillCategory
logger = logging.getLogger(__name__)
@@ -37,7 +27,7 @@ class SkillResponse(BaseModel):
name: str = Field(..., description="Name of the skill")
description: str = Field(..., description="Description of what the skill does")
license: str | None = Field(None, description="License information")
category: str = Field(..., description="Category of the skill (public or custom)")
category: SkillCategory = Field(..., description="Category of the skill (public or custom)")
enabled: bool = Field(default=True, description="Whether this skill is enabled")
@@ -101,9 +91,9 @@ def _skill_to_response(skill: Skill) -> SkillResponse:
summary="List All Skills",
description="Retrieve a list of all available skills from both public and custom directories.",
)
async def list_skills() -> SkillsListResponse:
async def list_skills(config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try:
skills = load_skills(enabled_only=False)
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error(f"Failed to load skills: {e}", exc_info=True)
@@ -116,10 +106,10 @@ async def list_skills() -> SkillsListResponse:
summary="Install Skill",
description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.",
)
async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
async def install_skill(request: SkillInstallRequest, config: AppConfig = Depends(get_config)) -> SkillInstallResponse:
try:
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
result = install_skill_from_archive(skill_file_path)
result = await get_or_new_skill_storage(app_config=config).ainstall_skill_from_archive(skill_file_path)
await refresh_skills_system_prompt_cache_async()
return SkillInstallResponse(**result)
except FileNotFoundError as e:
@@ -136,9 +126,9 @@ async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
async def list_custom_skills() -> SkillsListResponse:
async def list_custom_skills(config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try:
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
skills = [skill for skill in get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False) if skill.category == SkillCategory.CUSTOM]
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error("Failed to list custom skills: %s", e, exc_info=True)
@@ -146,13 +136,14 @@ async def list_custom_skills() -> SkillsListResponse:
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
async def get_custom_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try:
skills = load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == skill_name and s.category == SkillCategory.CUSTOM), None)
if skill is None:
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name))
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=get_or_new_skill_storage(app_config=config).read_custom_skill(skill_name))
except HTTPException:
raise
except Exception as e:
@@ -161,30 +152,31 @@ async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse:
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try:
ensure_custom_skill_is_editable(skill_name)
validate_skill_markdown_content(skill_name, request.content)
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md")
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
storage = get_or_new_skill_storage(app_config=config)
storage.ensure_custom_skill_is_editable(skill_name)
storage.validate_skill_markdown_content(skill_name, request.content)
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/{SKILL_MD_FILE}", app_config=config)
if scan.decision == "block":
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md"
prev_content = skill_file.read_text(encoding="utf-8")
atomic_write(skill_file, request.content)
append_history(
prev_content = storage.read_custom_skill(skill_name)
storage.write_custom_skill(skill_name, SKILL_MD_FILE, request.content)
storage.append_history(
skill_name,
{
"action": "human_edit",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"file_path": SKILL_MD_FILE,
"prev_content": prev_content,
"new_content": request.content,
"scanner": {"decision": scan.decision, "reason": scan.reason},
},
)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
return await get_custom_skill(skill_name, config)
except HTTPException:
raise
except FileNotFoundError as e:
@@ -197,29 +189,22 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
async def delete_custom_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> dict[str, bool]:
try:
ensure_custom_skill_is_editable(skill_name)
skill_dir = get_custom_skill_dir(skill_name)
prev_content = read_custom_skill_content(skill_name)
try:
append_history(
skill_name,
{
"action": "human_delete",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"prev_content": prev_content,
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
)
except OSError as e:
if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}:
raise
logger.warning("Skipping delete history write for custom skill %s due to readonly/permission failure; continuing with skill directory removal: %s", skill_name, e)
shutil.rmtree(skill_dir)
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
storage = get_or_new_skill_storage(app_config=config)
storage.delete_custom_skill(
skill_name,
history_meta={
"action": "human_delete",
"author": "human",
"thread_id": None,
"file_path": SKILL_MD_FILE,
"prev_content": None,
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
)
await refresh_skills_system_prompt_cache_async()
return {"success": True}
except FileNotFoundError as e:
@@ -232,11 +217,13 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse:
async def get_custom_skill_history(skill_name: str, config: AppConfig = Depends(get_config)) -> CustomSkillHistoryResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
storage = get_or_new_skill_storage(app_config=config)
if not storage.custom_skill_exists(skill_name) and not storage.get_skill_history_file(skill_name).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillHistoryResponse(history=read_history(skill_name))
return CustomSkillHistoryResponse(history=storage.read_history(skill_name))
except HTTPException:
raise
except Exception as e:
@@ -245,38 +232,39 @@ async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryRespons
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse:
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
storage = get_or_new_skill_storage(app_config=config)
if not storage.custom_skill_exists(skill_name) and not storage.get_skill_history_file(skill_name).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
history = read_history(skill_name)
history = storage.read_history(skill_name)
if not history:
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
record = history[request.history_index]
target_content = record.get("prev_content")
if target_content is None:
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
validate_skill_markdown_content(skill_name, target_content)
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md")
skill_file = get_custom_skill_file(skill_name)
storage.validate_skill_markdown_content(skill_name, target_content)
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/{SKILL_MD_FILE}", app_config=config)
skill_file = storage.get_custom_skill_file(skill_name)
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
history_entry = {
"action": "rollback",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"file_path": SKILL_MD_FILE,
"prev_content": current_content,
"new_content": target_content,
"rollback_from_ts": record.get("ts"),
"scanner": {"decision": scan.decision, "reason": scan.reason},
}
if scan.decision == "block":
append_history(skill_name, history_entry)
storage.append_history(skill_name, history_entry)
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
atomic_write(skill_file, target_content)
append_history(skill_name, history_entry)
storage.write_custom_skill(skill_name, SKILL_MD_FILE, target_content)
storage.append_history(skill_name, history_entry)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
return await get_custom_skill(skill_name, config)
except HTTPException:
raise
except IndexError:
@@ -296,9 +284,10 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
summary="Get Skill Details",
description="Retrieve detailed information about a specific skill by its name.",
)
async def get_skill(skill_name: str) -> SkillResponse:
async def get_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
@@ -318,9 +307,10 @@ async def get_skill(skill_name: str) -> SkillResponse:
summary="Update Skill",
description="Update a skill's enabled status by modifying the extensions_config.json file.",
)
async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillResponse:
async def update_skill(skill_name: str, request: SkillUpdateRequest, config: AppConfig = Depends(get_config)) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
@@ -346,7 +336,7 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
reload_extensions_config()
await refresh_skills_system_prompt_cache_async()
skills = load_skills(enabled_only=False)
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
updated_skill = next((s for s in skills if s.name == skill_name), None)
if updated_skill is None:
+10 -3
View File
@@ -1,11 +1,13 @@
import json
import logging
from fastapi import APIRouter, Request
from fastapi import APIRouter, Depends, Request
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@@ -100,7 +102,12 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
)
@require_permission("threads", "read", owner_check=True)
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
async def generate_suggestions(
thread_id: str,
body: SuggestionsRequest,
request: Request,
config: AppConfig = Depends(get_config),
) -> SuggestionsResponse:
if not body.messages:
return SuggestionsResponse(suggestions=[])
@@ -122,7 +129,7 @@ async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try:
model = create_chat_model(name=body.model_name, thinking_enabled=False)
model = create_chat_model(name=body.model_name, thinking_enabled=False, app_config=config)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)], config={"run_name": "suggest_agent"})
raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or []
+123 -18
View File
@@ -4,11 +4,12 @@ import logging
import os
import stat
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from pydantic import BaseModel
from app.gateway.authz import require_permission
from deerflow.config.app_config import get_app_config
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
@@ -29,6 +30,11 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads/{thread_id}/uploads", tags=["uploads"])
UPLOAD_CHUNK_SIZE = 8192
DEFAULT_MAX_FILES = 10
DEFAULT_MAX_FILE_SIZE = 50 * 1024 * 1024
DEFAULT_MAX_TOTAL_SIZE = 100 * 1024 * 1024
class UploadResponse(BaseModel):
"""Response model for file upload."""
@@ -38,6 +44,14 @@ class UploadResponse(BaseModel):
message: str
class UploadLimits(BaseModel):
"""Application-level upload limits exposed to clients."""
max_files: int
max_file_size: int
max_total_size: int
def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
"""Ensure uploaded files remain writable when mounted into non-local sandboxes.
@@ -60,23 +74,78 @@ def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
def _get_uploads_config_value(key: str, default: object) -> object:
def _get_uploads_config_value(app_config: AppConfig, key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access."""
cfg = get_app_config()
uploads_cfg = getattr(cfg, "uploads", None)
uploads_cfg = getattr(app_config, "uploads", None)
if isinstance(uploads_cfg, dict):
return uploads_cfg.get(key, default)
return getattr(uploads_cfg, key, default)
def _auto_convert_documents_enabled() -> bool:
def _get_upload_limit(app_config: AppConfig, key: str, default: int, *, legacy_key: str | None = None) -> int:
try:
value = _get_uploads_config_value(app_config, key, None)
if value is None and legacy_key is not None:
value = _get_uploads_config_value(app_config, legacy_key, None)
if value is None:
value = default
limit = int(value)
if limit <= 0:
raise ValueError
return limit
except Exception:
logger.warning("Invalid uploads.%s value; falling back to %d", key, default)
return default
def _get_upload_limits(app_config: AppConfig) -> UploadLimits:
return UploadLimits(
max_files=_get_upload_limit(app_config, "max_files", DEFAULT_MAX_FILES, legacy_key="max_file_count"),
max_file_size=_get_upload_limit(app_config, "max_file_size", DEFAULT_MAX_FILE_SIZE, legacy_key="max_single_file_size"),
max_total_size=_get_upload_limit(app_config, "max_total_size", DEFAULT_MAX_TOTAL_SIZE),
)
def _cleanup_uploaded_paths(paths: list[os.PathLike[str] | str]) -> None:
for path in reversed(paths):
try:
os.unlink(path)
except FileNotFoundError:
pass
except Exception:
logger.warning("Failed to clean up upload path after rejected request: %s", path, exc_info=True)
async def _write_upload_file_streaming(
file: UploadFile,
file_path: os.PathLike[str] | str,
*,
display_filename: str,
max_single_file_size: int,
max_total_size: int,
total_size: int,
) -> tuple[int, int]:
file_size = 0
with open(file_path, "wb") as output:
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
file_size += len(chunk)
total_size += len(chunk)
if file_size > max_single_file_size:
raise HTTPException(status_code=413, detail=f"File too large: {display_filename}")
if total_size > max_total_size:
raise HTTPException(status_code=413, detail="Total upload size too large")
output.write(chunk)
return file_size, total_size
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
"""Return whether automatic host-side document conversion is enabled.
The secure default is disabled unless an operator explicitly opts in via
uploads.auto_convert_documents in config.yaml.
"""
try:
raw = _get_uploads_config_value("auto_convert_documents", False)
raw = _get_uploads_config_value(app_config, "auto_convert_documents", False)
if isinstance(raw, str):
return raw.strip().lower() in {"1", "true", "yes", "on"}
return bool(raw)
@@ -90,17 +159,25 @@ async def upload_files(
thread_id: str,
request: Request,
files: list[UploadFile] = File(...),
config: AppConfig = Depends(get_config),
) -> UploadResponse:
"""Upload multiple files to a thread's uploads directory."""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
limits = _get_upload_limits(config)
if len(files) > limits.max_files:
raise HTTPException(status_code=413, detail=f"Too many files: maximum is {limits.max_files}")
try:
uploads_dir = ensure_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
uploaded_files = []
written_paths = []
sandbox_sync_targets = []
total_size = 0
sandbox_provider = get_sandbox_provider()
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
@@ -108,7 +185,9 @@ async def upload_files(
if sync_to_sandbox:
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
auto_convert_documents = _auto_convert_documents_enabled()
if sandbox is None:
raise HTTPException(status_code=500, detail="Failed to acquire sandbox")
auto_convert_documents = _auto_convert_documents_enabled(config)
for file in files:
if not file.filename:
@@ -121,35 +200,41 @@ async def upload_files(
continue
try:
content = await file.read()
file_path = uploads_dir / safe_filename
file_path.write_bytes(content)
written_paths.append(file_path)
file_size, total_size = await _write_upload_file_streaming(
file,
file_path,
display_filename=safe_filename,
max_single_file_size=limits.max_file_size,
max_total_size=limits.max_total_size,
total_size=total_size,
)
virtual_path = upload_virtual_path(safe_filename)
if sync_to_sandbox and sandbox is not None:
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, content)
if sync_to_sandbox:
sandbox_sync_targets.append((file_path, virtual_path))
file_info = {
"filename": safe_filename,
"size": str(len(content)),
"size": str(file_size),
"path": str(sandbox_uploads / safe_filename),
"virtual_path": virtual_path,
"artifact_url": upload_artifact_url(thread_id, safe_filename),
}
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
logger.info(f"Saved file: {safe_filename} ({file_size} bytes) to {file_info['path']}")
file_ext = file_path.suffix.lower()
if auto_convert_documents and file_ext in CONVERTIBLE_EXTENSIONS:
md_path = await convert_file_to_markdown(file_path)
if md_path:
written_paths.append(md_path)
md_virtual_path = upload_virtual_path(md_path.name)
if sync_to_sandbox and sandbox is not None:
_make_file_sandbox_writable(md_path)
sandbox.update_file(md_virtual_path, md_path.read_bytes())
if sync_to_sandbox:
sandbox_sync_targets.append((md_path, md_virtual_path))
file_info["markdown_file"] = md_path.name
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
@@ -158,10 +243,19 @@ async def upload_files(
uploaded_files.append(file_info)
except HTTPException as e:
_cleanup_uploaded_paths(written_paths)
raise e
except Exception as e:
logger.error(f"Failed to upload {file.filename}: {e}")
_cleanup_uploaded_paths(written_paths)
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
if sync_to_sandbox:
for file_path, virtual_path in sandbox_sync_targets:
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, file_path.read_bytes())
return UploadResponse(
success=True,
files=uploaded_files,
@@ -169,6 +263,17 @@ async def upload_files(
)
@router.get("/limits", response_model=UploadLimits)
@require_permission("threads", "read", owner_check=True)
async def get_upload_limits(
thread_id: str,
request: Request,
config: AppConfig = Depends(get_config),
) -> UploadLimits:
"""Return upload limits used by the gateway for this thread."""
return _get_upload_limits(config)
@router.get("/list", response_model=dict)
@require_permission("threads", "read", owner_check=True)
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
+40 -18
View File
@@ -98,6 +98,44 @@ def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
_DEFAULT_ASSISTANT_ID = "lead_agent"
# Whitelist of run-context keys that the langgraph-compat layer forwards from
# ``body.context`` into the run config. ``config["context"]`` exists in
# LangGraph >=0.6, but these values must be written to both ``configurable``
# (for legacy ``_get_runtime_config`` consumers) and ``context`` because
# LangGraph >=1.1.9 no longer makes ``ToolRuntime.context`` fall back to
# ``configurable`` for consumers like ``setup_agent``.
_CONTEXT_CONFIGURABLE_KEYS: frozenset[str] = frozenset(
{
"model_name",
"mode",
"thinking_enabled",
"reasoning_effort",
"is_plan_mode",
"subagent_enabled",
"max_concurrent_subagents",
"agent_name",
"is_bootstrap",
}
)
def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, Any] | None) -> None:
"""Merge whitelisted keys from ``body.context`` into both ``config['configurable']``
and ``config['context']`` so they are visible to legacy configurable readers and
to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool —
see issue #2677)."""
if not context:
return
configurable = config.setdefault("configurable", {})
runtime_context = config.setdefault("context", {})
for key in _CONTEXT_CONFIGURABLE_KEYS:
if key in context:
if isinstance(configurable, dict):
configurable.setdefault(key, context[key])
if isinstance(runtime_context, dict):
runtime_context.setdefault(key, context[key])
def resolve_agent_factory(assistant_id: str | None):
"""Resolve the agent factory callable from config.
@@ -245,27 +283,11 @@ async def start_run(
graph_input = normalize_input(body.input)
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
# Merge DeerFlow-specific context overrides into configurable.
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
# The ``context`` field is a custom extension for the langgraph-compat layer
# that carries agent configuration (model_name, thinking_enabled, etc.).
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
context = getattr(body, "context", None)
if context:
_CONTEXT_CONFIGURABLE_KEYS = {
"model_name",
"mode",
"thinking_enabled",
"reasoning_effort",
"is_plan_mode",
"subagent_enabled",
"max_concurrent_subagents",
"agent_name",
"is_bootstrap",
}
configurable = config.setdefault("configurable", {})
for key in _CONTEXT_CONFIGURABLE_KEYS:
if key in context:
configurable.setdefault(key, context[key])
merge_run_context_overrides(config, getattr(body, "context", None))
stream_modes = normalize_stream_modes(body.stream_mode)
+16 -24
View File
@@ -34,50 +34,42 @@ _LOG_FMT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
_LOG_DATEFMT = "%Y-%m-%d %H:%M:%S"
def _logging_level_from_config(name: str) -> int:
"""Map ``config.yaml`` ``log_level`` string to a ``logging`` level constant."""
mapping = logging.getLevelNamesMapping()
return mapping.get((name or "info").strip().upper(), logging.INFO)
def _setup_logging(log_level: int = logging.INFO) -> None:
"""Route logs to ``debug.log`` using *log_level* for the initial root/file setup.
This configures the root logger and the ``debug.log`` file handler so logs do
not print on the interactive console. It is idempotent: any pre-existing
handlers on the root logger (e.g. installed by ``logging.basicConfig`` in
transitively imported modules) are removed so the debug session output only
lands in ``debug.log``.
def _setup_logging(log_level: str) -> None:
"""Send application logs to ``debug.log`` at *log_level*; do not print them on the console.
Idempotent: any pre-existing handlers on the root logger (e.g. installed by
``logging.basicConfig`` in transitively imported modules) are removed so the
debug session output only lands in ``debug.log``.
Note: later config-driven logging adjustments may change named logger
verbosity without raising the root logger or file-handler thresholds set
here, so the eventual contents of ``debug.log`` may not be filtered solely by
this function's ``log_level`` argument.
"""
level = _logging_level_from_config(log_level)
root = logging.root
for h in list(root.handlers):
root.removeHandler(h)
h.close()
root.setLevel(level)
root.setLevel(log_level)
file_handler = logging.FileHandler("debug.log", mode="a", encoding="utf-8")
file_handler.setLevel(level)
file_handler.setLevel(log_level)
file_handler.setFormatter(logging.Formatter(_LOG_FMT, datefmt=_LOG_DATEFMT))
root.addHandler(file_handler)
def _update_logging_level(log_level: str) -> None:
"""Update the root logger and existing handlers to *log_level*."""
level = _logging_level_from_config(log_level)
root = logging.root
root.setLevel(level)
for handler in root.handlers:
handler.setLevel(level)
async def main():
# Install file logging first so warnings emitted while loading config do not
# leak onto the interactive terminal via Python's lastResort handler.
_setup_logging("info")
_setup_logging()
from deerflow.config import get_app_config
from deerflow.config.app_config import apply_logging_level
app_config = get_app_config()
_update_logging_level(app_config.log_level)
apply_logging_level(app_config.log_level)
# Delay the rest of the deerflow imports until *after* logging is installed
# so that any import-time side effects (e.g. deerflow.agents starts a
+13 -6
View File
@@ -259,6 +259,8 @@ sandbox:
When you configure `sandbox.mounts`, DeerFlow exposes those `container_path` values in the agent prompt so the agent can discover and operate on mounted directories directly instead of assuming everything must live under `/mnt/user-data`.
For bare-metal Docker sandbox runs that use localhost, DeerFlow binds the sandbox HTTP port to `127.0.0.1` by default so it is not exposed on every host interface. Docker-outside-of-Docker deployments that connect through `host.docker.internal` keep the broad legacy bind for compatibility. Set `DEER_FLOW_SANDBOX_BIND_HOST` explicitly if your deployment needs a different bind address.
### Skills
Configure the skills directory for specialized workflows:
@@ -319,11 +321,16 @@ models:
- `DEEPSEEK_API_KEY` - DeepSeek API key
- `NOVITA_API_KEY` - Novita API key (OpenAI-compatible endpoint)
- `TAVILY_API_KEY` - Tavily search API key
- `DEER_FLOW_PROJECT_ROOT` - Project root for relative runtime paths
- `DEER_FLOW_CONFIG_PATH` - Custom config file path
- `DEER_FLOW_EXTENSIONS_CONFIG_PATH` - Custom extensions config file path
- `DEER_FLOW_HOME` - Runtime state directory (defaults to `.deer-flow` under the project root)
- `DEER_FLOW_SKILLS_PATH` - Skills directory when `skills.path` is omitted
- `GATEWAY_ENABLE_DOCS` - Set to `false` to disable Swagger UI (`/docs`), ReDoc (`/redoc`), and OpenAPI schema (`/openapi.json`) endpoints (default: `true`)
## Configuration Location
The configuration file should be placed in the **project root directory** (`deer-flow/config.yaml`), not in the backend directory.
The configuration file should be placed in the **project root directory** (`deer-flow/config.yaml`). Set `DEER_FLOW_PROJECT_ROOT` when the process may start from another working directory, or set `DEER_FLOW_CONFIG_PATH` to point at a specific file.
## Configuration Priority
@@ -331,12 +338,12 @@ DeerFlow searches for configuration in this order:
1. Path specified in code via `config_path` argument
2. Path from `DEER_FLOW_CONFIG_PATH` environment variable
3. `config.yaml` in current working directory (typically `backend/` when running)
4. `config.yaml` in parent directory (project root: `deer-flow/`)
3. `config.yaml` under `DEER_FLOW_PROJECT_ROOT`, or under the current working directory when `DEER_FLOW_PROJECT_ROOT` is unset
4. Legacy backend/repository-root locations for monorepo compatibility
## Best Practices
1. **Place `config.yaml` in project root** - Not in `backend/` directory
1. **Place `config.yaml` in project root** - Set `DEER_FLOW_PROJECT_ROOT` if the runtime starts elsewhere
2. **Never commit `config.yaml`** - It's already in `.gitignore`
3. **Use environment variables for secrets** - Don't hardcode API keys
4. **Keep `config.example.yaml` updated** - Document all new options
@@ -347,7 +354,7 @@ DeerFlow searches for configuration in this order:
### "Config file not found"
- Ensure `config.yaml` exists in the **project root** directory (`deer-flow/config.yaml`)
- The backend searches parent directory by default, so root location is preferred
- If the runtime starts outside the project root, set `DEER_FLOW_PROJECT_ROOT`
- Alternatively, set `DEER_FLOW_CONFIG_PATH` environment variable to custom location
### "Invalid API key"
@@ -357,7 +364,7 @@ DeerFlow searches for configuration in this order:
### "Skills not loading"
- Check that `deer-flow/skills/` directory exists
- Verify skills have valid `SKILL.md` files
- Check `skills.path` configuration if using custom path
- Check `skills.path` or `DEER_FLOW_SKILLS_PATH` if using a custom path
### "Docker sandbox fails to start"
- Ensure Docker is running
+20 -2
View File
@@ -22,6 +22,8 @@ POST /api/threads/{thread_id}/uploads
**请求体:** `multipart/form-data`
- `files`: 一个或多个文件
网关会在应用层限制上传规模,默认最多 10 个文件、单文件 50 MiB、单次请求总计 100 MiB。可通过 `config.yaml``uploads.max_files``uploads.max_file_size``uploads.max_total_size` 调整;前端会读取同一组限制并在选择文件时提示,超过限制时后端返回 `413 Payload Too Large`
**响应:**
```json
{
@@ -48,7 +50,23 @@ POST /api/threads/{thread_id}/uploads
- `virtual_path`: Agent 在沙箱中使用的虚拟路径
- `artifact_url`: 前端通过 HTTP 访问文件的 URL
### 2. 列出已上传文件
### 2. 查询上传限制
```
GET /api/threads/{thread_id}/uploads/limits
```
返回网关当前生效的上传限制,供前端在用户选择文件前提示和拦截。
**响应:**
```json
{
"max_files": 10,
"max_file_size": 52428800,
"max_total_size": 104857600
}
```
### 3. 列出已上传文件
```
GET /api/threads/{thread_id}/uploads/list
```
@@ -71,7 +89,7 @@ GET /api/threads/{thread_id}/uploads/list
}
```
### 3. 删除文件
### 4. 删除文件
```
DELETE /api/threads/{thread_id}/uploads/{filename}
```
-343
View File
@@ -1,343 +0,0 @@
# DeerFlow 后端拆分设计文档:Harness + App
> 状态:Draft
> 作者:DeerFlow Team
> 日期:2026-03-13
## 1. 背景与动机
DeerFlow 后端当前是一个单一 Python 包(`src.*`),包含了从底层 agent 编排到上层用户产品的所有代码。随着项目发展,这种结构带来了几个问题:
- **复用困难**:其他产品(CLI 工具、Slack bot、第三方集成)想用 agent 能力,必须依赖整个后端,包括 FastAPI、IM SDK 等不需要的依赖
- **职责模糊**:agent 编排逻辑和用户产品逻辑混在同一个 `src/` 下,边界不清晰
- **依赖膨胀**LangGraph Server 运行时不需要 FastAPI/uvicorn/Slack SDK,但当前必须安装全部依赖
本文档提出将后端拆分为两部分:**deerflow-harness**(可发布的 agent 框架包)和 **app**(不打包的用户产品代码)。
## 2. 核心概念
### 2.1 Harness(线束/框架层)
Harness 是 agent 的构建与编排框架,回答 **"如何构建和运行 agent"** 的问题:
- Agent 工厂与生命周期管理
- Middleware pipeline
- 工具系统(内置工具 + MCP + 社区工具)
- 沙箱执行环境
- 子 agent 委派
- 记忆系统
- 技能加载与注入
- 模型工厂
- 配置系统
**Harness 是一个可发布的 Python 包**`deerflow-harness`),可以独立安装和使用。
**Harness 的设计原则**:对上层应用完全无感知。它不知道也不关心谁在调用它——可以是 Web App、CLI、Slack Bot、或者一个单元测试。
### 2.2 App(应用层)
App 是面向用户的产品代码,回答 **"如何将 agent 呈现给用户"** 的问题:
- Gateway APIFastAPI REST 接口)
- IM Channels(飞书、Slack、Telegram 集成)
- Custom Agent 的 CRUD 管理
- 文件上传/下载的 HTTP 接口
**App 不打包、不发布**,它是 DeerFlow 项目内部的应用代码,直接运行。
**App 依赖 Harness,但 Harness 不依赖 App。**
### 2.3 边界划分
| 模块 | 归属 | 说明 |
|------|------|------|
| `config/` | Harness | 配置系统是基础设施 |
| `reflection/` | Harness | 动态模块加载工具 |
| `utils/` | Harness | 通用工具函数 |
| `agents/` | Harness | Agent 工厂、middleware、state、memory |
| `subagents/` | Harness | 子 agent 委派系统 |
| `sandbox/` | Harness | 沙箱执行环境 |
| `tools/` | Harness | 工具注册与发现 |
| `mcp/` | Harness | MCP 协议集成 |
| `skills/` | Harness | 技能加载、解析、定义 schema |
| `models/` | Harness | LLM 模型工厂 |
| `community/` | Harness | 社区工具(tavily、jina 等) |
| `client.py` | Harness | 嵌入式 Python 客户端 |
| `gateway/` | App | FastAPI REST API |
| `channels/` | App | IM 平台集成 |
**关于 Custom Agents**agent 定义格式(`config.yaml` + `SOUL.md` schema)由 Harness 层的 `config/agents_config.py` 定义,但文件的存储、CRUD、发现机制由 App 层的 `gateway/routers/agents.py` 负责。
## 3. 目标架构
### 3.1 目录结构
```
backend/
├── packages/
│ └── harness/
│ ├── pyproject.toml # deerflow-harness 包定义
│ └── deerflow/ # Python 包根(import 前缀: deerflow.*
│ ├── __init__.py
│ ├── config/
│ ├── reflection/
│ ├── utils/
│ ├── agents/
│ │ ├── lead_agent/
│ │ ├── middlewares/
│ │ ├── memory/
│ │ ├── checkpointer/
│ │ └── thread_state.py
│ ├── subagents/
│ ├── sandbox/
│ ├── tools/
│ ├── mcp/
│ ├── skills/
│ ├── models/
│ ├── community/
│ └── client.py
├── app/ # 不打包(import 前缀: app.*
│ ├── __init__.py
│ ├── gateway/
│ │ ├── __init__.py
│ │ ├── app.py
│ │ ├── config.py
│ │ ├── path_utils.py
│ │ └── routers/
│ └── channels/
│ ├── __init__.py
│ ├── base.py
│ ├── manager.py
│ ├── service.py
│ ├── store.py
│ ├── message_bus.py
│ ├── feishu.py
│ ├── slack.py
│ └── telegram.py
├── pyproject.toml # uv workspace root
├── langgraph.json
├── tests/
├── docs/
└── Makefile
```
### 3.2 Import 规则
两个层使用不同的 import 前缀,职责边界一目了然:
```python
# ---------------------------------------------------------------
# Harness 内部互相引用(deerflow.* 前缀)
# ---------------------------------------------------------------
from deerflow.agents import make_lead_agent
from deerflow.models import create_chat_model
from deerflow.config import get_app_config
from deerflow.tools import get_available_tools
# ---------------------------------------------------------------
# App 内部互相引用(app.* 前缀)
# ---------------------------------------------------------------
from app.gateway.app import app
from app.gateway.routers.uploads import upload_files
from app.channels.service import start_channel_service
# ---------------------------------------------------------------
# App 调用 Harness(单向依赖,Harness 永远不 import app
# ---------------------------------------------------------------
from deerflow.agents import make_lead_agent
from deerflow.models import create_chat_model
from deerflow.skills import load_skills
from deerflow.config.extensions_config import get_extensions_config
```
**App 调用 Harness 示例 — Gateway 中启动 agent**
```python
# app/gateway/routers/chat.py
from deerflow.agents.lead_agent.agent import make_lead_agent
from deerflow.models import create_chat_model
from deerflow.config import get_app_config
async def create_chat_session(thread_id: str, model_name: str):
config = get_app_config()
model = create_chat_model(name=model_name)
agent = make_lead_agent(config=...)
# ... 使用 agent 处理用户消息
```
**App 调用 Harness 示例 — Channel 中查询 skills**
```python
# app/channels/manager.py
from deerflow.skills import load_skills
from deerflow.agents.memory.updater import get_memory_data
def handle_status_command():
skills = load_skills(enabled_only=True)
memory = get_memory_data()
return f"Skills: {len(skills)}, Memory facts: {len(memory.get('facts', []))}"
```
**禁止方向**Harness 代码中绝不能出现 `from app.``import app.`
### 3.3 为什么 App 不打包
| 方面 | 打包(放 packages/ 下) | 不打包(放 backend/app/ |
|------|------------------------|--------------------------|
| 命名空间 | 需要 pkgutil `extend_path` 合并,或独立前缀 | 天然独立,`app.*` vs `deerflow.*` |
| 发布需求 | 没有——App 是项目内部代码 | 不需要 pyproject.toml |
| 复杂度 | 需要管理两个包的构建、版本、依赖声明 | 直接运行,零额外配置 |
| 运行方式 | `pip install deerflow-app` | `PYTHONPATH=. uvicorn app.gateway.app:app` |
App 的唯一消费者是 DeerFlow 项目自身,没有独立发布的需求。放在 `backend/app/` 下作为普通 Python 包,通过 `PYTHONPATH` 或 editable install 让 Python 找到即可。
### 3.4 依赖关系
```
┌─────────────────────────────────────┐
│ app/ (不打包,直接运行) │
│ ├── fastapi, uvicorn │
│ ├── slack-sdk, lark-oapi, ... │
│ └── import deerflow.* │
└──────────────┬──────────────────────┘
┌─────────────────────────────────────┐
│ deerflow-harness (可发布的包) │
│ ├── langgraph, langchain │
│ ├── markitdown, pydantic, ... │
│ └── 零 app 依赖 │
└─────────────────────────────────────┘
```
**依赖分类**
| 分类 | 依赖包 |
|------|--------|
| Harness only | agent-sandbox, langchain*, langgraph*, markdownify, markitdown, pydantic, pyyaml, readabilipy, tavily-python, firecrawl-py, tiktoken, ddgs, duckdb, httpx, kubernetes, dotenv |
| App only | fastapi, uvicorn, sse-starlette, python-multipart, lark-oapi, slack-sdk, python-telegram-bot, markdown-to-mrkdwn |
| Shared | langgraph-sdkchannels 用 HTTP client, pydantic, httpx |
### 3.5 Workspace 配置
`backend/pyproject.toml`workspace root):
```toml
[project]
name = "deer-flow"
version = "0.1.0"
requires-python = ">=3.12"
dependencies = ["deerflow-harness"]
[dependency-groups]
dev = ["pytest>=8.0.0", "ruff>=0.14.11"]
# App 的额外依赖(fastapi 等)也声明在 workspace root,因为 app 不打包
app = ["fastapi", "uvicorn", "sse-starlette", "python-multipart"]
channels = ["lark-oapi", "slack-sdk", "python-telegram-bot"]
[tool.uv.workspace]
members = ["packages/harness"]
[tool.uv.sources]
deerflow-harness = { workspace = true }
```
## 4. 当前的跨层依赖问题
在拆分之前,需要先解决 `client.py` 中两处从 harness 到 app 的反向依赖:
### 4.1 `_validate_skill_frontmatter`
```python
# client.py — harness 导入了 app 层代码
from src.gateway.routers.skills import _validate_skill_frontmatter
```
**解决方案**:将该函数提取到 `deerflow/skills/validation.py`。这是一个纯逻辑函数(解析 YAML frontmatter、校验字段),与 FastAPI 无关。
### 4.2 `CONVERTIBLE_EXTENSIONS` + `convert_file_to_markdown`
```python
# client.py — harness 导入了 app 层代码
from src.gateway.routers.uploads import CONVERTIBLE_EXTENSIONS, convert_file_to_markdown
```
**解决方案**:将它们提取到 `deerflow/utils/file_conversion.py`。仅依赖 `markitdown` + `pathlib`,是通用工具函数。
## 5. 基础设施变更
### 5.1 LangGraph Server
LangGraph Server 只需要 harness 包。`langgraph.json` 更新:
```json
{
"dependencies": ["./packages/harness"],
"graphs": {
"lead_agent": "deerflow.agents:make_lead_agent"
},
"checkpointer": {
"path": "./packages/harness/deerflow/runtime/checkpointer/async_provider.py:make_checkpointer"
}
}
```
### 5.2 Gateway API
```bash
# serve.sh / Makefile
# PYTHONPATH 包含 backend/ 根目录,使 app.* 和 deerflow.* 都能被找到
PYTHONPATH=. uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
```
### 5.3 Nginx
无需变更(只做 URL 路由,不涉及 Python 模块路径)。
### 5.4 Docker
Dockerfile 中的 module 引用从 `src.` 改为 `deerflow.` / `app.``COPY` 命令需覆盖 `packages/``app/` 目录。
## 6. 实施计划
分 3 个 PR 递进执行:
### PR 1:提取共享工具函数(Low Risk)
1. 创建 `src/skills/validation.py`,从 `gateway/routers/skills.py` 提取 `_validate_skill_frontmatter`
2. 创建 `src/utils/file_conversion.py`,从 `gateway/routers/uploads.py` 提取文件转换逻辑
3. 更新 `client.py``gateway/routers/skills.py``gateway/routers/uploads.py` 的 import
4. 运行全部测试确认无回归
### PR 2Rename + 物理拆分(High Risk,原子操作)
1. 创建 `packages/harness/` 目录,创建 `pyproject.toml`
2. `git mv` 将 harness 相关模块从 `src/` 移入 `packages/harness/deerflow/`
3. `git mv` 将 app 相关模块从 `src/` 移入 `app/`
4. 全局替换 import
- harness 模块:`src.*``deerflow.*`(所有 `.py` 文件、`langgraph.json`、测试、文档)
- app 模块:`src.gateway.*``app.gateway.*``src.channels.*``app.channels.*`
5. 更新 workspace root `pyproject.toml`
6. 更新 `langgraph.json``Makefile``Dockerfile`
7. `uv sync` + 全部测试 + 手动验证服务启动
### PR 3:边界检查 + 文档(Low Risk
1. 添加 lint 规则:检查 harness 不 import app 模块
2. 更新 `CLAUDE.md``README.md`
## 7. 风险与缓解
| 风险 | 影响 | 缓解措施 |
|------|------|----------|
| 全局 rename 误伤 | 字符串中的 `src` 被错误替换 | 正则精确匹配 `\bsrc\.`review diff |
| LangGraph Server 找不到模块 | 服务启动失败 | `langgraph.json``dependencies` 指向正确的 harness 包路径 |
| App 的 `PYTHONPATH` 缺失 | Gateway/Channel 启动 import 报错 | Makefile/Docker 统一设置 `PYTHONPATH=.` |
| `config.yaml` 中的 `use` 字段引用旧路径 | 运行时模块解析失败 | `config.yaml` 中的 `use` 字段同步更新为 `deerflow.*` |
| 测试中 `sys.path` 混乱 | 测试失败 | 用 editable install`uv sync`)确保 deerflow 可导入,`conftest.py` 中添加 `app/``sys.path` |
## 8. 未来演进
- **独立发布**harness 可以发布到内部 PyPI,让其他项目直接 `pip install deerflow-harness`
- **插件化 App**:不同的 appweb、CLI、bot)可以各自独立,都依赖同一个 harness
- **更细粒度拆分**:如果 harness 内部模块继续增长,可以进一步拆分(如 `deerflow-sandbox``deerflow-mcp`
+13 -7
View File
@@ -23,6 +23,9 @@ DeerFlow uses a YAML configuration file that should be placed in the **project r
# Option A: Set environment variables (recommended)
export OPENAI_API_KEY="your-key-here"
# Optional: pin the project root when running from another directory
export DEER_FLOW_PROJECT_ROOT="/path/to/deer-flow"
# Option B: Edit config.yaml directly
vim config.yaml # or your preferred editor
```
@@ -35,17 +38,20 @@ DeerFlow uses a YAML configuration file that should be placed in the **project r
## Important Notes
- **Location**: `config.yaml` should be in `deer-flow/` (project root), not `deer-flow/backend/`
- **Location**: `config.yaml` should be in `deer-flow/` (project root)
- **Git**: `config.yaml` is automatically ignored by git (contains secrets)
- **Priority**: If both `backend/config.yaml` and `../config.yaml` exist, backend version takes precedence
- **Runtime root**: Set `DEER_FLOW_PROJECT_ROOT` if DeerFlow may start from outside the project root
- **Runtime data**: State defaults to `.deer-flow` under the project root; set `DEER_FLOW_HOME` to move it
- **Skills**: Skills default to `skills/` under the project root; set `DEER_FLOW_SKILLS_PATH` or `skills.path` to move them
## Configuration File Locations
The backend searches for `config.yaml` in this order:
1. `DEER_FLOW_CONFIG_PATH` environment variable (if set)
2. `backend/config.yaml` (current directory when running from backend/)
3. `deer-flow/config.yaml` (parent directory - **recommended location**)
1. Explicit `config_path` argument from code
2. `DEER_FLOW_CONFIG_PATH` environment variable (if set)
3. `config.yaml` under `DEER_FLOW_PROJECT_ROOT`, or the current working directory when `DEER_FLOW_PROJECT_ROOT` is unset
4. Legacy backend/repository-root locations for monorepo compatibility
**Recommended**: Place `config.yaml` in project root (`deer-flow/config.yaml`).
@@ -77,8 +83,8 @@ python -c "from deerflow.config.app_config import AppConfig; print(AppConfig.res
If it can't find the config:
1. Ensure you've copied `config.example.yaml` to `config.yaml`
2. Verify you're in the correct directory
3. Check the file exists: `ls -la ../config.yaml`
2. Verify you're in the project root, or set `DEER_FLOW_PROJECT_ROOT`
3. Check the file exists: `ls -la config.yaml`
### Permission denied
@@ -254,9 +254,11 @@ def _assemble_from_features(
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
chain.append(ViewImageMiddleware())
from deerflow.tools.builtins import view_image_tool
extra_tools.append(view_image_tool)
if feat.sandbox is not False:
from deerflow.tools.builtins import view_image_tool
extra_tools.append(view_image_tool)
# --- [11] Subagent ---
if feat.subagent is not False:
@@ -18,9 +18,7 @@ from deerflow.agents.middlewares.tool_error_handling_middleware import build_lea
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import get_app_config
from deerflow.config.memory_config import get_memory_config
from deerflow.config.summarization_config import get_summarization_config
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@@ -35,9 +33,9 @@ def _get_runtime_config(config: RunnableConfig) -> dict:
return cfg
def _resolve_model_name(requested_model_name: str | None = None) -> str:
def _resolve_model_name(requested_model_name: str | None = None, *, app_config: AppConfig | None = None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
app_config = get_app_config()
app_config = app_config or get_app_config()
default_model_name = app_config.models[0].name if app_config.models else None
if default_model_name is None:
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
@@ -50,9 +48,10 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
return default_model_name
def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None:
def _create_summarization_middleware(*, app_config: AppConfig | None = None) -> DeerFlowSummarizationMiddleware | None:
"""Create and configure the summarization middleware from config."""
config = get_summarization_config()
resolved_app_config = app_config or get_app_config()
config = resolved_app_config.summarization
if not config.enabled:
return None
@@ -73,9 +72,9 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
# as middleware rather than lead_agent (SummarizationMiddleware is a
# LangChain built-in, so we tag the model at creation time).
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=resolved_app_config)
else:
model = create_chat_model(thinking_enabled=False)
model = create_chat_model(thinking_enabled=False, app_config=resolved_app_config)
model = model.with_config(tags=["middleware:summarize"])
# Prepare kwargs
@@ -92,17 +91,13 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
kwargs["summary_prompt"] = config.summary_prompt
hooks: list[BeforeSummarizationHook] = []
if get_memory_config().enabled:
if resolved_app_config.memory.enabled:
hooks.append(memory_flush_hook)
# The logic below relies on two assumptions holding true: this factory is
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
# config is not expected to change after startup.
try:
skills_container_path = get_app_config().skills.container_path or "/mnt/skills"
except Exception:
logger.exception("Failed to resolve skills container path; falling back to default")
skills_container_path = "/mnt/skills"
skills_container_path = resolved_app_config.skills.container_path or "/mnt/skills"
return DeerFlowSummarizationMiddleware(
**kwargs,
@@ -240,7 +235,14 @@ Being proactive with task management demonstrates thoroughness and ensures all r
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
# ClarificationMiddleware should be last to intercept clarification requests after model calls
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None):
def _build_middlewares(
config: RunnableConfig,
model_name: str | None,
agent_name: str | None = None,
custom_middlewares: list[AgentMiddleware] | None = None,
*,
app_config: AppConfig | None = None,
):
"""Build middleware chain based on runtime configuration.
Args:
@@ -251,10 +253,11 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
Returns:
List of middleware instances.
"""
middlewares = build_lead_runtime_middlewares(lazy_init=True)
resolved_app_config = app_config or get_app_config()
middlewares = build_lead_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
# Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware()
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
if summarization_middleware is not None:
middlewares.append(summarization_middleware)
@@ -266,24 +269,23 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
middlewares.append(todo_list_middleware)
# Add TokenUsageMiddleware when token_usage tracking is enabled
if get_app_config().token_usage.enabled:
if resolved_app_config.token_usage.enabled:
middlewares.append(TokenUsageMiddleware())
# Add TitleMiddleware
middlewares.append(TitleMiddleware())
middlewares.append(TitleMiddleware(app_config=resolved_app_config))
# Add MemoryMiddleware (after TitleMiddleware)
middlewares.append(MemoryMiddleware(agent_name=agent_name))
middlewares.append(MemoryMiddleware(agent_name=agent_name, memory_config=resolved_app_config.memory))
# Add ViewImageMiddleware only if the current model supports vision.
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
app_config = get_app_config()
model_config = app_config.get_model_config(model_name) if model_name else None
model_config = resolved_app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware())
# Add DeferredToolFilterMiddleware to hide deferred tool schemas from model binding
if app_config.tool_search.enabled:
if resolved_app_config.tool_search.enabled:
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
middlewares.append(DeferredToolFilterMiddleware())
@@ -307,11 +309,19 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
def make_lead_agent(config: RunnableConfig):
"""LangGraph graph factory; keep the signature compatible with LangGraph Server."""
runtime_config = _get_runtime_config(config)
runtime_app_config = runtime_config.get("app_config")
return _make_lead_agent(config, app_config=runtime_app_config or get_app_config())
def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
# Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent
cfg = _get_runtime_config(config)
resolved_app_config = app_config
thinking_enabled = cfg.get("thinking_enabled", True)
reasoning_effort = cfg.get("reasoning_effort", None)
@@ -327,10 +337,9 @@ def make_lead_agent(config: RunnableConfig):
agent_model_name = agent_config.model if agent_config and agent_config.model else None
# Final model name resolution: request → agent config → global default, with fallback for unknown names
model_name = _resolve_model_name(requested_model_name or agent_model_name)
model_name = _resolve_model_name(requested_model_name or agent_model_name, app_config=resolved_app_config)
app_config = get_app_config()
model_config = app_config.get_model_config(model_name)
model_config = resolved_app_config.get_model_config(model_name)
if model_config is None:
raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.")
@@ -369,20 +378,34 @@ def make_lead_agent(config: RunnableConfig):
if is_bootstrap:
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent],
middleware=_build_middlewares(config, model_name=model_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent],
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
available_skills=set(["bootstrap"]),
app_config=resolved_app_config,
),
state_schema=ThreadState,
)
# Default lead agent (unchanged behavior)
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config),
tools=get_available_tools(
model_name=model_name,
groups=agent_config.tool_groups if agent_config else None,
subagent_enabled=subagent_enabled,
app_config=resolved_app_config,
),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
agent_name=agent_name,
available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None,
app_config=resolved_app_config,
),
state_schema=ThreadState,
)
@@ -1,14 +1,20 @@
from __future__ import annotations
import asyncio
import logging
import threading
from datetime import datetime
from functools import lru_cache
from typing import TYPE_CHECKING
from deerflow.config.agents_config import load_agent_soul
from deerflow.skills import load_skills
from deerflow.skills.types import Skill
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.types import Skill, SkillCategory
from deerflow.subagents import get_available_subagent_names
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
@@ -20,7 +26,7 @@ _enabled_skills_refresh_event = threading.Event()
def _load_enabled_skills_sync() -> list[Skill]:
return list(load_skills(enabled_only=True))
return list(get_or_new_skill_storage().load_skills(enabled_only=True))
def _start_enabled_skills_refresh_thread() -> None:
@@ -111,8 +117,21 @@ def _get_enabled_skills():
return []
def _skill_mutability_label(category: str) -> str:
return "[custom, editable]" if category == "custom" else "[built-in]"
def _get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
"""Return enabled skills using the caller's config source.
When a concrete ``app_config`` is supplied, bypass the global enabled-skills
cache so the skill list and skill paths are resolved from the same config
object. This keeps request-scoped config injection consistent even while the
release branch still supports global fallback paths.
"""
if app_config is None:
return _get_enabled_skills()
return list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
def _skill_mutability_label(category: SkillCategory | str) -> str:
return "[custom, editable]" if category == SkillCategory.CUSTOM else "[built-in]"
def clear_skills_system_prompt_cache() -> None:
@@ -139,7 +158,7 @@ Skip simple one-off tasks.
"""
def _build_available_subagents_description(available_names: list[str], bash_available: bool) -> str:
def _build_available_subagents_description(available_names: list[str], bash_available: bool, *, app_config: AppConfig | None = None) -> str:
"""Dynamically build subagent type descriptions from registry.
Mirrors Codex's pattern where agent_type_description is dynamically generated
@@ -161,7 +180,7 @@ def _build_available_subagents_description(available_names: list[str], bash_avai
if name in builtin_descriptions:
lines.append(f"- **{name}**: {builtin_descriptions[name]}")
else:
config = get_subagent_config(name)
config = get_subagent_config(name, app_config=app_config)
if config is not None:
desc = config.description.split("\n")[0].strip() # First line only for brevity
lines.append(f"- **{name}**: {desc}")
@@ -169,7 +188,7 @@ def _build_available_subagents_description(available_names: list[str], bash_avai
return "\n".join(lines)
def _build_subagent_section(max_concurrent: int) -> str:
def _build_subagent_section(max_concurrent: int, *, app_config: AppConfig | None = None) -> str:
"""Build the subagent system prompt section with dynamic concurrency limit.
Args:
@@ -179,12 +198,12 @@ def _build_subagent_section(max_concurrent: int) -> str:
Formatted subagent section string.
"""
n = max_concurrent
available_names = get_available_subagent_names()
available_names = get_available_subagent_names(app_config=app_config) if app_config is not None else get_available_subagent_names()
bash_available = "bash" in available_names
# Dynamically build subagent type descriptions from registry (aligned with Codex's
# agent_type_description pattern where all registered roles are listed in the tool spec).
available_subagents = _build_available_subagents_description(available_names, bash_available)
available_subagents = _build_available_subagents_description(available_names, bash_available, app_config=app_config)
direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc."
direct_execution_example = (
'# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()'
@@ -511,21 +530,28 @@ combined with a FastAPI gateway for REST API access [citation:FastAPI](https://f
"""
def _get_memory_context(agent_name: str | None = None) -> str:
def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig | None = None) -> str:
"""Get memory context for injection into system prompt.
Args:
agent_name: If provided, loads per-agent memory. If None, loads global memory.
app_config: Explicit application config. When provided, memory options
are read from this value instead of the global config singleton.
Returns:
Formatted memory context string wrapped in XML tags, or empty string if disabled.
"""
try:
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
from deerflow.config.memory_config import get_memory_config
from deerflow.runtime.user_context import get_effective_user_id
config = get_memory_config()
if app_config is None:
from deerflow.config.memory_config import get_memory_config
config = get_memory_config()
else:
config = app_config.memory
if not config.enabled or not config.injection_enabled:
return ""
@@ -539,8 +565,8 @@ def _get_memory_context(agent_name: str | None = None) -> str:
{memory_content}
</memory>
"""
except Exception as e:
logger.error("Failed to load memory context: %s", e)
except Exception:
logger.exception("Failed to load memory context")
return ""
@@ -576,19 +602,24 @@ You have access to skills that provide optimized workflows for specific tasks. E
</skill_system>"""
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_config: AppConfig | None = None) -> str:
"""Generate the skills prompt section with available skills list."""
skills = _get_enabled_skills()
skills = _get_enabled_skills_for_config(app_config)
try:
from deerflow.config import get_app_config
if app_config is None:
try:
from deerflow.config import get_app_config
config = get_app_config()
config = get_app_config()
container_base_path = config.skills.container_path
skill_evolution_enabled = config.skill_evolution.enabled
except Exception:
container_base_path = "/mnt/skills"
skill_evolution_enabled = False
else:
config = app_config
container_base_path = config.skills.container_path
skill_evolution_enabled = config.skill_evolution.enabled
except Exception:
container_base_path = "/mnt/skills"
skill_evolution_enabled = False
if not skills and not skill_evolution_enabled:
return ""
@@ -612,7 +643,7 @@ def get_agent_soul(agent_name: str | None) -> str:
return ""
def get_deferred_tools_prompt_section() -> str:
def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> str:
"""Generate <available-deferred-tools> block for the system prompt.
Lists only deferred tool names so the agent knows what exists
@@ -621,12 +652,17 @@ def get_deferred_tools_prompt_section() -> str:
"""
from deerflow.tools.builtins.tool_search import get_deferred_registry
try:
from deerflow.config import get_app_config
if app_config is None:
try:
from deerflow.config import get_app_config
if not get_app_config().tool_search.enabled:
config = get_app_config()
except Exception:
return ""
except Exception:
else:
config = app_config
if not config.tool_search.enabled:
return ""
registry = get_deferred_registry()
@@ -637,15 +673,19 @@ def get_deferred_tools_prompt_section() -> str:
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
def _build_acp_section() -> str:
def _build_acp_section(*, app_config: AppConfig | None = None) -> str:
"""Build the ACP agent prompt section, only if ACP agents are configured."""
try:
from deerflow.config.acp_config import get_acp_agents
if app_config is None:
try:
from deerflow.config.acp_config import get_acp_agents
agents = get_acp_agents()
if not agents:
agents = get_acp_agents()
except Exception:
return ""
except Exception:
else:
agents = getattr(app_config, "acp_agents", {}) or {}
if not agents:
return ""
return (
@@ -657,15 +697,20 @@ def _build_acp_section() -> str:
)
def _build_custom_mounts_section() -> str:
def _build_custom_mounts_section(*, app_config: AppConfig | None = None) -> str:
"""Build a prompt section for explicitly configured sandbox mounts."""
try:
from deerflow.config import get_app_config
if app_config is None:
try:
from deerflow.config import get_app_config
mounts = get_app_config().sandbox.mounts or []
except Exception:
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
return ""
config = get_app_config()
except Exception:
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
return ""
else:
config = app_config
mounts = config.sandbox.mounts or []
if not mounts:
return ""
@@ -679,13 +724,20 @@ def _build_custom_mounts_section() -> str:
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
def apply_prompt_template(
subagent_enabled: bool = False,
max_concurrent_subagents: int = 3,
*,
agent_name: str | None = None,
available_skills: set[str] | None = None,
app_config: AppConfig | None = None,
) -> str:
# Get memory context
memory_context = _get_memory_context(agent_name)
memory_context = _get_memory_context(agent_name, app_config=app_config)
# Include subagent section only if enabled (from runtime parameter)
n = max_concurrent_subagents
subagent_section = _build_subagent_section(n) if subagent_enabled else ""
subagent_section = _build_subagent_section(n, app_config=app_config) if subagent_enabled else ""
# Add subagent reminder to critical_reminders if enabled
subagent_reminder = (
@@ -706,14 +758,14 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
)
# Get skills section
skills_section = get_skills_prompt_section(available_skills)
skills_section = get_skills_prompt_section(available_skills, app_config=app_config)
# Get deferred tools section (tool_search)
deferred_tools_section = get_deferred_tools_prompt_section()
deferred_tools_section = get_deferred_tools_prompt_section(app_config=app_config)
# Build ACP agent section only if ACP agents are configured
acp_section = _build_acp_section()
custom_mounts_section = _build_custom_mounts_section()
acp_section = _build_acp_section(app_config=app_config)
custom_mounts_section = _build_custom_mounts_section(app_config=app_config)
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
# Format the prompt with dynamic skills and memory
@@ -9,7 +9,6 @@ import logging
import math
import re
import uuid
from collections.abc import Awaitable
from typing import Any
from deerflow.agents.memory.prompt import (
@@ -26,6 +25,12 @@ from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
# Thread pool for offloading sync memory updates when called from an async
# context. Unlike the previous asyncio.run() approach, this runs *sync*
# model.invoke() calls — no event loop is created, so the langchain async
# httpx client pool (globally cached via @lru_cache) is never touched and
# cross-loop connection reuse is impossible.
_SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
max_workers=4,
thread_name_prefix="memory-updater-sync",
@@ -222,39 +227,6 @@ def _extract_text(content: Any) -> str:
return str(content)
def _run_async_update_sync(coro: Awaitable[bool]) -> bool:
"""Run an async memory update from sync code, including nested-loop contexts."""
handed_off = False
try:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(asyncio.run, coro)
handed_off = True
return future.result()
handed_off = True
return asyncio.run(coro)
except Exception:
if not handed_off:
close = getattr(coro, "close", None)
if callable(close):
try:
close()
except Exception:
logger.debug(
"Failed to close un-awaited memory update coroutine",
exc_info=True,
)
logger.exception("Failed to run async memory update from sync context")
return False
# Matches sentences that describe a file-upload *event* rather than general
# file-related work. Deliberately narrow to avoid removing legitimate facts
# such as "User works with CSV files" or "prefers PDF export".
@@ -349,13 +321,14 @@ class MemoryUpdater:
agent_name: str | None,
correction_detected: bool,
reinforcement_detected: bool,
user_id: str | None = None,
) -> tuple[dict[str, Any], str] | None:
"""Load memory and build the update prompt for a conversation."""
config = get_memory_config()
if not config.enabled or not messages:
return None
current_memory = get_memory_data(agent_name)
current_memory = get_memory_data(agent_name, user_id=user_id)
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return None
@@ -377,6 +350,7 @@ class MemoryUpdater:
response_content: Any,
thread_id: str | None,
agent_name: str | None,
user_id: str | None = None,
) -> bool:
"""Parse the model response, apply updates, and persist memory."""
response_text = _extract_text(response_content).strip()
@@ -390,7 +364,7 @@ class MemoryUpdater:
# cannot corrupt the still-cached original object reference.
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
return get_memory_storage().save(updated_memory, agent_name)
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
async def aupdate_memory(
self,
@@ -399,28 +373,63 @@ class MemoryUpdater:
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool:
"""Update memory asynchronously based on conversation messages."""
"""Update memory asynchronously by delegating to the sync path.
Uses ``asyncio.to_thread`` to run the *sync* ``model.invoke()`` path
in a worker thread so no second event loop is created and the
langchain async httpx client pool (shared with the lead agent) is
never touched. This eliminates the cross-loop connection-reuse bug
described in issue #2615.
"""
return await asyncio.to_thread(
self._do_update_memory_sync,
messages=messages,
thread_id=thread_id,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
user_id=user_id,
)
def _do_update_memory_sync(
self,
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool:
"""Pure-sync memory update using ``model.invoke()``.
Uses the *sync* LLM call path so no event loop is created. This
guarantees that the langchain provider's globally cached async
httpx ``AsyncClient`` / connection pool (the one shared with the
lead agent) is never touched — no cross-loop connection reuse is
possible.
"""
try:
prepared = await asyncio.to_thread(
self._prepare_update_prompt,
prepared = self._prepare_update_prompt(
messages=messages,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
user_id=user_id,
)
if prepared is None:
return False
current_memory, prompt = prepared
model = self._get_model()
response = await model.ainvoke(prompt, config={"run_name": "memory_agent"})
return await asyncio.to_thread(
self._finalize_update,
response = model.invoke(prompt, config={"run_name": "memory_agent"})
return self._finalize_update(
current_memory=current_memory,
response_content=response.content,
thread_id=thread_id,
agent_name=agent_name,
user_id=user_id,
)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
@@ -438,7 +447,16 @@ class MemoryUpdater:
reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool:
"""Synchronously update memory via the async updater path.
"""Synchronously update memory using the sync LLM path.
Uses ``model.invoke()`` (sync HTTP) which operates on a completely
separate connection pool from the async ``AsyncClient`` shared by
the lead agent. This eliminates the cross-loop connection-reuse
bug described in issue #2615.
When called from within a running event loop (e.g. from a LangGraph
node), the blocking sync call is offloaded to a thread pool so the
caller's loop is not blocked.
Args:
messages: List of conversation messages.
@@ -451,14 +469,34 @@ class MemoryUpdater:
Returns:
True if update was successful, False otherwise.
"""
return _run_async_update_sync(
self.aupdate_memory(
messages=messages,
thread_id=thread_id,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
try:
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(
self._do_update_memory_sync,
messages=messages,
thread_id=thread_id,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
user_id=user_id,
)
return future.result()
except Exception:
logger.exception("Failed to offload memory update to executor")
return False
return self._do_update_memory_sync(
messages=messages,
thread_id=thread_id,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
user_id=user_id,
)
def _apply_updates(
@@ -20,7 +20,7 @@ from langchain.agents.middleware.types import (
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
@@ -70,20 +70,11 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
retry_base_delay_ms: int = 1000
retry_cap_delay_ms: int = 8000
circuit_failure_threshold: int = 5
circuit_recovery_timeout_sec: int = 60
def __init__(self, **kwargs: Any) -> None:
def __init__(self, *, app_config: AppConfig, **kwargs: Any) -> None:
super().__init__(**kwargs)
# Load Circuit Breaker configs from app config if available, fall back to defaults
try:
app_config = get_app_config()
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
except (FileNotFoundError, RuntimeError):
# Gracefully fall back to class defaults in test environments
pass
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
# Circuit Breaker state
self._circuit_lock = threading.Lock()
@@ -1,7 +1,7 @@
"""Middleware for memory mechanism."""
import logging
from typing import override
from typing import TYPE_CHECKING, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
@@ -13,6 +13,9 @@ from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.memory_config import get_memory_config
from deerflow.runtime.user_context import get_effective_user_id
if TYPE_CHECKING:
from deerflow.config.memory_config import MemoryConfig
logger = logging.getLogger(__name__)
@@ -34,14 +37,17 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
state_schema = MemoryMiddlewareState
def __init__(self, agent_name: str | None = None):
def __init__(self, agent_name: str | None = None, *, memory_config: "MemoryConfig | None" = None):
"""Initialize the MemoryMiddleware.
Args:
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
memory_config: Explicit memory config. When omitted, legacy global
config fallback is used.
"""
super().__init__()
self._agent_name = agent_name
self._memory_config = memory_config
@override
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
@@ -54,7 +60,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
Returns:
None (no state changes needed from this middleware).
"""
config = get_memory_config()
config = self._memory_config or get_memory_config()
if not config.enabled:
return None
@@ -2,7 +2,7 @@
import logging
import re
from typing import Any, NotRequired, override
from typing import TYPE_CHECKING, Any, NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
@@ -12,6 +12,10 @@ from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config
from deerflow.models import create_chat_model
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
from deerflow.config.title_config import TitleConfig
logger = logging.getLogger(__name__)
@@ -26,6 +30,18 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
state_schema = TitleMiddlewareState
def __init__(self, *, app_config: "AppConfig | None" = None, title_config: "TitleConfig | None" = None):
super().__init__()
self._app_config = app_config
self._title_config = title_config
def _get_title_config(self):
if self._title_config is not None:
return self._title_config
if self._app_config is not None:
return self._app_config.title
return get_title_config()
def _normalize_content(self, content: object) -> str:
if isinstance(content, str):
return content
@@ -47,7 +63,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
"""Check if we should generate a title for this thread."""
config = get_title_config()
config = self._get_title_config()
if not config.enabled:
return False
@@ -72,7 +88,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
"""
config = get_title_config()
config = self._get_title_config()
messages = state.get("messages", [])
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
@@ -94,14 +110,14 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
def _parse_title(self, content: object) -> str:
"""Normalize model output into a clean title string."""
config = get_title_config()
config = self._get_title_config()
title_content = self._normalize_content(content)
title_content = self._strip_think_tags(title_content)
title = title_content.strip().strip('"').strip("'")
return title[: config.max_chars] if len(title) > config.max_chars else title
def _fallback_title(self, user_msg: str) -> str:
config = get_title_config()
config = self._get_title_config()
fallback_chars = min(config.max_chars, 50)
if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..."
@@ -135,14 +151,17 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
if not self._should_generate_title(state):
return None
config = get_title_config()
config = self._get_title_config()
prompt, user_msg = self._build_title_prompt(state)
try:
model_kwargs = {"thinking_enabled": False}
if self._app_config is not None:
model_kwargs["app_config"] = self._app_config
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
model = create_chat_model(name=config.model_name, **model_kwargs)
else:
model = create_chat_model(thinking_enabled=False)
model = create_chat_model(**model_kwargs)
response = await model.ainvoke(prompt, config=self._get_runnable_config())
title = self._parse_title(response.content)
if title:
@@ -11,6 +11,8 @@ from langgraph.errors import GraphBubbleUp
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
@@ -67,6 +69,7 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
def _build_runtime_middlewares(
*,
app_config: AppConfig,
include_uploads: bool,
include_dangling_tool_call_patch: bool,
lazy_init: bool = True,
@@ -91,12 +94,10 @@ def _build_runtime_middlewares(
middlewares.append(DanglingToolCallMiddleware())
middlewares.append(LLMErrorHandlingMiddleware())
middlewares.append(LLMErrorHandlingMiddleware(app_config=app_config))
# Guardrail middleware (if configured)
from deerflow.config.guardrails_config import get_guardrails_config
guardrails_config = get_guardrails_config()
guardrails_config = app_config.guardrails
if guardrails_config.enabled and guardrails_config.provider:
import inspect
@@ -125,19 +126,42 @@ def _build_runtime_middlewares(
return middlewares
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
def build_lead_runtime_middlewares(*, app_config: AppConfig, lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
return _build_runtime_middlewares(
app_config=app_config,
include_uploads=True,
include_dangling_tool_call_patch=True,
lazy_init=lazy_init,
)
def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
def build_subagent_runtime_middlewares(
*,
app_config: AppConfig | None = None,
model_name: str | None = None,
lazy_init: bool = True,
) -> list[AgentMiddleware]:
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
return _build_runtime_middlewares(
if app_config is None:
from deerflow.config import get_app_config
app_config = get_app_config()
middlewares = _build_runtime_middlewares(
app_config=app_config,
include_uploads=False,
include_dangling_tool_call_patch=True,
lazy_init=lazy_init,
)
if model_name is None and app_config.models:
model_name = app_config.models[0].name
model_config = app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision:
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
middlewares.append(ViewImageMiddleware())
return middlewares
+26 -18
View File
@@ -41,7 +41,7 @@ from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
from deerflow.config.paths import get_paths
from deerflow.models import create_chat_model
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.skills.installer import install_skill_from_archive
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.uploads.manager import (
claim_unique_filename,
delete_file_safe,
@@ -228,14 +228,21 @@ class DeerFlowClient:
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
kwargs: dict[str, Any] = {
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=self._app_config),
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
"middleware": _build_middlewares(
config,
model_name=model_name,
agent_name=self._agent_name,
custom_middlewares=self._middlewares,
app_config=self._app_config,
),
"system_prompt": apply_prompt_template(
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
agent_name=self._agent_name,
available_skills=self._available_skills,
app_config=self._app_config,
),
"state_schema": ThreadState,
}
@@ -243,7 +250,7 @@ class DeerFlowClient:
if checkpointer is None:
from deerflow.runtime.checkpointer import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(app_config=self._app_config)
if checkpointer is not None:
kwargs["checkpointer"] = checkpointer
@@ -251,12 +258,15 @@ class DeerFlowClient:
self._agent_config_key = key
logger.info("Agent created: agent_name=%s, model=%s, thinking=%s", self._agent_name, model_name, thinking_enabled)
@staticmethod
def _get_tools(*, model_name: str | None, subagent_enabled: bool):
def _get_tools(self, *, model_name: str | None, subagent_enabled: bool):
"""Lazy import to avoid circular dependency at module level."""
from deerflow.tools import get_available_tools
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
return get_available_tools(
model_name=model_name,
subagent_enabled=subagent_enabled,
app_config=self._app_config,
)
@staticmethod
def _serialize_tool_calls(tool_calls) -> list[dict]:
@@ -377,7 +387,7 @@ class DeerFlowClient:
if checkpointer is None:
from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(app_config=self._app_config)
thread_info_map = {}
@@ -432,7 +442,7 @@ class DeerFlowClient:
if checkpointer is None:
from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(app_config=self._app_config)
config = {"configurable": {"thread_id": thread_id}}
checkpoints = []
@@ -752,8 +762,6 @@ class DeerFlowClient:
Dict with "skills" key containing list of skill info dicts,
matching the Gateway API ``SkillsListResponse`` schema.
"""
from deerflow.skills.loader import load_skills
return {
"skills": [
{
@@ -763,7 +771,7 @@ class DeerFlowClient:
"category": s.category,
"enabled": s.enabled,
}
for s in load_skills(enabled_only=enabled_only)
for s in get_or_new_skill_storage().load_skills(enabled_only=enabled_only)
]
}
@@ -872,9 +880,9 @@ class DeerFlowClient:
Returns:
Skill info dict, or None if not found.
"""
from deerflow.skills.loader import load_skills
from deerflow.skills.storage import get_or_new_skill_storage
skill = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
skill = next((s for s in get_or_new_skill_storage().load_skills(enabled_only=False) if s.name == name), None)
if skill is None:
return None
return {
@@ -899,9 +907,9 @@ class DeerFlowClient:
ValueError: If the skill is not found.
OSError: If the config file cannot be written.
"""
from deerflow.skills.loader import load_skills
from deerflow.skills.storage import get_or_new_skill_storage
skills = load_skills(enabled_only=False)
skills = get_or_new_skill_storage().load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == name), None)
if skill is None:
raise ValueError(f"Skill '{name}' not found")
@@ -924,7 +932,7 @@ class DeerFlowClient:
self._agent_config_key = None
reload_extensions_config()
updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
updated = next((s for s in get_or_new_skill_storage().load_skills(enabled_only=False) if s.name == name), None)
if updated is None:
raise RuntimeError(f"Skill '{name}' disappeared after update")
return {
@@ -948,7 +956,7 @@ class DeerFlowClient:
FileNotFoundError: If the file does not exist.
ValueError: If the file is invalid.
"""
return install_skill_from_archive(skill_path)
return get_or_new_skill_storage().install_skill_from_archive(skill_path)
# ------------------------------------------------------------------
# Public API — memory management
@@ -48,6 +48,12 @@ class AioSandbox(Sandbox):
self._home_dir = context.home_dir
return self._home_dir
# Default no_change_timeout for exec_command (seconds). Matches the
# client-level timeout so that long-running commands which produce no
# output are not prematurely terminated by the sandbox's built-in 120 s
# default.
_DEFAULT_NO_CHANGE_TIMEOUT = 600
def execute_command(self, command: str) -> str:
"""Execute a shell command in the sandbox.
@@ -66,13 +72,13 @@ class AioSandbox(Sandbox):
"""
with self._lock:
try:
result = self._client.shell.exec_command(command=command)
result = self._client.shell.exec_command(command=command, no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT)
output = result.data.output if result.data else ""
if output and _ERROR_OBSERVATION_SIGNATURE in output:
logger.warning("ErrorObservation detected in sandbox output, retrying with a fresh session")
fresh_id = str(uuid.uuid4())
result = self._client.shell.exec_command(command=command, id=fresh_id)
result = self._client.shell.exec_command(command=command, id=fresh_id, no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT)
output = result.data.output if result.data else ""
return output if output else "(no output)"
@@ -108,7 +114,7 @@ class AioSandbox(Sandbox):
"""
with self._lock:
try:
result = self._client.shell.exec_command(command=f"find {shlex.quote(path)} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500")
result = self._client.shell.exec_command(command=f"find {shlex.quote(path)} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500", no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT)
output = result.data.output if result.data else ""
if output:
return [line.strip() for line in output.strip().split("\n") if line.strip()]
@@ -9,6 +9,7 @@ from __future__ import annotations
import json
import logging
import os
import shlex
import subprocess
from datetime import datetime
@@ -86,6 +87,88 @@ def _format_container_mount(runtime: str, host_path: str, container_path: str, r
return ["-v", mount_spec]
def _redact_container_command_for_log(cmd: list[str]) -> list[str]:
"""Return a Docker/Container command with environment values redacted."""
redacted: list[str] = []
redact_next_env = False
for arg in cmd:
if redact_next_env:
if "=" in arg:
key = arg.split("=", 1)[0]
redacted.append(f"{key}=<redacted>" if key else "<redacted>")
else:
redacted.append(arg)
redact_next_env = False
continue
if arg in {"-e", "--env"}:
redacted.append(arg)
redact_next_env = True
continue
if arg.startswith("--env="):
value = arg.removeprefix("--env=")
if "=" in value:
key = value.split("=", 1)[0]
redacted.append(f"--env={key}=<redacted>" if key else "--env=<redacted>")
else:
redacted.append(arg)
continue
redacted.append(arg)
return redacted
def _format_container_command_for_log(cmd: list[str]) -> str:
if os.name == "nt":
return subprocess.list2cmdline(cmd)
return shlex.join(cmd)
def _normalize_sandbox_host(host: str) -> str:
return host.strip().lower()
def _is_ipv6_loopback_sandbox_host(host: str) -> bool:
return _normalize_sandbox_host(host) in {"::1", "[::1]"}
def _is_loopback_sandbox_host(host: str) -> bool:
return _normalize_sandbox_host(host) in {"", "localhost", "127.0.0.1", "::1", "[::1]"}
def _resolve_docker_bind_host(sandbox_host: str | None = None, bind_host: str | None = None) -> str:
"""Choose the host interface for legacy Docker ``-p`` sandbox publishing.
Bare-metal/local runs talk to sandboxes through localhost and should not
expose the sandbox HTTP API on every host interface. Docker-outside-of-
Docker deployments commonly use ``host.docker.internal`` from another
container; keep their legacy broad bind unless operators opt into a
narrower bind with ``DEER_FLOW_SANDBOX_BIND_HOST``. When operators choose
an IPv6 loopback sandbox host, bind Docker to IPv6 loopback as well so the
advertised sandbox URL and published socket use the same address family.
"""
explicit_bind = bind_host if bind_host is not None else os.environ.get("DEER_FLOW_SANDBOX_BIND_HOST")
if explicit_bind is not None:
explicit_bind = explicit_bind.strip()
if explicit_bind:
logger.debug("Docker sandbox bind: %s (explicit bind host override)", explicit_bind)
return explicit_bind
host = sandbox_host if sandbox_host is not None else os.environ.get("DEER_FLOW_SANDBOX_HOST", "localhost")
if _is_ipv6_loopback_sandbox_host(host):
logger.debug("Docker sandbox bind: [::1] (IPv6 loopback sandbox host)")
return "[::1]"
if _is_loopback_sandbox_host(host):
logger.debug("Docker sandbox bind: 127.0.0.1 (loopback default)")
return "127.0.0.1"
logger.debug("Docker sandbox bind: 0.0.0.0 (non-loopback sandbox host compatibility)")
return "0.0.0.0"
class LocalContainerBackend(SandboxBackend):
"""Backend that manages sandbox containers locally using Docker or Apple Container.
@@ -424,12 +507,17 @@ class LocalContainerBackend(SandboxBackend):
if self._runtime == "docker":
cmd.extend(["--security-opt", "seccomp=unconfined"])
if self._runtime == "docker":
port_mapping = f"{_resolve_docker_bind_host()}:{port}:8080"
else:
port_mapping = f"{port}:8080"
cmd.extend(
[
"--rm",
"-d",
"-p",
f"{port}:8080",
port_mapping,
"--name",
container_name,
]
@@ -464,7 +552,8 @@ class LocalContainerBackend(SandboxBackend):
cmd.append(self._image)
logger.info(f"Starting container using {self._runtime}: {' '.join(cmd)}")
log_cmd = _format_container_command_for_log(_redact_container_command_for_log(cmd))
logger.info(f"Starting container using {self._runtime}: {log_cmd}")
try:
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
@@ -8,7 +8,7 @@ import yaml
from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.database_config import DatabaseConfig
@@ -17,6 +17,7 @@ from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
from deerflow.config.runtime_paths import existing_project_file
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
from deerflow.config.skills_config import SkillsConfig
@@ -46,17 +47,41 @@ class CircuitBreakerConfig(BaseModel):
recovery_timeout_sec: int = Field(default=60, description="Time in seconds before attempting to recover the circuit")
def _default_config_candidates() -> tuple[Path, ...]:
"""Return deterministic config.yaml locations without relying on cwd."""
def _legacy_config_candidates() -> tuple[Path, ...]:
"""Return source-tree config.yaml locations for monorepo compatibility."""
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
return (backend_dir / "config.yaml", repo_root / "config.yaml")
def logging_level_from_config(name: str | None) -> int:
"""Map ``config.yaml`` ``log_level`` string to a :mod:`logging` level constant."""
mapping = logging.getLevelNamesMapping()
return mapping.get((name or "info").strip().upper(), logging.INFO)
def apply_logging_level(name: str | None) -> None:
"""Resolve *name* to a logging level and apply it to the ``deerflow``/``app`` logger hierarchies.
Only the ``deerflow`` and ``app`` logger levels are changed so that
third-party library verbosity (e.g. uvicorn, sqlalchemy) is not
affected. Root handler levels are lowered (never raised) so that
messages from the configured loggers can propagate through without
being filtered, while preserving handler thresholds that may be
intentionally restrictive for third-party log output.
"""
level = logging_level_from_config(name)
for logger_name in ("deerflow", "app"):
logging.getLogger(logger_name).setLevel(level)
for handler in logging.root.handlers:
if level < handler.level:
handler.setLevel(level)
class AppConfig(BaseModel):
"""Config for the DeerFlow application"""
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
log_level: str = Field(default="info", description="Logging level for deerflow and app modules (debug/info/warning/error); third-party libraries are not affected")
token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration")
models: list[ModelConfig] = Field(default_factory=list, description="Available models")
sandbox: SandboxConfig = Field(description="Sandbox configuration")
@@ -70,10 +95,11 @@ class AppConfig(BaseModel):
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
agents_api: AgentsApiConfig = Field(default_factory=AgentsApiConfig, description="Custom-agent management API configuration")
acp_agents: dict[str, ACPAgentConfig] = Field(default_factory=dict, description="ACP-compatible agent configuration")
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
model_config = ConfigDict(extra="allow", frozen=False)
model_config = ConfigDict(extra="allow")
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
@@ -86,7 +112,8 @@ class AppConfig(BaseModel):
Priority:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
3. Otherwise, search the caller project root.
4. Finally, search legacy backend/repository-root defaults for monorepo compatibility.
"""
if config_path:
path = Path(config_path)
@@ -99,10 +126,14 @@ class AppConfig(BaseModel):
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
return path
else:
for path in _default_config_candidates():
project_config = existing_project_file(("config.yaml",))
if project_config is not None:
return project_config
for path in _legacy_config_candidates():
if path.exists():
return path
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
raise FileNotFoundError("`config.yaml` file not found in the project root or legacy backend/repository root locations")
@classmethod
def from_file(cls, config_path: str | None = None) -> Self:
@@ -292,6 +323,9 @@ class AppConfig(BaseModel):
return next((group for group in self.tool_groups if group.name == name), None)
# Compatibility singleton layer for code paths that have not yet been
# migrated to explicit ``AppConfig`` threading. New composition roots should
# prefer constructing ``AppConfig`` once and passing it down directly.
_app_config: AppConfig | None = None
_app_config_path: Path | None = None
_app_config_mtime: float | None = None
@@ -7,6 +7,8 @@ from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.runtime_paths import existing_project_file
class McpOAuthConfig(BaseModel):
"""OAuth configuration for an MCP server (HTTP/SSE transports)."""
@@ -73,8 +75,8 @@ class ExtensionsConfig(BaseModel):
Priority:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
3. Otherwise, check for `extensions_config.json` in the current directory, then in the parent directory.
4. For backward compatibility, also check for `mcp_config.json` if `extensions_config.json` is not found.
3. Otherwise, search the caller project root for `extensions_config.json`, then `mcp_config.json`.
4. For backward compatibility, also search legacy backend/repository-root defaults.
5. If not found, return None (extensions are optional).
Args:
@@ -83,8 +85,9 @@ class ExtensionsConfig(BaseModel):
Resolution order:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
3. Otherwise, search backend/repository-root defaults for
3. Otherwise, search the caller project root for
`extensions_config.json`, then legacy `mcp_config.json`.
4. Finally, search backend/repository-root defaults for monorepo compatibility.
Returns:
Path to the extensions config file if found, otherwise None.
@@ -100,6 +103,10 @@ class ExtensionsConfig(BaseModel):
raise FileNotFoundError(f"Extensions config file specified by environment variable `DEER_FLOW_EXTENSIONS_CONFIG_PATH` not found at {path}")
return path
else:
project_config = existing_project_file(("extensions_config.json", "mcp_config.json"))
if project_config is not None:
return project_config
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
for path in (
@@ -3,6 +3,8 @@ import re
import shutil
from pathlib import Path, PureWindowsPath
from deerflow.config.runtime_paths import runtime_home
# Virtual path prefix seen by agents inside the sandbox
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
@@ -11,9 +13,8 @@ _SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
def _default_local_base_dir() -> Path:
"""Return the repo-local DeerFlow state directory without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4]
return backend_dir / ".deer-flow"
"""Return the caller project's writable DeerFlow state directory."""
return runtime_home()
def _validate_thread_id(thread_id: str) -> str:
@@ -81,7 +82,7 @@ class Paths:
BaseDir resolution (in priority order):
1. Constructor argument `base_dir`
2. DEER_FLOW_HOME environment variable
3. Repo-local fallback derived from this module path: `{backend_dir}/.deer-flow`
3. Caller project fallback: `{project_root}/.deer-flow`
"""
def __init__(self, base_dir: str | Path | None = None) -> None:
@@ -0,0 +1,41 @@
"""Runtime path resolution for standalone harness usage."""
import os
from pathlib import Path
def project_root() -> Path:
"""Return the caller project root for runtime-owned files."""
if env_root := os.getenv("DEER_FLOW_PROJECT_ROOT"):
root = Path(env_root).resolve()
if not root.exists():
raise ValueError(f"DEER_FLOW_PROJECT_ROOT is set to '{env_root}', but the resolved path '{root}' does not exist.")
if not root.is_dir():
raise ValueError(f"DEER_FLOW_PROJECT_ROOT is set to '{env_root}', but the resolved path '{root}' is not a directory.")
return root
return Path.cwd().resolve()
def runtime_home() -> Path:
"""Return the writable DeerFlow state directory."""
if env_home := os.getenv("DEER_FLOW_HOME"):
return Path(env_home).resolve()
return project_root() / ".deer-flow"
def resolve_path(value: str | os.PathLike[str], *, base: Path | None = None) -> Path:
"""Resolve absolute paths as-is and relative paths against the project root."""
path = Path(value)
if not path.is_absolute():
path = (base or project_root()) / path
return path.resolve()
def existing_project_file(names: tuple[str, ...]) -> Path | None:
"""Return the first existing named file under the project root."""
root = project_root()
for name in names:
candidate = root / name
if candidate.is_file():
return candidate
return None
@@ -1,19 +1,21 @@
import os
from pathlib import Path
from pydantic import BaseModel, Field
def _default_repo_root() -> Path:
"""Resolve the repo root without relying on the current working directory."""
return Path(__file__).resolve().parents[5]
from deerflow.config.runtime_paths import project_root, resolve_path
class SkillsConfig(BaseModel):
"""Configuration for skills system"""
use: str = Field(
default="deerflow.skills.storage.local_skill_storage:LocalSkillStorage",
description="Class path of the SkillStorage implementation.",
)
path: str | None = Field(
default=None,
description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory",
description="Path to skills directory. If not specified, defaults to skills under the caller project root.",
)
container_path: str = Field(
default="/mnt/skills",
@@ -28,17 +30,11 @@ class SkillsConfig(BaseModel):
Path to the skills directory
"""
if self.path:
# Use configured path (can be absolute or relative)
path = Path(self.path)
if not path.is_absolute():
# If relative, resolve from the repo root for deterministic behavior.
path = _default_repo_root() / path
return path.resolve()
else:
# Default: ../skills relative to backend directory
from deerflow.skills.loader import get_skills_root_path
return get_skills_root_path()
# Use configured path (can be absolute or relative to project root)
return resolve_path(self.path)
if env_path := os.getenv("DEER_FLOW_SKILLS_PATH"):
return resolve_path(env_path)
return project_root() / "skills"
def get_skill_container_path(self, skill_name: str, category: str = "public") -> str:
"""
@@ -3,6 +3,7 @@ import logging
from langchain.chat_models import BaseChatModel
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_class
from deerflow.tracing import build_tracing_callbacks
@@ -46,7 +47,7 @@ def _enable_stream_usage_by_default(model_use_path: str, model_settings_from_con
model_settings_from_config["stream_usage"] = True
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel:
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *, app_config: AppConfig | None = None, **kwargs) -> BaseChatModel:
"""Create a chat model instance from the config.
Args:
@@ -55,7 +56,7 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
Returns:
A chat model instance.
"""
config = get_app_config()
config = app_config or get_app_config()
if name is None:
name = config.models[0].name
model_config = config.get_model_config(name)
@@ -1,4 +1,5 @@
import ast
import html
import json
import re
import uuid
@@ -36,8 +37,8 @@ def _fix_messages(messages: list) -> list:
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", []):
xml_parts = []
for tool in msg.tool_calls:
args_xml = " ".join(f"<parameter={k}>{json.dumps(v, ensure_ascii=False)}</parameter>" for k, v in tool.get("args", {}).items())
xml_parts.append(f"<tool_call> <function={tool['name']}> {args_xml} </function> </tool_call>")
args_xml = " ".join(f"<parameter={html.escape(str(k), quote=False)}>{html.escape(v if isinstance(v, str) else json.dumps(v, ensure_ascii=False), quote=False)}</parameter>" for k, v in tool.get("args", {}).items())
xml_parts.append(f"<tool_call> <function={html.escape(str(tool['name']), quote=False)}> {args_xml} </function> </tool_call>")
full_text = f"{text}\n" + "\n".join(xml_parts) if text else "\n".join(xml_parts)
fixed.append(AIMessage(content=full_text.strip() or " "))
continue
@@ -80,13 +81,24 @@ def _parse_xml_tool_call_to_dict(content: str) -> tuple[str, list[dict]]:
func_match = re.search(r"<function=([^>]+)>", inner_content)
if not func_match:
continue
function_name = func_match.group(1).strip()
function_name = html.unescape(func_match.group(1).strip())
# Ignore nested tool blocks when extracting parameters for this call.
# Nested `<tool_call>` sections represent separate invocations and
# their `<parameter>` tags must not leak into the current call args.
param_source_parts: list[str] = []
nested_cursor = 0
for nested_start, nested_end, _ in _iter_tool_call_blocks(inner_content):
param_source_parts.append(inner_content[nested_cursor:nested_start])
nested_cursor = nested_end
param_source_parts.append(inner_content[nested_cursor:])
param_source = "".join(param_source_parts)
args = {}
param_pattern = re.compile(r"<parameter=([^>]+)>(.*?)</parameter>", re.DOTALL)
for param_match in param_pattern.finditer(inner_content):
key = param_match.group(1).strip()
raw_value = param_match.group(2).strip()
for param_match in param_pattern.finditer(param_source):
key = html.unescape(param_match.group(1).strip())
raw_value = html.unescape(param_match.group(2).strip())
# Attempt to deserialize string values into native Python types
# to satisfy downstream Pydantic validation.
@@ -24,7 +24,7 @@ from collections.abc import AsyncIterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.runtime.checkpointer.provider import (
POSTGRES_CONN_REQUIRED,
POSTGRES_INSTALL,
@@ -123,11 +123,11 @@ async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpoi
@contextlib.asynccontextmanager
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
async def make_checkpointer(app_config: AppConfig | None = None) -> AsyncIterator[Checkpointer]:
"""Async context manager that yields a checkpointer for the caller's lifetime.
Resources are opened on enter and closed on exit -- no global state::
async with make_checkpointer() as checkpointer:
async with make_checkpointer(app_config) as checkpointer:
app.state.checkpointer = checkpointer
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
@@ -138,16 +138,17 @@ async def make_checkpointer() -> AsyncIterator[Checkpointer]:
3. Default InMemorySaver
"""
config = get_app_config()
if app_config is None:
app_config = get_app_config()
# Legacy: standalone checkpointer config takes precedence
if config.checkpointer is not None:
async with _async_checkpointer(config.checkpointer) as saver:
if app_config.checkpointer is not None:
async with _async_checkpointer(app_config.checkpointer) as saver:
yield saver
return
# Unified database config
db_config = getattr(config, "database", None)
db_config = getattr(app_config, "database", None)
if db_config is not None and db_config.backend != "memory":
async with _async_checkpointer_from_database(db_config) as saver:
yield saver
@@ -25,7 +25,7 @@ from collections.abc import Iterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
@@ -98,9 +98,78 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
_checkpointer: Checkpointer | None = None
_checkpointer_ctx = None # open context manager keeping the connection alive
_explicit_checkpointers: dict[int, Checkpointer] = {}
_explicit_checkpointer_contexts: dict[int, object] = {}
def get_checkpointer() -> Checkpointer:
def _default_in_memory_checkpointer() -> Checkpointer:
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
return InMemorySaver()
def _persistent_database_backend(db_config) -> str | None:
backend = getattr(db_config, "backend", None)
if backend in {"sqlite", "postgres"}:
return backend
return None
@contextlib.contextmanager
def _sync_checkpointer_from_database_cm(db_config) -> Iterator[Checkpointer]:
"""Context manager that creates a sync checkpointer from unified DatabaseConfig."""
backend = _persistent_database_backend(db_config)
if backend is None:
yield _default_in_memory_checkpointer()
return
if backend == "sqlite":
try:
from langgraph.checkpoint.sqlite import SqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = db_config.checkpointer_sqlite_path
ensure_sqlite_parent_dir(conn_str)
with SqliteSaver.from_conn_string(conn_str) as saver:
saver.setup()
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
yield saver
return
if backend == "postgres":
try:
from langgraph.checkpoint.postgres import PostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not db_config.postgres_url:
raise ValueError("database.postgres_url is required for the postgres backend")
with PostgresSaver.from_conn_string(db_config.postgres_url) as saver:
saver.setup()
logger.info("Checkpointer: using PostgresSaver")
yield saver
return
raise ValueError(f"Unknown database backend: {backend!r}")
def _build_checkpointer_from_app_config(app_config: AppConfig) -> tuple[Checkpointer, object | None]:
if app_config.checkpointer is not None:
ctx = _sync_checkpointer_cm(app_config.checkpointer)
return ctx.__enter__(), ctx
db_config = getattr(app_config, "database", None)
if _persistent_database_backend(db_config) is not None:
ctx = _sync_checkpointer_from_database_cm(db_config)
return ctx.__enter__(), ctx
return _default_in_memory_checkpointer(), None
def get_checkpointer(app_config: AppConfig | None = None) -> Checkpointer:
"""Return the global sync checkpointer singleton, creating it on first call.
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
@@ -111,6 +180,18 @@ def get_checkpointer() -> Checkpointer:
"""
global _checkpointer, _checkpointer_ctx
if app_config is not None:
cache_key = id(app_config)
cached = _explicit_checkpointers.get(cache_key)
if cached is not None:
return cached
explicit_checkpointer, explicit_ctx = _build_checkpointer_from_app_config(app_config)
_explicit_checkpointers[cache_key] = explicit_checkpointer
if explicit_ctx is not None:
_explicit_checkpointer_contexts[cache_key] = explicit_ctx
return explicit_checkpointer
if _checkpointer is not None:
return _checkpointer
@@ -121,28 +202,30 @@ def get_checkpointer() -> Checkpointer:
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
global_app_config = _app_config
if config is None and _app_config is None:
if config is None and global_app_config is None:
# Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk.
try:
get_app_config()
global_app_config = get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected.
pass
config = get_checkpointer_config()
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
_checkpointer = InMemorySaver()
if config is not None:
_checkpointer_ctx = _sync_checkpointer_cm(config)
_checkpointer = _checkpointer_ctx.__enter__()
return _checkpointer
_checkpointer_ctx = _sync_checkpointer_cm(config)
_checkpointer = _checkpointer_ctx.__enter__()
if global_app_config is not None:
_checkpointer, _checkpointer_ctx = _build_checkpointer_from_app_config(global_app_config)
return _checkpointer
_checkpointer = _default_in_memory_checkpointer()
return _checkpointer
@@ -161,6 +244,18 @@ def reset_checkpointer() -> None:
_checkpointer_ctx = None
_checkpointer = None
for cache_key, ctx in list(_explicit_checkpointer_contexts.items()):
try:
ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during explicit checkpointer cleanup", exc_info=True)
finally:
_explicit_checkpointer_contexts.pop(cache_key, None)
_explicit_checkpointers.pop(cache_key, None)
_explicit_checkpointers.clear()
_explicit_checkpointer_contexts.clear()
# ---------------------------------------------------------------------------
# Sync context manager
@@ -168,7 +263,7 @@ def reset_checkpointer() -> None:
@contextlib.contextmanager
def checkpointer_context() -> Iterator[Checkpointer]:
def checkpointer_context(app_config: AppConfig | None = None) -> Iterator[Checkpointer]:
"""Sync context manager that yields a checkpointer and cleans up on exit.
Unlike :func:`get_checkpointer`, this does **not** cache the instance
@@ -181,12 +276,16 @@ def checkpointer_context() -> Iterator[Checkpointer]:
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
resolved_app_config = app_config or get_app_config()
if resolved_app_config.checkpointer is not None:
with _sync_checkpointer_cm(resolved_app_config.checkpointer) as saver:
yield saver
return
with _sync_checkpointer_cm(config.checkpointer) as saver:
yield saver
db_config = getattr(resolved_app_config, "database", None)
if _persistent_database_backend(db_config) is not None:
with _sync_checkpointer_from_database_cm(db_config) as saver:
yield saver
return
yield _default_in_memory_checkpointer()
@@ -1,6 +1,8 @@
"""Pure functions to convert LangChain message objects to OpenAI Chat Completions format.
Used by RunJournal to build content dicts for event storage.
Utility for translating LangChain message types to OpenAI-compatible dicts.
Not currently wired into RunJournal (which uses message.model_dump() directly),
but available for consumers that need the OpenAI wire format.
"""
from __future__ import annotations
@@ -62,9 +62,6 @@ class RunJournal(BaseCallbackHandler):
self._total_output_tokens = 0
self._total_tokens = 0
self._llm_call_count = 0
self._lead_agent_tokens = 0
self._subagent_tokens = 0
self._middleware_tokens = 0
# Convenience fields
self._last_ai_msg: str | None = None
@@ -76,7 +73,7 @@ class RunJournal(BaseCallbackHandler):
# LLM request/response tracking
self._llm_call_index = 0
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
self._seen_llm_starts: set[str] = set() # langchain run_ids that fired on_chat_model_start
# -- Lifecycle callbacks --
@@ -135,15 +132,20 @@ class RunJournal(BaseCallbackHandler):
rid = str(run_id)
self._llm_start_times[rid] = time.monotonic()
self._llm_call_index += 1
# Mark this run_id as seen so on_llm_end knows not to increment again.
self._cached_prompts[rid] = []
self._seen_llm_starts.add(rid)
logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}")
logger.debug(
"on_chat_model_start %s: tags=%s num_batches=%d message_counts=%s",
run_id,
tags,
len(messages),
[len(batch) for batch in messages],
)
# Capture the first human message sent to any LLM in this run.
if not self._first_human_msg and not messages:
for batch in messages.reversed():
for m in batch.reversed():
if not self._first_human_msg and messages:
for batch in reversed(messages):
for m in reversed(batch):
if isinstance(m, HumanMessage) and m.name != "summary":
caller = self._identify_caller(tags)
self.set_first_human_message(m.text)
@@ -161,9 +163,17 @@ class RunJournal(BaseCallbackHandler):
# Fallback: on_chat_model_start is preferred. This just tracks latency.
self._llm_start_times[str(run_id)] = time.monotonic()
def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None:
def on_llm_end(
self,
response: Any,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
messages: list[AnyMessage] = []
logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}")
logger.debug("on_llm_end %s: tags=%s", run_id, tags)
for generation in response.generations:
for gen in generation:
if hasattr(gen, "message"):
@@ -185,10 +195,11 @@ class RunJournal(BaseCallbackHandler):
# Resolve call index
call_index = self._llm_call_index
if rid not in self._cached_prompts:
if rid not in self._seen_llm_starts:
# Fallback: on_chat_model_start was not called
self._llm_call_index += 1
call_index = self._llm_call_index
self._seen_llm_starts.add(rid)
# Trace event: llm_response (OpenAI completion format)
self._put(
@@ -223,7 +234,7 @@ class RunJournal(BaseCallbackHandler):
def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs):
"""Handle tool start event, cache tool call ID for later correlation"""
tool_call_id = str(run_id)
logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}")
logger.debug("Tool start for node %s, tool_call_id=%s, tags=%s", run_id, tool_call_id, tags)
def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs):
"""Handle tool end event, append message and clear node data"""
@@ -242,7 +253,7 @@ class RunJournal(BaseCallbackHandler):
else:
logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}")
finally:
logger.info(f"Tool end for node {run_id}")
logger.debug("Tool end for node %s", run_id)
# -- Internal methods --
@@ -307,8 +318,8 @@ class RunJournal(BaseCallbackHandler):
if exc:
logger.warning("Journal flush task failed: %s", exc)
def _identify_caller(self, tags: list[str] | None, **kwargs) -> str:
_tags = tags or kwargs.get("tags", [])
def _identify_caller(self, tags: list[str] | None) -> str:
_tags = tags or []
for tag in _tags:
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
return tag
@@ -365,9 +376,6 @@ class RunJournal(BaseCallbackHandler):
"total_output_tokens": self._total_output_tokens,
"total_tokens": self._total_tokens,
"llm_call_count": self._llm_call_count,
"lead_agent_tokens": self._lead_agent_tokens,
"subagent_tokens": self._subagent_tokens,
"middleware_tokens": self._middleware_tokens,
"message_count": self._msg_count,
"last_ai_message": self._last_ai_msg,
"first_human_message": self._first_human_msg,
@@ -20,11 +20,13 @@ import copy
import inspect
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, cast
if TYPE_CHECKING:
from langchain_core.messages import HumanMessage
from deerflow.config.app_config import AppConfig
from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge
@@ -37,6 +39,33 @@ logger = logging.getLogger(__name__)
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
def _build_runtime_context(
thread_id: str,
run_id: str,
caller_context: Any | None,
app_config: AppConfig | None = None,
) -> dict[str, Any]:
"""Build the dict that becomes ``ToolRuntime.context`` for the run.
Always includes ``thread_id`` and ``run_id``. Additional keys from the caller's
``config['context']`` (e.g. ``agent_name`` for the bootstrap flow issue #2677)
are merged in but never override ``thread_id``/``run_id``. The resolved
``AppConfig`` is added by the worker so tools can consume it without ambient
global lookups.
langgraph 1.1+ surfaces this as ``runtime.context`` via the parent runtime stored
under ``config['configurable']['__pregel_runtime']`` see
``langgraph.pregel.main`` where ``parent_runtime.merge(...)`` is invoked.
"""
runtime_ctx: dict[str, Any] = {"thread_id": thread_id, "run_id": run_id}
if isinstance(caller_context, dict):
for key, value in caller_context.items():
runtime_ctx.setdefault(key, value)
if app_config is not None:
runtime_ctx["app_config"] = app_config
return runtime_ctx
@dataclass(frozen=True)
class RunContext:
"""Infrastructure dependencies for a single agent run.
@@ -51,6 +80,39 @@ class RunContext:
event_store: Any | None = field(default=None)
run_events_config: Any | None = field(default=None)
thread_store: Any | None = field(default=None)
app_config: AppConfig | None = field(default=None)
def _install_runtime_context(config: dict, runtime_context: dict[str, Any]) -> None:
existing_context = config.get("context")
if isinstance(existing_context, dict):
existing_context.setdefault("thread_id", runtime_context["thread_id"])
existing_context.setdefault("run_id", runtime_context["run_id"])
if "app_config" in runtime_context:
existing_context["app_config"] = runtime_context["app_config"]
return
config["context"] = dict(runtime_context)
def _compute_agent_factory_supports_app_config(agent_factory: Any) -> bool:
try:
return "app_config" in inspect.signature(agent_factory).parameters
except (TypeError, ValueError):
return False
@lru_cache(maxsize=128)
def _cached_agent_factory_supports_app_config(agent_factory: Any) -> bool:
return _compute_agent_factory_supports_app_config(agent_factory)
def _agent_factory_supports_app_config(agent_factory: Any) -> bool:
try:
return _cached_agent_factory_supports_app_config(agent_factory)
except TypeError:
# Some callable instances are unhashable; fall back to a direct check.
return _compute_agent_factory_supports_app_config(agent_factory)
async def run_agent(
@@ -146,15 +208,13 @@ async def run_agent(
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
# Inject runtime context so middlewares can access thread_id
# (langgraph-cli does this automatically; we must do it manually)
runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store)
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
# prefers it over ``configurable`` for thread-level data), make
# sure ``thread_id`` is available there too.
if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("thread_id", thread_id)
config["context"].setdefault("run_id", run_id)
# Inject runtime context so middlewares and tools (via ToolRuntime.context) can
# access thread-level data. langgraph-cli does this automatically; we must do it
# manually here because we drive the graph through ``agent.astream(config=...)``
# without passing the official ``context=`` parameter.
runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"), ctx.app_config)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=cast(Any, runtime_ctx), store=store)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Inject RunJournal as a LangChain callback handler.
@@ -163,7 +223,10 @@ async def run_agent(
config.setdefault("callbacks", []).append(journal)
runnable_config = RunnableConfig(**config)
agent = agent_factory(config=runnable_config)
if ctx.app_config is not None and _agent_factory_supports_app_config(agent_factory):
agent = agent_factory(config=runnable_config, app_config=ctx.app_config)
else:
agent = agent_factory(config=runnable_config)
# 4. Attach checkpointer and store
if checkpointer is not None:
@@ -23,7 +23,7 @@ from collections.abc import AsyncIterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.runtime.store.provider import POSTGRES_CONN_REQUIRED, POSTGRES_STORE_INSTALL, SQLITE_STORE_INSTALL, ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -86,7 +86,7 @@ async def _async_store(config) -> AsyncIterator[BaseStore]:
@contextlib.asynccontextmanager
async def make_store() -> AsyncIterator[BaseStore]:
async def make_store(app_config: AppConfig | None = None) -> AsyncIterator[BaseStore]:
"""Async context manager that yields a Store whose backend matches the
configured checkpointer.
@@ -94,20 +94,21 @@ async def make_store() -> AsyncIterator[BaseStore]:
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so
that both singletons always use the same persistence technology::
async with make_store() as store:
async with make_store(app_config) as store:
app.state.store = store
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
``checkpointer`` section is configured (emits a WARNING in that case).
"""
config = get_app_config()
if app_config is None:
app_config = get_app_config()
if config.checkpointer is None:
if app_config.checkpointer is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore()
return
async with _async_store(config.checkpointer) as store:
async with _async_store(app_config.checkpointer) as store:
yield store
@@ -26,7 +26,7 @@ from collections.abc import Iterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -98,9 +98,26 @@ def _sync_store_cm(config) -> Iterator[BaseStore]:
_store: BaseStore | None = None
_store_ctx = None # open context manager keeping the connection alive
_explicit_stores: dict[int, BaseStore] = {}
_explicit_store_contexts: dict[int, object] = {}
def get_store() -> BaseStore:
def _default_in_memory_store() -> BaseStore:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
return InMemoryStore()
def _build_store_from_app_config(app_config: AppConfig) -> tuple[BaseStore, object | None]:
if app_config.checkpointer is not None:
ctx = _sync_store_cm(app_config.checkpointer)
return ctx.__enter__(), ctx
return _default_in_memory_store(), None
def get_store(app_config: AppConfig | None = None) -> BaseStore:
"""Return the global sync Store singleton, creating it on first call.
Returns an :class:`~langgraph.store.memory.InMemoryStore` when no
@@ -112,6 +129,18 @@ def get_store() -> BaseStore:
"""
global _store, _store_ctx
if app_config is not None:
cache_key = id(app_config)
cached = _explicit_stores.get(cache_key)
if cached is not None:
return cached
explicit_store, explicit_ctx = _build_store_from_app_config(app_config)
_explicit_stores[cache_key] = explicit_store
if explicit_ctx is not None:
_explicit_store_contexts[cache_key] = explicit_ctx
return explicit_store
if _store is not None:
return _store
@@ -130,10 +159,7 @@ def get_store() -> BaseStore:
config = get_checkpointer_config()
if config is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
_store = InMemoryStore()
_store = _default_in_memory_store()
return _store
_store_ctx = _sync_store_cm(config)
@@ -156,6 +182,18 @@ def reset_store() -> None:
_store_ctx = None
_store = None
for cache_key, ctx in list(_explicit_store_contexts.items()):
try:
ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during explicit store cleanup", exc_info=True)
finally:
_explicit_store_contexts.pop(cache_key, None)
_explicit_stores.pop(cache_key, None)
_explicit_stores.clear()
_explicit_store_contexts.clear()
# ---------------------------------------------------------------------------
# Sync context manager
@@ -163,7 +201,7 @@ def reset_store() -> None:
@contextlib.contextmanager
def store_context() -> Iterator[BaseStore]:
def store_context(app_config: AppConfig | None = None) -> Iterator[BaseStore]:
"""Sync context manager that yields a Store and cleans up on exit.
Unlike :func:`get_store`, this does **not** cache the instance each
@@ -176,13 +214,10 @@ def store_context() -> Iterator[BaseStore]:
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore()
resolved_app_config = app_config or get_app_config()
if resolved_app_config.checkpointer is None:
yield _default_in_memory_store()
return
with _sync_store_cm(config.checkpointer) as store:
with _sync_store_cm(resolved_app_config.checkpointer) as store:
yield store
@@ -17,6 +17,7 @@ import contextlib
import logging
from collections.abc import AsyncIterator
from deerflow.config.app_config import AppConfig
from deerflow.config.stream_bridge_config import get_stream_bridge_config
from .base import StreamBridge
@@ -25,14 +26,16 @@ logger = logging.getLogger(__name__)
@contextlib.asynccontextmanager
async def make_stream_bridge(config=None) -> AsyncIterator[StreamBridge]:
async def make_stream_bridge(app_config: AppConfig | None = None) -> AsyncIterator[StreamBridge]:
"""Async context manager that yields a :class:`StreamBridge`.
Falls back to :class:`MemoryStreamBridge` when no configuration is
provided and nothing is set globally.
"""
if config is None:
if app_config is None:
config = get_stream_bridge_config()
else:
config = app_config.stream_bridge
if config is None or config.type == "memory":
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge
@@ -22,6 +22,13 @@ def list_dir(path: str, max_depth: int = 2) -> list[str]:
if not root_path.is_dir():
return result
def _is_within_root(candidate: Path) -> bool:
try:
candidate.relative_to(root_path)
return True
except ValueError:
return False
def _traverse(current_path: Path, current_depth: int) -> None:
"""Recursively traverse directories up to max_depth."""
if current_depth > max_depth:
@@ -32,8 +39,23 @@ def list_dir(path: str, max_depth: int = 2) -> list[str]:
if should_ignore_name(item.name):
continue
if item.is_symlink():
try:
item_resolved = item.resolve()
if not _is_within_root(item_resolved):
continue
except OSError:
continue
post_fix = "/" if item_resolved.is_dir() else ""
result.append(str(item_resolved) + post_fix)
continue
item_resolved = item.resolve()
if not _is_within_root(item_resolved):
continue
post_fix = "/" if item.is_dir() else ""
result.append(str(item.resolve()) + post_fix)
result.append(str(item_resolved) + post_fix)
# Recurse into subdirectories if not at max depth
if item.is_dir() and current_depth < max_depth:
@@ -5,6 +5,7 @@ import shutil
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import NamedTuple
from deerflow.sandbox.local.list_dir import list_dir
from deerflow.sandbox.sandbox import Sandbox
@@ -20,6 +21,11 @@ class PathMapping:
read_only: bool = False
class ResolvedPath(NamedTuple):
path: str
mapping: PathMapping | None
class LocalSandbox(Sandbox):
@staticmethod
def _shell_name(shell: str) -> str:
@@ -91,7 +97,23 @@ class LocalSandbox(Sandbox):
return best_mapping.read_only
def _resolve_path(self, path: str) -> str:
def _find_path_mapping(self, path: str) -> tuple[PathMapping, str] | None:
path_str = str(path)
for mapping in sorted(self.path_mappings, key=lambda m: len(m.container_path.rstrip("/") or "/"), reverse=True):
container_path = mapping.container_path.rstrip("/") or "/"
if container_path == "/":
if path_str.startswith("/"):
return mapping, path_str.lstrip("/")
continue
if path_str == container_path or path_str.startswith(container_path + "/"):
relative = path_str[len(container_path) :].lstrip("/")
return mapping, relative
return None
def _resolve_path_with_mapping(self, path: str) -> ResolvedPath:
"""
Resolve container path to actual local path using mappings.
@@ -99,22 +121,30 @@ class LocalSandbox(Sandbox):
path: Path that might be a container path
Returns:
Resolved local path
Resolved local path and the matched mapping, if any
"""
path_str = str(path)
# Try each mapping (longest prefix first for more specific matches)
for mapping in sorted(self.path_mappings, key=lambda m: len(m.container_path), reverse=True):
container_path = mapping.container_path
local_path = mapping.local_path
if path_str == container_path or path_str.startswith(container_path + "/"):
# Replace the container path prefix with local path
relative = path_str[len(container_path) :].lstrip("/")
resolved = str(Path(local_path) / relative) if relative else local_path
return resolved
mapping_match = self._find_path_mapping(path_str)
if mapping_match is None:
return ResolvedPath(path_str, None)
# No mapping found, return original path
return path_str
mapping, relative = mapping_match
local_root = Path(mapping.local_path).resolve()
resolved_path = (local_root / relative).resolve() if relative else local_root
try:
resolved_path.relative_to(local_root)
except ValueError as exc:
raise PermissionError(errno.EACCES, "Access denied: path escapes mounted directory", path_str) from exc
return ResolvedPath(str(resolved_path), mapping)
def _resolve_path(self, path: str) -> str:
return self._resolve_path_with_mapping(path).path
def _is_resolved_path_read_only(self, resolved: ResolvedPath) -> bool:
return bool(resolved.mapping and resolved.mapping.read_only) or self._is_read_only_path(resolved.path)
def _reverse_resolve_path(self, path: str) -> str:
"""
@@ -309,8 +339,14 @@ class LocalSandbox(Sandbox):
def list_dir(self, path: str, max_depth=2) -> list[str]:
resolved_path = self._resolve_path(path)
entries = list_dir(resolved_path, max_depth)
# Reverse resolve local paths back to container paths in output
return [self._reverse_resolve_paths_in_output(entry) for entry in entries]
# Reverse resolve local paths back to container paths and preserve
# list_dir's trailing "/" marker for directories.
result: list[str] = []
for entry in entries:
is_dir = entry.endswith(("/", "\\"))
reversed_entry = self._reverse_resolve_path(entry.rstrip("/\\")) if is_dir else self._reverse_resolve_path(entry)
result.append(f"{reversed_entry}/" if is_dir and not reversed_entry.endswith("/") else reversed_entry)
return result
def read_file(self, path: str) -> str:
resolved_path = self._resolve_path(path)
@@ -329,8 +365,9 @@ class LocalSandbox(Sandbox):
raise type(e)(e.errno, e.strerror, path) from None
def write_file(self, path: str, content: str, append: bool = False) -> None:
resolved_path = self._resolve_path(path)
if self._is_read_only_path(resolved_path):
resolved = self._resolve_path_with_mapping(path)
resolved_path = resolved.path
if self._is_resolved_path_read_only(resolved):
raise OSError(errno.EROFS, "Read-only file system", path)
try:
dir_path = os.path.dirname(resolved_path)
@@ -384,8 +421,9 @@ class LocalSandbox(Sandbox):
], truncated
def update_file(self, path: str, content: bytes) -> None:
resolved_path = self._resolve_path(path)
if self._is_read_only_path(resolved_path):
resolved = self._resolve_path_with_mapping(path)
resolved_path = resolved.path
if self._is_resolved_path_read_only(resolved):
raise OSError(errno.EROFS, "Read-only file system", path)
try:
dir_path = os.path.dirname(resolved_path)
@@ -22,6 +22,9 @@ from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
_FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE)
_URL_WITH_SCHEME_PATTERN = re.compile(r"^[a-z][a-z0-9+.-]*://", re.IGNORECASE)
_URL_IN_COMMAND_PATTERN = re.compile(r"\b[a-z][a-z0-9+.-]*://[^\s\"'`;&|<>()]+", re.IGNORECASE)
_DOTDOT_PATH_SEGMENT_PATTERN = re.compile(r"(?:^|[/\\=])\.\.(?:$|[/\\])")
_LOCAL_BASH_SYSTEM_PATH_PREFIXES = (
"/bin/",
"/usr/bin/",
@@ -37,6 +40,42 @@ _DEFAULT_GLOB_MAX_RESULTS = 200
_MAX_GLOB_MAX_RESULTS = 1000
_DEFAULT_GREP_MAX_RESULTS = 100
_MAX_GREP_MAX_RESULTS = 500
_LOCAL_BASH_CWD_COMMANDS = {"cd", "pushd"}
_LOCAL_BASH_COMMAND_WRAPPERS = {"command", "builtin"}
_LOCAL_BASH_COMMAND_PREFIX_KEYWORDS = {"!", "{", "case", "do", "elif", "else", "for", "if", "select", "then", "time", "until", "while"}
_LOCAL_BASH_COMMAND_END_KEYWORDS = {"}", "done", "esac", "fi"}
_LOCAL_BASH_ROOT_PATH_COMMANDS = {
"awk",
"cat",
"cp",
"du",
"find",
"grep",
"head",
"less",
"ln",
"ls",
"more",
"mv",
"rm",
"sed",
"tail",
"tar",
}
_SHELL_COMMAND_SEPARATORS = {";", "&&", "||", "|", "|&", "&", "(", ")"}
_SHELL_REDIRECTION_OPERATORS = {
"<",
">",
"<<",
">>",
"<<<",
"<>",
">&",
"<&",
"&>",
"&>>",
">|",
}
def _get_skills_container_path() -> str:
@@ -549,7 +588,7 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
This function is a security gate it checks whether *path* may be
accessed and raises on violation. It does **not** resolve the virtual
path to a host path; callers are responsible for resolution via
``_resolve_and_validate_user_data_path`` or ``_resolve_skills_path``.
``resolve_and_validate_user_data_path`` or ``_resolve_skills_path``.
Allowed virtual-path families:
- ``/mnt/user-data/*`` always allowed (read + write)
@@ -636,6 +675,219 @@ def _resolve_and_validate_user_data_path(path: str, thread_data: ThreadDataState
return str(resolved)
def _is_non_file_url_token(token: str) -> bool:
"""Return True for URL tokens that should not be interpreted as paths."""
values = [token]
if "=" in token:
values.append(token.split("=", 1)[1])
for value in values:
match = _URL_WITH_SCHEME_PATTERN.match(value)
if match and not value.lower().startswith("file://"):
return True
return False
def _non_file_url_spans(command: str) -> list[tuple[int, int]]:
spans = []
for match in _URL_IN_COMMAND_PATTERN.finditer(command):
if not match.group().lower().startswith("file://"):
spans.append(match.span())
return spans
def _is_in_spans(position: int, spans: list[tuple[int, int]]) -> bool:
return any(start <= position < end for start, end in spans)
def _has_dotdot_path_segment(token: str) -> bool:
if _is_non_file_url_token(token):
return False
return bool(_DOTDOT_PATH_SEGMENT_PATTERN.search(token))
def _split_shell_tokens(command: str) -> list[str]:
try:
normalized = command.replace("\r\n", "\n").replace("\r", "\n").replace("\n", " ; ")
lexer = shlex.shlex(normalized, posix=True, punctuation_chars=True)
lexer.whitespace_split = True
lexer.commenters = ""
return list(lexer)
except ValueError:
# The shell will reject malformed quoting later; keep validation
# best-effort instead of turning syntax errors into security messages.
return command.split()
def _is_shell_command_separator(token: str) -> bool:
return token in _SHELL_COMMAND_SEPARATORS
def _is_shell_redirection_operator(token: str) -> bool:
return token in _SHELL_REDIRECTION_OPERATORS
def _is_shell_assignment(token: str) -> bool:
name, separator, _ = token.partition("=")
if not separator or not name:
return False
return bool(re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", name))
def _is_allowed_local_bash_absolute_path(path: str, allowed_paths: list[str], *, allow_system_paths: bool) -> bool:
# Check for MCP filesystem server allowed paths
if any(path.startswith(allowed_path) or path == allowed_path.rstrip("/") for allowed_path in allowed_paths):
_reject_path_traversal(path)
return True
if path == VIRTUAL_PATH_PREFIX or path.startswith(f"{VIRTUAL_PATH_PREFIX}/"):
_reject_path_traversal(path)
return True
# Allow skills container path (resolved by tools.py before passing to sandbox)
if _is_skills_path(path):
_reject_path_traversal(path)
return True
# Allow ACP workspace path (path-traversal check only)
if _is_acp_workspace_path(path):
_reject_path_traversal(path)
return True
# Allow custom mount container paths
if _is_custom_mount_path(path):
_reject_path_traversal(path)
return True
if allow_system_paths and any(path == prefix.rstrip("/") or path.startswith(prefix) for prefix in _LOCAL_BASH_SYSTEM_PATH_PREFIXES):
return True
return False
def _next_cd_target(tokens: list[str], start_index: int) -> tuple[str | None, int]:
index = start_index
while index < len(tokens):
token = tokens[index]
if _is_shell_command_separator(token):
return None, index
if _is_shell_redirection_operator(token):
index += 2
continue
if token == "--":
index += 1
continue
if token in {"-L", "-P", "-e", "-@"}:
index += 1
continue
if token.startswith("-") and token != "-":
index += 1
continue
return token, index + 1
return None, index
def _validate_local_bash_cwd_target(command_name: str, target: str | None, allowed_paths: list[str]) -> None:
if target is None or target == "-":
raise PermissionError(f"Unsafe working directory change in command: {command_name}. Use paths under {VIRTUAL_PATH_PREFIX}")
if target.startswith(("$", "`")):
raise PermissionError(f"Unsafe working directory change in command: {command_name} {target}. Use paths under {VIRTUAL_PATH_PREFIX}")
if target.startswith("~"):
raise PermissionError(f"Unsafe working directory change in command: {command_name} {target}. Use paths under {VIRTUAL_PATH_PREFIX}")
if target.startswith("/"):
_reject_path_traversal(target)
if not _is_allowed_local_bash_absolute_path(target, allowed_paths, allow_system_paths=False):
raise PermissionError(f"Unsafe working directory change in command: {command_name} {target}. Use paths under {VIRTUAL_PATH_PREFIX}")
def _looks_like_unsafe_cwd_target(target: str | None) -> bool:
if target is None:
return False
return target == "-" or target.startswith(("$", "`", "~", "/", "..")) or _has_dotdot_path_segment(target)
def _validate_local_bash_root_path_args(command_name: str, tokens: list[str], start_index: int) -> None:
if command_name not in _LOCAL_BASH_ROOT_PATH_COMMANDS:
return
index = start_index
while index < len(tokens):
token = tokens[index]
if _is_shell_command_separator(token):
return
if _is_shell_redirection_operator(token):
index += 2
continue
if token == "/" and not _is_non_file_url_token(token):
raise PermissionError(f"Unsafe absolute paths in command: /. Use paths under {VIRTUAL_PATH_PREFIX}")
index += 1
def _validate_local_bash_shell_tokens(command: str, allowed_paths: list[str]) -> None:
"""Conservatively reject relative path escapes missed by absolute-path scanning."""
if re.search(r"\$\([^)]*\b(?:cd|pushd)\b", command):
raise PermissionError(f"Unsafe working directory change in command substitution. Use paths under {VIRTUAL_PATH_PREFIX}")
tokens = _split_shell_tokens(command)
for token in tokens:
if _is_shell_command_separator(token) or _is_shell_redirection_operator(token):
continue
if _has_dotdot_path_segment(token):
raise PermissionError("Access denied: path traversal detected")
at_command_start = True
index = 0
while index < len(tokens):
token = tokens[index]
if _is_shell_command_separator(token):
at_command_start = True
index += 1
continue
if _is_shell_redirection_operator(token):
index += 1
continue
if at_command_start and _is_shell_assignment(token):
index += 1
continue
command_name = token.rsplit("/", 1)[-1]
if at_command_start and command_name in _LOCAL_BASH_COMMAND_PREFIX_KEYWORDS | _LOCAL_BASH_COMMAND_END_KEYWORDS:
index += 1
continue
if not at_command_start:
index += 1
continue
at_command_start = False
if command_name in _LOCAL_BASH_COMMAND_WRAPPERS and index + 1 < len(tokens):
wrapped_name = tokens[index + 1].rsplit("/", 1)[-1]
if wrapped_name in _LOCAL_BASH_CWD_COMMANDS:
target, next_index = _next_cd_target(tokens, index + 2)
_validate_local_bash_cwd_target(wrapped_name, target, allowed_paths)
index = next_index
continue
_validate_local_bash_root_path_args(wrapped_name, tokens, index + 2)
if command_name not in _LOCAL_BASH_CWD_COMMANDS:
_validate_local_bash_root_path_args(command_name, tokens, index + 1)
index += 1
continue
target, next_index = _next_cd_target(tokens, index + 1)
_validate_local_bash_cwd_target(command_name, target, allowed_paths)
index = next_index
def resolve_and_validate_user_data_path(path: str, thread_data: ThreadDataState) -> str:
"""Resolve a /mnt/user-data virtual path and validate it stays in bounds."""
return _resolve_and_validate_user_data_path(path, thread_data)
def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState | None) -> None:
"""Validate absolute paths in local-sandbox bash commands.
@@ -661,33 +913,14 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
unsafe_paths: list[str] = []
allowed_paths = _get_mcp_allowed_paths()
_validate_local_bash_shell_tokens(command, allowed_paths)
url_spans = _non_file_url_spans(command)
for absolute_path in _ABSOLUTE_PATH_PATTERN.findall(command):
# Check for MCP filesystem server allowed paths
if any(absolute_path.startswith(path) or absolute_path == path.rstrip("/") for path in allowed_paths):
_reject_path_traversal(absolute_path)
for match in _ABSOLUTE_PATH_PATTERN.finditer(command):
if _is_in_spans(match.start(), url_spans):
continue
if absolute_path == VIRTUAL_PATH_PREFIX or absolute_path.startswith(f"{VIRTUAL_PATH_PREFIX}/"):
_reject_path_traversal(absolute_path)
continue
# Allow skills container path (resolved by tools.py before passing to sandbox)
if _is_skills_path(absolute_path):
_reject_path_traversal(absolute_path)
continue
# Allow ACP workspace path (path-traversal check only)
if _is_acp_workspace_path(absolute_path):
_reject_path_traversal(absolute_path)
continue
# Allow custom mount container paths
if _is_custom_mount_path(absolute_path):
_reject_path_traversal(absolute_path)
continue
if any(absolute_path == prefix.rstrip("/") or absolute_path.startswith(prefix) for prefix in _LOCAL_BASH_SYSTEM_PATH_PREFIXES):
absolute_path = match.group()
if _is_allowed_local_bash_absolute_path(absolute_path, allowed_paths, allow_system_paths=True):
continue
unsafe_paths.append(absolute_path)
@@ -1,14 +1,17 @@
from .installer import SkillAlreadyExistsError, install_skill_from_archive
from .loader import get_skills_root_path, load_skills
from __future__ import annotations
from .installer import SkillAlreadyExistsError, SkillSecurityScanError
from .storage import LocalSkillStorage, SkillStorage, get_or_new_skill_storage
from .types import Skill
from .validation import ALLOWED_FRONTMATTER_PROPERTIES, _validate_skill_frontmatter
__all__ = [
"load_skills",
"get_skills_root_path",
"Skill",
"ALLOWED_FRONTMATTER_PROPERTIES",
"_validate_skill_frontmatter",
"install_skill_from_archive",
"SkillAlreadyExistsError",
"SkillSecurityScanError",
"SkillStorage",
"LocalSkillStorage",
"get_or_new_skill_storage",
]
@@ -4,24 +4,31 @@ Pure business logic — no FastAPI/HTTP dependencies.
Both Gateway and Client delegate to these functions.
"""
import asyncio
import concurrent.futures
import logging
import posixpath
import shutil
import stat
import tempfile
import zipfile
from pathlib import Path, PurePosixPath, PureWindowsPath
from deerflow.skills.loader import get_skills_root_path
from deerflow.skills.validation import _validate_skill_frontmatter
from deerflow.skills.security_scanner import scan_skill_content
logger = logging.getLogger(__name__)
_PROMPT_INPUT_DIRS = {"references", "templates"}
_PROMPT_INPUT_SUFFIXES = frozenset({".json", ".markdown", ".md", ".rst", ".txt", ".yaml", ".yml"})
class SkillAlreadyExistsError(ValueError):
"""Raised when a skill with the same name is already installed."""
class SkillSecurityScanError(ValueError):
"""Raised when a skill archive fails security scanning."""
def is_unsafe_zip_member(info: zipfile.ZipInfo) -> bool:
"""Return True if the zip member path is absolute or attempts directory traversal."""
name = info.filename
@@ -114,70 +121,84 @@ def safe_extract_skill_archive(
dst.write(chunk)
def install_skill_from_archive(
zip_path: str | Path,
*,
skills_root: Path | None = None,
) -> dict:
"""Install a skill from a .skill archive (ZIP).
def _is_script_support_file(rel_path: Path) -> bool:
return bool(rel_path.parts) and rel_path.parts[0] == "scripts"
Args:
zip_path: Path to the .skill file.
skills_root: Override the skills root directory. If None, uses
the default from config.
Returns:
Dict with success, skill_name, message.
def _should_scan_support_file(rel_path: Path) -> bool:
if _is_script_support_file(rel_path):
return True
return bool(rel_path.parts) and rel_path.parts[0] in _PROMPT_INPUT_DIRS and rel_path.suffix.lower() in _PROMPT_INPUT_SUFFIXES
Raises:
FileNotFoundError: If the file does not exist.
ValueError: If the file is invalid (wrong extension, bad ZIP,
invalid frontmatter, duplicate name).
"""
logger.info("Installing skill from %s", zip_path)
path = Path(zip_path)
if not path.is_file():
if not path.exists():
raise FileNotFoundError(f"Skill file not found: {zip_path}")
raise ValueError(f"Path is not a file: {zip_path}")
if path.suffix != ".skill":
raise ValueError("File must have .skill extension")
if skills_root is None:
skills_root = get_skills_root_path()
custom_dir = skills_root / "custom"
custom_dir.mkdir(parents=True, exist_ok=True)
def _move_staged_skill_into_reserved_target(staging_target: Path, target: Path) -> None:
installed = False
reserved = False
try:
target.mkdir(mode=0o700)
reserved = True
for child in staging_target.iterdir():
shutil.move(str(child), target / child.name)
installed = True
except FileExistsError as e:
raise SkillAlreadyExistsError(f"Skill '{target.name}' already exists") from e
finally:
if reserved and not installed and target.exists():
shutil.rmtree(target)
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
try:
zf = zipfile.ZipFile(path, "r")
except FileNotFoundError:
raise FileNotFoundError(f"Skill file not found: {zip_path}") from None
except (zipfile.BadZipFile, IsADirectoryError):
raise ValueError("File is not a valid ZIP archive") from None
async def _scan_skill_file_or_raise(skill_dir: Path, path: Path, skill_name: str, *, executable: bool) -> None:
rel_path = path.relative_to(skill_dir).as_posix()
location = f"{skill_name}/{rel_path}"
try:
content = path.read_text(encoding="utf-8")
except UnicodeDecodeError as e:
raise SkillSecurityScanError(f"Security scan failed for skill '{skill_name}': {location} must be valid UTF-8") from e
with zf:
safe_extract_skill_archive(zf, tmp_path)
try:
result = await scan_skill_content(content, executable=executable, location=location)
except Exception as e:
raise SkillSecurityScanError(f"Security scan failed for {location}: {e}") from e
skill_dir = resolve_skill_dir_from_archive(tmp_path)
decision = getattr(result, "decision", None)
reason = str(getattr(result, "reason", "") or "No reason provided.")
if decision == "block":
if rel_path == "SKILL.md":
raise SkillSecurityScanError(f"Security scan blocked skill '{skill_name}': {reason}")
raise SkillSecurityScanError(f"Security scan blocked {location}: {reason}")
if executable and decision != "allow":
raise SkillSecurityScanError(f"Security scan rejected executable {location}: {reason}")
if decision not in {"allow", "warn"}:
raise SkillSecurityScanError(f"Security scan failed for {location}: invalid scanner decision {decision!r}")
is_valid, message, skill_name = _validate_skill_frontmatter(skill_dir)
if not is_valid:
raise ValueError(f"Invalid skill: {message}")
if not skill_name or "/" in skill_name or "\\" in skill_name or ".." in skill_name:
raise ValueError(f"Invalid skill name: {skill_name}")
target = custom_dir / skill_name
if target.exists():
raise SkillAlreadyExistsError(f"Skill '{skill_name}' already exists")
async def _scan_skill_archive_contents_or_raise(skill_dir: Path, skill_name: str) -> None:
"""Run the skill security scanner against all installable text and script files."""
skill_md = skill_dir / "SKILL.md"
await _scan_skill_file_or_raise(skill_dir, skill_md, skill_name, executable=False)
shutil.copytree(skill_dir, target)
logger.info("Skill %r installed to %s", skill_name, target)
for path in sorted(skill_dir.rglob("*")):
if not path.is_file():
continue
return {
"success": True,
"skill_name": skill_name,
"message": f"Skill '{skill_name}' installed successfully",
}
rel_path = path.relative_to(skill_dir)
if rel_path == Path("SKILL.md"):
continue
if path.name == "SKILL.md":
raise SkillSecurityScanError(f"Security scan failed for skill '{skill_name}': nested SKILL.md is not allowed at {skill_name}/{rel_path.as_posix()}")
if not _should_scan_support_file(rel_path):
continue
await _scan_skill_file_or_raise(skill_dir, path, skill_name, executable=_is_script_support_file(rel_path))
def _run_async_install(coro):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
return executor.submit(asyncio.run, coro).result()
return asyncio.run(coro)
@@ -1,103 +0,0 @@
import logging
import os
from pathlib import Path
from .parser import parse_skill_file
from .types import Skill
logger = logging.getLogger(__name__)
def get_skills_root_path() -> Path:
"""
Get the root path of the skills directory.
Returns:
Path to the skills directory (deer-flow/skills)
"""
# loader.py lives at packages/harness/deerflow/skills/loader.py — 5 parents up reaches backend/
backend_dir = Path(__file__).resolve().parent.parent.parent.parent.parent
# skills directory is sibling to backend directory
skills_dir = backend_dir.parent / "skills"
return skills_dir
def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False) -> list[Skill]:
"""
Load all skills from the skills directory.
Scans both public and custom skill directories, parsing SKILL.md files
to extract metadata. The enabled state is determined by the skills_state_config.json file.
Args:
skills_path: Optional custom path to skills directory.
If not provided and use_config is True, uses path from config.
Otherwise defaults to deer-flow/skills
use_config: Whether to load skills path from config (default: True)
enabled_only: If True, only return enabled skills (default: False)
Returns:
List of Skill objects, sorted by name
"""
if skills_path is None:
if use_config:
try:
from deerflow.config import get_app_config
config = get_app_config()
skills_path = config.skills.get_skills_path()
except Exception:
# Fallback to default if config fails
skills_path = get_skills_root_path()
else:
skills_path = get_skills_root_path()
if not skills_path.exists():
return []
skills_by_name: dict[str, Skill] = {}
# Scan public and custom directories
for category in ["public", "custom"]:
category_path = skills_path / category
if not category_path.exists() or not category_path.is_dir():
continue
for current_root, dir_names, file_names in os.walk(category_path, followlinks=True):
# Keep traversal deterministic and skip hidden directories.
dir_names[:] = sorted(name for name in dir_names if not name.startswith("."))
if "SKILL.md" not in file_names:
continue
skill_file = Path(current_root) / "SKILL.md"
relative_path = skill_file.parent.relative_to(category_path)
skill = parse_skill_file(skill_file, category=category, relative_path=relative_path)
if skill:
skills_by_name[skill.name] = skill
skills = list(skills_by_name.values())
# Load skills state configuration and update enabled status
# NOTE: We use ExtensionsConfig.from_file() instead of get_extensions_config()
# to always read the latest configuration from disk. This ensures that changes
# made through the Gateway API (which runs in a separate process) are immediately
# reflected in the LangGraph Server when loading skills.
try:
from deerflow.config.extensions_config import ExtensionsConfig
extensions_config = ExtensionsConfig.from_file()
for skill in skills:
skill.enabled = extensions_config.is_skill_enabled(skill.name, skill.category)
except Exception as e:
# If config loading fails, default to all enabled
logger.warning("Failed to load extensions config: %s", e)
# Filter by enabled status if requested
if enabled_only:
skills = [skill for skill in skills if skill.enabled]
# Sort by name for consistent ordering
skills.sort(key=lambda s: s.name)
return skills
@@ -1,159 +0,0 @@
"""Utilities for managing custom skills and their history."""
from __future__ import annotations
import json
import re
import tempfile
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from deerflow.config import get_app_config
from deerflow.skills.loader import load_skills
from deerflow.skills.validation import _validate_skill_frontmatter
SKILL_FILE_NAME = "SKILL.md"
HISTORY_FILE_NAME = "HISTORY.jsonl"
HISTORY_DIR_NAME = ".history"
ALLOWED_SUPPORT_SUBDIRS = {"references", "templates", "scripts", "assets"}
_SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
def get_skills_root_dir() -> Path:
return get_app_config().skills.get_skills_path()
def get_public_skills_dir() -> Path:
return get_skills_root_dir() / "public"
def get_custom_skills_dir() -> Path:
path = get_skills_root_dir() / "custom"
path.mkdir(parents=True, exist_ok=True)
return path
def validate_skill_name(name: str) -> str:
normalized = name.strip()
if not _SKILL_NAME_PATTERN.fullmatch(normalized):
raise ValueError("Skill name must be hyphen-case using lowercase letters, digits, and hyphens only.")
if len(normalized) > 64:
raise ValueError("Skill name must be 64 characters or fewer.")
return normalized
def get_custom_skill_dir(name: str) -> Path:
return get_custom_skills_dir() / validate_skill_name(name)
def get_custom_skill_file(name: str) -> Path:
return get_custom_skill_dir(name) / SKILL_FILE_NAME
def get_custom_skill_history_dir() -> Path:
path = get_custom_skills_dir() / HISTORY_DIR_NAME
path.mkdir(parents=True, exist_ok=True)
return path
def get_skill_history_file(name: str) -> Path:
return get_custom_skill_history_dir() / f"{validate_skill_name(name)}.jsonl"
def get_public_skill_dir(name: str) -> Path:
return get_public_skills_dir() / validate_skill_name(name)
def custom_skill_exists(name: str) -> bool:
return get_custom_skill_file(name).exists()
def public_skill_exists(name: str) -> bool:
return (get_public_skill_dir(name) / SKILL_FILE_NAME).exists()
def ensure_custom_skill_is_editable(name: str) -> None:
if custom_skill_exists(name):
return
if public_skill_exists(name):
raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.")
raise FileNotFoundError(f"Custom skill '{name}' not found.")
def ensure_safe_support_path(name: str, relative_path: str) -> Path:
skill_dir = get_custom_skill_dir(name).resolve()
if not relative_path or relative_path.endswith("/"):
raise ValueError("Supporting file path must include a filename.")
relative = Path(relative_path)
if relative.is_absolute():
raise ValueError("Supporting file path must be relative.")
if any(part in {"..", ""} for part in relative.parts):
raise ValueError("Supporting file path must not contain parent-directory traversal.")
top_level = relative.parts[0] if relative.parts else ""
if top_level not in ALLOWED_SUPPORT_SUBDIRS:
raise ValueError(f"Supporting files must live under one of: {', '.join(sorted(ALLOWED_SUPPORT_SUBDIRS))}.")
target = (skill_dir / relative).resolve()
allowed_root = (skill_dir / top_level).resolve()
try:
target.relative_to(allowed_root)
except ValueError as exc:
raise ValueError("Supporting file path must stay within the selected support directory.") from exc
return target
def validate_skill_markdown_content(name: str, content: str) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
temp_skill_dir = Path(tmp_dir) / validate_skill_name(name)
temp_skill_dir.mkdir(parents=True, exist_ok=True)
(temp_skill_dir / SKILL_FILE_NAME).write_text(content, encoding="utf-8")
is_valid, message, parsed_name = _validate_skill_frontmatter(temp_skill_dir)
if not is_valid:
raise ValueError(message)
if parsed_name != name:
raise ValueError(f"Frontmatter name '{parsed_name}' must match requested skill name '{name}'.")
def atomic_write(path: Path, content: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile("w", encoding="utf-8", delete=False, dir=str(path.parent)) as tmp_file:
tmp_file.write(content)
tmp_path = Path(tmp_file.name)
tmp_path.replace(path)
def append_history(name: str, record: dict[str, Any]) -> None:
history_path = get_skill_history_file(name)
history_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"ts": datetime.now(UTC).isoformat(),
**record,
}
with history_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False))
f.write("\n")
def read_history(name: str) -> list[dict[str, Any]]:
history_path = get_skill_history_file(name)
if not history_path.exists():
return []
records: list[dict[str, Any]] = []
for line in history_path.read_text(encoding="utf-8").splitlines():
if not line.strip():
continue
records.append(json.loads(line))
return records
def list_custom_skills() -> list:
return [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
def read_custom_skill_content(name: str) -> str:
skill_file = get_custom_skill_file(name)
if not skill_file.exists():
raise FileNotFoundError(f"Custom skill '{name}' not found.")
return skill_file.read_text(encoding="utf-8")
@@ -4,24 +4,24 @@ from pathlib import Path
import yaml
from .types import Skill
from .types import SKILL_MD_FILE, Skill, SkillCategory
logger = logging.getLogger(__name__)
def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None = None) -> Skill | None:
def parse_skill_file(skill_file: Path, category: SkillCategory, relative_path: Path | None = None) -> Skill | None:
"""Parse a SKILL.md file and extract metadata.
Args:
skill_file: Path to the SKILL.md file.
category: Category of the skill ('public' or 'custom').
category: Category of the skill.
relative_path: Relative path from the category root to the skill
directory. Defaults to the skill directory name when omitted.
Returns:
Skill object if parsing succeeds, None otherwise.
"""
if not skill_file.exists() or skill_file.name != "SKILL.md":
if not skill_file.exists() or skill_file.name != SKILL_MD_FILE:
return None
try:
@@ -8,7 +8,9 @@ import re
from dataclasses import dataclass
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
from deerflow.skills.types import SKILL_MD_FILE
logger = logging.getLogger(__name__)
@@ -35,7 +37,7 @@ def _extract_json_object(raw: str) -> dict | None:
return None
async def scan_skill_content(content: str, *, executable: bool = False, location: str = "SKILL.md") -> ScanResult:
async def scan_skill_content(content: str, *, executable: bool = False, location: str = SKILL_MD_FILE, app_config: AppConfig | None = None) -> ScanResult:
"""Screen skill content before it is written to disk."""
rubric = (
"You are a security reviewer for AI agent skills. "
@@ -47,9 +49,9 @@ async def scan_skill_content(content: str, *, executable: bool = False, location
prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----"
try:
config = get_app_config()
config = app_config or get_app_config()
model_name = config.skill_evolution.moderation_model_name
model = create_chat_model(name=model_name, thinking_enabled=False) if model_name else create_chat_model(thinking_enabled=False)
model = create_chat_model(name=model_name, thinking_enabled=False, app_config=config) if model_name else create_chat_model(thinking_enabled=False, app_config=config)
response = await model.ainvoke(
[
{"role": "system", "content": rubric},
@@ -0,0 +1,83 @@
"""SkillStorage singleton + reflection-based factory.
Mirrors the pattern used by ``deerflow/sandbox/sandbox_provider.py``.
"""
from __future__ import annotations
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
from deerflow.skills.storage.skill_storage import SkillStorage
_default_skill_storage: SkillStorage | None = None
_default_skill_storage_config: object | None = None # AppConfig identity the singleton was built from
def get_or_new_skill_storage(**kwargs) -> SkillStorage:
"""Return a ``SkillStorage`` instance — either a new one or the process singleton.
**New instance** is created (never cached) when:
- ``skills_path`` is provided uses it as the ``host_path`` override (class still resolved via config).
- ``app_config`` is provided constructs a storage from ``app_config.skills``
so that per-request config (e.g. Gateway ``Depends(get_config)``) is respected
without polluting the process-level singleton.
**Singleton** is returned (created on first call, then reused) when neither
``skills_path`` nor ``app_config`` is given uses ``get_app_config()`` to
resolve the active configuration.
"""
global _default_skill_storage, _default_skill_storage_config
from deerflow.config import get_app_config
from deerflow.config.skills_config import SkillsConfig
def _make_storage(skills_config: SkillsConfig, *, host_path: str | None = None, **kwargs) -> SkillStorage:
from deerflow.reflection import resolve_class
cls = resolve_class(skills_config.use, SkillStorage)
return cls(
host_path=host_path if host_path is not None else str(skills_config.get_skills_path()),
container_path=skills_config.container_path,
**kwargs,
)
skills_path = kwargs.pop("skills_path", None)
app_config = kwargs.pop("app_config", None)
if skills_path is not None:
if app_config is not None:
return _make_storage(app_config.skills, host_path=str(skills_path), **kwargs)
# No app_config: use a default SkillsConfig so we never need to read config.yaml
# when the caller has already supplied an explicit host path.
from deerflow.config.skills_config import SkillsConfig
return _make_storage(SkillsConfig(), host_path=str(skills_path), **kwargs)
if app_config is not None:
return _make_storage(app_config.skills, **kwargs)
# If the singleton was manually injected (e.g. in tests) without a config
# identity (_default_skill_storage_config is None), skip get_app_config()
# entirely to avoid requiring a config.yaml on disk.
if _default_skill_storage is not None and _default_skill_storage_config is None:
return _default_skill_storage
app_config_now = get_app_config()
if _default_skill_storage is None or _default_skill_storage_config is not app_config_now:
_default_skill_storage = _make_storage(app_config_now.skills, **kwargs)
_default_skill_storage_config = app_config_now
return _default_skill_storage
def reset_skill_storage() -> None:
"""Clear the cached singleton (used in tests and hot-reload scenarios)."""
global _default_skill_storage, _default_skill_storage_config
_default_skill_storage = None
_default_skill_storage_config = None
__all__ = [
"LocalSkillStorage",
"SkillStorage",
"get_or_new_skill_storage",
"reset_skill_storage",
]
@@ -0,0 +1,195 @@
"""Local-filesystem implementation of ``SkillStorage``."""
from __future__ import annotations
import errno
import json
import logging
import os
import shutil
import tempfile
from collections.abc import Iterable
from datetime import UTC, datetime
from pathlib import Path
from deerflow.config.runtime_paths import resolve_path
from deerflow.skills.storage.skill_storage import SKILL_MD_FILE, SkillStorage
from deerflow.skills.types import SkillCategory
logger = logging.getLogger(__name__)
DEFAULT_SKILLS_CONTAINER_PATH = "/mnt/skills"
class LocalSkillStorage(SkillStorage):
"""Skill storage backed by the local filesystem.
Layout::
<root>/public/<name>/SKILL.md
<root>/custom/<name>/SKILL.md
<root>/custom/.history/<name>.jsonl
"""
def __init__(
self,
host_path: str | None = None,
container_path: str = DEFAULT_SKILLS_CONTAINER_PATH,
app_config=None,
) -> None:
super().__init__(container_path=container_path)
if host_path is None:
from deerflow.config import get_app_config
config = app_config or get_app_config()
self._host_root: Path = config.skills.get_skills_path()
else:
self._host_root = resolve_path(host_path)
# ------------------------------------------------------------------
# Abstract operation implementations
# ------------------------------------------------------------------
def get_skills_root_path(self) -> Path:
return self._host_root
def custom_skill_exists(self, name: str) -> bool:
return self.get_custom_skill_file(name).exists()
def public_skill_exists(self, name: str) -> bool:
normalized_name = self.validate_skill_name(name)
return (self._host_root / SkillCategory.PUBLIC.value / normalized_name / SKILL_MD_FILE).exists()
def _iter_skill_files(self) -> Iterable[tuple[SkillCategory, Path, Path]]:
if not self._host_root.exists():
return
for category in SkillCategory:
category_path = self._host_root / category.value
if not category_path.exists() or not category_path.is_dir():
continue
for current_root, dir_names, file_names in os.walk(category_path, followlinks=True):
dir_names[:] = sorted(name for name in dir_names if not name.startswith("."))
if SKILL_MD_FILE not in file_names:
continue
yield category, category_path, Path(current_root) / SKILL_MD_FILE
def read_custom_skill(self, name: str) -> str:
if not self.custom_skill_exists(name):
raise FileNotFoundError(f"Custom skill '{name}' not found.")
return (self.get_custom_skill_dir(name) / SKILL_MD_FILE).read_text(encoding="utf-8")
def write_custom_skill(self, name: str, relative_path: str, content: str) -> None:
target = self.validate_relative_path(relative_path, self.get_custom_skill_dir(name))
target.parent.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(
"w",
encoding="utf-8",
delete=False,
dir=str(target.parent),
) as tmp_file:
tmp_file.write(content)
tmp_path = Path(tmp_file.name)
tmp_path.replace(target)
async def ainstall_skill_from_archive(self, archive_path: str | Path) -> dict:
import zipfile
from deerflow.skills.installer import (
SkillAlreadyExistsError,
_move_staged_skill_into_reserved_target,
_scan_skill_archive_contents_or_raise,
resolve_skill_dir_from_archive,
safe_extract_skill_archive,
)
from deerflow.skills.validation import _validate_skill_frontmatter
logger.info("Installing skill from %s", archive_path)
path = Path(archive_path)
if not path.is_file():
if not path.exists():
raise FileNotFoundError(f"Skill file not found: {archive_path}")
raise ValueError(f"Path is not a file: {archive_path}")
if path.suffix != ".skill":
raise ValueError("File must have .skill extension")
custom_dir = self._host_root / "custom"
custom_dir.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
try:
zf = zipfile.ZipFile(path, "r")
except FileNotFoundError:
raise FileNotFoundError(f"Skill file not found: {archive_path}") from None
except (zipfile.BadZipFile, IsADirectoryError):
raise ValueError("File is not a valid ZIP archive") from None
with zf:
safe_extract_skill_archive(zf, tmp_path)
skill_dir = resolve_skill_dir_from_archive(tmp_path)
is_valid, message, skill_name = _validate_skill_frontmatter(skill_dir)
if not is_valid:
raise ValueError(f"Invalid skill: {message}")
if not skill_name or "/" in skill_name or "\\" in skill_name or ".." in skill_name:
raise ValueError(f"Invalid skill name: {skill_name}")
target = custom_dir / skill_name
if target.exists():
raise SkillAlreadyExistsError(f"Skill '{skill_name}' already exists")
await _scan_skill_archive_contents_or_raise(skill_dir, skill_name)
with tempfile.TemporaryDirectory(prefix=f".installing-{skill_name}-", dir=custom_dir) as staging_root:
staging_target = Path(staging_root) / skill_name
shutil.copytree(skill_dir, staging_target)
_move_staged_skill_into_reserved_target(staging_target, target)
logger.info("Skill %r installed to %s", skill_name, target)
return {
"success": True,
"skill_name": skill_name,
"message": f"Skill '{skill_name}' installed successfully",
}
def delete_custom_skill(self, name: str, *, history_meta: dict | None = None) -> None:
self.validate_skill_name(name)
self.ensure_custom_skill_is_editable(name)
target = self.get_custom_skill_dir(name)
if history_meta is not None:
prev_content = self.read_custom_skill(name)
try:
self.append_history(name, {**history_meta, "prev_content": prev_content})
except OSError as e:
if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}:
raise
logger.warning(
"Skipping delete history write for custom skill %s due to readonly/permission failure; continuing with skill directory removal: %s",
name,
e,
)
if target.exists():
shutil.rmtree(target)
def append_history(self, name: str, record: dict) -> None:
self.validate_skill_name(name)
payload = {"ts": datetime.now(UTC).isoformat(), **record}
history_path = self.get_skill_history_file(name)
history_path.parent.mkdir(parents=True, exist_ok=True)
with history_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False))
f.write("\n")
def read_history(self, name: str) -> list[dict]:
self.validate_skill_name(name)
history_path = self.get_skill_history_file(name)
if not history_path.exists():
return []
records: list[dict] = []
for line in history_path.read_text(encoding="utf-8").splitlines():
if not line.strip():
continue
records.append(json.loads(line))
return records
@@ -0,0 +1,254 @@
"""Abstract SkillStorage base class with template-method flows."""
from __future__ import annotations
import logging
import re
from abc import ABC, abstractmethod
from collections.abc import Iterable
from pathlib import Path
from deerflow.skills.types import SKILL_MD_FILE, Skill, SkillCategory # noqa: F401
logger = logging.getLogger(__name__)
_SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
class SkillStorage(ABC):
"""Abstract base for skill storage backends.
Subclasses implement a small set of storage-medium-specific atomic
operations; this base class provides final template-method flows
(load_skills, history serialisation, path helpers, validation) that
compose them with protocol-level helpers.
"""
def __init__(self, container_path: str = "/mnt/skills") -> None:
self._container_root = container_path
# ------------------------------------------------------------------
# Static protocol helpers (not storage-specific)
# ------------------------------------------------------------------
@staticmethod
def validate_skill_name(name: str) -> str:
"""Validate and normalise a skill name; return the normalised form."""
normalized = name.strip()
if not _SKILL_NAME_PATTERN.fullmatch(normalized):
raise ValueError("Skill name must be hyphen-case using lowercase letters, digits, and hyphens only.")
if len(normalized) > 64:
raise ValueError("Skill name must be 64 characters or fewer.")
return normalized
@staticmethod
def validate_relative_path(relative_path: str, base_dir: Path) -> Path:
"""Validate *relative_path* against *base_dir* and return the resolved target.
Checks that *relative_path* is non-empty, then joins it with *base_dir*
and resolves the result (following symlinks). Raises ``ValueError`` if
the resolved target does not lie within *base_dir*.
"""
if not relative_path:
raise ValueError("relative_path must not be empty.")
resolved_base = base_dir.resolve()
target = (resolved_base / relative_path).resolve()
try:
target.relative_to(resolved_base)
except ValueError as exc:
raise ValueError("relative_path must resolve within the skill directory.") from exc
return target
@staticmethod
def validate_skill_markdown_content(name: str, content: str) -> None:
"""Validate SKILL.md content: parse frontmatter and check name matches."""
import tempfile
from deerflow.skills.validation import _validate_skill_frontmatter
with tempfile.TemporaryDirectory() as tmp_dir:
temp_skill_dir = Path(tmp_dir) / SkillStorage.validate_skill_name(name)
temp_skill_dir.mkdir(parents=True, exist_ok=True)
(temp_skill_dir / SKILL_MD_FILE).write_text(content, encoding="utf-8")
is_valid, message, parsed_name = _validate_skill_frontmatter(temp_skill_dir)
if not is_valid:
raise ValueError(message)
if parsed_name != name:
raise ValueError(f"Frontmatter name '{parsed_name}' must match requested skill name '{name}'.")
def ensure_safe_support_path(self, name: str, relative_path: str) -> Path:
"""Validate and return the resolved absolute path for a support file."""
_ALLOWED_SUPPORT_SUBDIRS = {"references", "templates", "scripts", "assets"}
skill_dir = self.get_custom_skill_dir(self.validate_skill_name(name)).resolve()
if not relative_path or relative_path.endswith("/"):
raise ValueError("Supporting file path must include a filename.")
relative = Path(relative_path)
if relative.is_absolute():
raise ValueError("Supporting file path must be relative.")
if any(part in {"..", ""} for part in relative.parts):
raise ValueError("Supporting file path must not contain parent-directory traversal.")
top_level = relative.parts[0] if relative.parts else ""
if top_level not in _ALLOWED_SUPPORT_SUBDIRS:
raise ValueError(f"Supporting files must live under one of: {', '.join(sorted(_ALLOWED_SUPPORT_SUBDIRS))}.")
target = (skill_dir / relative).resolve()
allowed_root = (skill_dir / top_level).resolve()
try:
target.relative_to(allowed_root)
except ValueError as exc:
raise ValueError("Supporting file path must stay within the selected support directory.") from exc
return target
# ------------------------------------------------------------------
# Abstract atomic operations (storage-medium specific)
# ------------------------------------------------------------------
@abstractmethod
def get_skills_root_path(self) -> Path:
"""Absolute host path to the skills root, used for sandbox mounts.
Origin: ``deerflow.skills.loader.get_skills_root_path``.
"""
@abstractmethod
def _iter_skill_files(self) -> Iterable[tuple[SkillCategory, Path, Path]]:
"""Yield ``(category, category_root, skill_md_path)`` for every SKILL.md.
Origin: extracted from directory-walk logic inside
``deerflow.skills.loader.load_skills``.
"""
@abstractmethod
def read_custom_skill(self, name: str) -> str:
"""Read SKILL.md content for a custom skill.
Origin: ``deerflow.skills.manager.read_custom_skill_content``.
"""
@abstractmethod
def write_custom_skill(self, name: str, relative_path: str, content: str) -> None:
"""Atomically write a text file under ``custom/<name>/<relative_path>``.
Origin: ``deerflow.skills.manager.atomic_write``.
"""
@abstractmethod
async def ainstall_skill_from_archive(self, archive_path: str | Path) -> dict:
"""Async install of a skill from a ``.skill`` ZIP archive.
Origin: ``deerflow.skills.installer.ainstall_skill_from_archive``.
"""
def install_skill_from_archive(self, archive_path: str | Path) -> dict:
"""Sync wrapper — delegates to :meth:`ainstall_skill_from_archive`."""
from deerflow.skills.installer import _run_async_install
return _run_async_install(self.ainstall_skill_from_archive(archive_path))
@abstractmethod
def delete_custom_skill(self, name: str, *, history_meta: dict | None = None) -> None:
"""Delete a custom skill (validation + optional history + directory removal).
Origin: ``app.gateway.routers.skills.delete_custom_skill`` + ``skill_manage_tool``.
"""
@abstractmethod
def custom_skill_exists(self, name: str) -> bool:
"""Origin: ``deerflow.skills.manager.custom_skill_exists``."""
@abstractmethod
def public_skill_exists(self, name: str) -> bool:
"""Origin: ``deerflow.skills.manager.public_skill_exists``."""
@abstractmethod
def append_history(self, name: str, record: dict) -> None:
"""Append a JSONL history entry for ``name``.
Origin: ``deerflow.skills.manager.append_history``.
"""
@abstractmethod
def read_history(self, name: str) -> list[dict]:
"""Return all history records for ``name``, oldest first.
Origin: ``deerflow.skills.manager.read_history``.
"""
# ------------------------------------------------------------------
# Concrete path helpers (layout is part of the SKILL.md protocol)
# ------------------------------------------------------------------
def get_container_root(self) -> str:
"""Origin: ``deerflow.config.skills_config.SkillsConfig.container_path`` accessor."""
return self._container_root
def get_custom_skill_dir(self, name: str) -> Path:
"""Path to ``custom/<name>``. Does not create the directory.
Origin: ``deerflow.skills.manager.get_custom_skill_dir``.
"""
normalized_name = self.validate_skill_name(name)
return self.get_skills_root_path() / SkillCategory.CUSTOM.value / normalized_name
def get_custom_skill_file(self, name: str) -> Path:
"""Path to ``custom/<name>/SKILL.md``.
Origin: ``deerflow.skills.manager.get_custom_skill_file``.
"""
normalized_name = self.validate_skill_name(name)
return self.get_custom_skill_dir(normalized_name) / SKILL_MD_FILE
def get_skill_history_file(self, name: str) -> Path:
"""Path to ``custom/.history/<name>.jsonl``. Does not create parents.
Origin: ``deerflow.skills.manager.get_skill_history_file``.
"""
normalized_name = self.validate_skill_name(name)
return self.get_skills_root_path() / SkillCategory.CUSTOM.value / ".history" / f"{normalized_name}.jsonl"
# ------------------------------------------------------------------
# Final template-method flows
# ------------------------------------------------------------------
def load_skills(self, *, enabled_only: bool = False) -> list[Skill]:
"""Discover all skills, merge enabled state, sort and optionally filter.
Origin: ``deerflow.skills.loader.load_skills``.
"""
from deerflow.skills.parser import parse_skill_file
skills_by_name: dict[str, Skill] = {}
for category, category_root, md_path in self._iter_skill_files():
skill = parse_skill_file(
md_path,
category=category,
relative_path=md_path.parent.relative_to(category_root),
)
if skill:
skills_by_name[skill.name] = skill
skills = list(skills_by_name.values())
# Merge enabled state from extensions config (re-read every call so
# changes made by another process are picked up immediately).
try:
from deerflow.config.extensions_config import ExtensionsConfig
extensions_config = ExtensionsConfig.from_file()
for skill in skills:
skill.enabled = extensions_config.is_skill_enabled(skill.name, skill.category)
except Exception as e:
logger.warning("Failed to load extensions config: %s", e)
if enabled_only:
skills = [s for s in skills if s.enabled]
skills.sort(key=lambda s: s.name)
return skills
def ensure_custom_skill_is_editable(self, name: str) -> None:
"""Origin: ``deerflow.skills.manager.ensure_custom_skill_is_editable``."""
if self.custom_skill_exists(name):
return
if self.public_skill_exists(name):
raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.")
raise FileNotFoundError(f"Custom skill '{name}' not found.")
@@ -1,6 +1,20 @@
from dataclasses import dataclass
from enum import StrEnum
from pathlib import Path
SKILL_MD_FILE = "SKILL.md"
class SkillCategory(StrEnum):
"""Source category for a skill.
- ``PUBLIC``: built-in skill bundled with the platform, read-only.
- ``CUSTOM``: user-authored skill that can be edited or deleted.
"""
PUBLIC = "public"
CUSTOM = "custom"
@dataclass
class Skill:
@@ -12,7 +26,7 @@ class Skill:
skill_dir: Path
skill_file: Path
relative_path: Path # Relative path from category root to skill directory
category: str # 'public' or 'custom'
category: SkillCategory # 'public' or 'custom'
enabled: bool = False # Whether this skill is enabled
@property
@@ -8,6 +8,8 @@ from pathlib import Path
import yaml
from deerflow.skills.types import SKILL_MD_FILE
# Allowed properties in SKILL.md frontmatter
ALLOWED_FRONTMATTER_PROPERTIES = {"name", "description", "license", "allowed-tools", "metadata", "compatibility", "version", "author"}
@@ -21,9 +23,9 @@ def _validate_skill_frontmatter(skill_dir: Path) -> tuple[bool, str, str | None]
Returns:
Tuple of (is_valid, message, skill_name).
"""
skill_md = skill_dir / "SKILL.md"
skill_md = skill_dir / SKILL_MD_FILE
if not skill_md.exists():
return False, "SKILL.md not found", None
return False, f"{SKILL_MD_FILE} not found", None
content = skill_md.read_text(encoding="utf-8")
if not content.startswith("---"):
@@ -1,6 +1,10 @@
"""Subagent configuration definitions."""
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
@dataclass
@@ -29,3 +33,24 @@ class SubagentConfig:
model: str = "inherit"
max_turns: int = 50
timeout_seconds: int = 900
def _default_model_name(app_config: "AppConfig") -> str:
if not app_config.models:
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
return app_config.models[0].name
def resolve_subagent_model_name(config: SubagentConfig, parent_model: str | None, *, app_config: "AppConfig | None" = None) -> str:
"""Resolve the effective model name a subagent should use."""
if config.model != "inherit":
return config.model
if parent_model is not None:
return parent_model
if app_config is None:
from deerflow.config import get_app_config
app_config = get_app_config()
return _default_model_name(app_config)
@@ -1,11 +1,14 @@
"""Subagent execution engine."""
import asyncio
import atexit
import logging
import threading
import uuid
from collections.abc import Callable, Coroutine
from concurrent.futures import Future, ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from contextvars import Context, copy_context
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
@@ -17,12 +20,20 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
from deerflow.subagents.config import SubagentConfig
from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name
logger = logging.getLogger(__name__)
_previous_shutdown_isolated_subagent_loop = globals().get("_shutdown_isolated_subagent_loop")
if callable(_previous_shutdown_isolated_subagent_loop):
atexit.unregister(_previous_shutdown_isolated_subagent_loop)
_previous_shutdown_isolated_subagent_loop()
class SubagentStatus(Enum):
"""Status of a subagent execution."""
@@ -72,12 +83,107 @@ _background_tasks_lock = threading.Lock()
# Thread pool for background task scheduling and orchestration
_scheduler_pool = ThreadPoolExecutor(max_workers=3, thread_name_prefix="subagent-scheduler-")
# Thread pool for actual subagent execution (with timeout support)
# Larger pool to avoid blocking when scheduler submits execution tasks
_execution_pool = ThreadPoolExecutor(max_workers=3, thread_name_prefix="subagent-exec-")
# Persistent event loop for isolated subagent executions triggered from an
# already-running parent loop. Reusing one long-lived loop avoids creating a
# fresh loop per execution and then closing async resources bound to it.
_isolated_subagent_loop: asyncio.AbstractEventLoop | None = None
_isolated_subagent_loop_thread: threading.Thread | None = None
_isolated_subagent_loop_started: threading.Event | None = None
_isolated_subagent_loop_lock = threading.Lock()
# Dedicated pool for sync execute() calls made from an already-running event loop.
_isolated_loop_pool = ThreadPoolExecutor(max_workers=3, thread_name_prefix="subagent-isolated-")
def _run_isolated_subagent_loop(
loop: asyncio.AbstractEventLoop,
started_event: threading.Event,
) -> None:
"""Run the persistent isolated subagent loop in a dedicated daemon thread."""
asyncio.set_event_loop(loop)
loop.call_soon(started_event.set)
try:
loop.run_forever()
finally:
started_event.clear()
def _shutdown_isolated_subagent_loop() -> None:
"""Stop and close the persistent isolated subagent loop."""
global _isolated_subagent_loop, _isolated_subagent_loop_thread, _isolated_subagent_loop_started
with _isolated_subagent_loop_lock:
loop = _isolated_subagent_loop
thread = _isolated_subagent_loop_thread
_isolated_subagent_loop = None
_isolated_subagent_loop_thread = None
_isolated_subagent_loop_started = None
if loop is None:
return
if loop.is_running():
loop.call_soon_threadsafe(loop.stop)
if thread is not None and thread.is_alive() and thread is not threading.current_thread():
thread.join(timeout=1)
thread_stopped = thread is None or not thread.is_alive()
loop_stopped = not loop.is_running()
if not loop.is_closed():
if thread_stopped and loop_stopped:
loop.close()
else:
logger.warning(
"Skipping close of isolated subagent loop because shutdown did not complete within timeout (thread_alive=%s, loop_running=%s)",
thread is not None and thread.is_alive(),
loop.is_running(),
)
atexit.register(_shutdown_isolated_subagent_loop)
def _get_isolated_subagent_loop() -> asyncio.AbstractEventLoop:
"""Return the persistent event loop used by isolated subagent executions."""
global _isolated_subagent_loop, _isolated_subagent_loop_thread, _isolated_subagent_loop_started
with _isolated_subagent_loop_lock:
thread_is_alive = _isolated_subagent_loop_thread is not None and _isolated_subagent_loop_thread.is_alive()
loop_is_usable = _isolated_subagent_loop is not None and not _isolated_subagent_loop.is_closed() and _isolated_subagent_loop.is_running() and thread_is_alive
if not loop_is_usable:
loop = asyncio.new_event_loop()
started_event = threading.Event()
thread = threading.Thread(
target=_run_isolated_subagent_loop,
args=(loop, started_event),
name="subagent-persistent-loop",
daemon=True,
)
thread.start()
if not started_event.wait(timeout=5):
loop.call_soon_threadsafe(loop.stop)
thread.join(timeout=1)
loop.close()
raise RuntimeError("Timed out starting isolated subagent event loop")
_isolated_subagent_loop = loop
_isolated_subagent_loop_thread = thread
_isolated_subagent_loop_started = started_event
if _isolated_subagent_loop is None:
raise RuntimeError("Isolated subagent event loop is not initialized")
return _isolated_subagent_loop
def _submit_to_isolated_loop_in_context(
context: Context,
coro_factory: Callable[[], Coroutine[Any, Any, SubagentResult]],
) -> Future[SubagentResult]:
"""Submit a coroutine to the isolated loop while preserving ContextVar state."""
return context.run(
lambda: asyncio.run_coroutine_threadsafe(
coro_factory(),
_get_isolated_subagent_loop(),
)
)
def _filter_tools(
@@ -110,21 +216,6 @@ def _filter_tools(
return filtered
def _get_model_name(config: SubagentConfig, parent_model: str | None) -> str | None:
"""Resolve the model name for a subagent.
Args:
config: Subagent configuration.
parent_model: The parent agent's model name.
Returns:
Model name to use, or None to use default.
"""
if config.model == "inherit":
return parent_model
return config.model
class SubagentExecutor:
"""Executor for running subagents."""
@@ -132,6 +223,7 @@ class SubagentExecutor:
self,
config: SubagentConfig,
tools: list[BaseTool],
app_config: AppConfig | None = None,
parent_model: str | None = None,
sandbox_state: SandboxState | None = None,
thread_data: ThreadDataState | None = None,
@@ -143,6 +235,9 @@ class SubagentExecutor:
Args:
config: Subagent configuration.
tools: List of all available tools (will be filtered).
app_config: Resolved AppConfig. When None, ``_create_agent`` falls
back to ``get_app_config()`` (matches the lead-agent factory's
pattern).
parent_model: The parent agent's model name for inheritance.
sandbox_state: Sandbox state from parent agent.
thread_data: Thread data from parent agent.
@@ -150,7 +245,15 @@ class SubagentExecutor:
trace_id: Trace ID from parent for distributed tracing.
"""
self.config = config
self.app_config = app_config
self.parent_model = parent_model
# Resolve eagerly only when it does not require loading config.yaml; otherwise defer
# to _create_agent (which already loads app_config) so unit tests can construct
# executors without a config file present.
if config.model != "inherit" or parent_model is not None or app_config is not None:
self.model_name: str | None = resolve_subagent_model_name(config, parent_model, app_config=app_config)
else:
self.model_name = None
self.sandbox_state = sandbox_state
self.thread_data = thread_data
self.thread_id = thread_id
@@ -168,13 +271,15 @@ class SubagentExecutor:
def _create_agent(self):
"""Create the agent instance."""
model_name = _get_model_name(self.config, self.parent_model)
model = create_chat_model(name=model_name, thinking_enabled=False)
app_config = self.app_config or get_app_config()
if self.model_name is None:
self.model_name = resolve_subagent_model_name(self.config, self.parent_model, app_config=app_config)
model = create_chat_model(name=self.model_name, thinking_enabled=False, app_config=app_config)
from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
# Reuse shared middleware composition with lead agent.
middlewares = build_subagent_runtime_middlewares(lazy_init=True)
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=self.model_name, lazy_init=True)
return create_agent(
model=model,
@@ -203,10 +308,12 @@ class SubagentExecutor:
return []
try:
from deerflow.skills.loader import load_skills
from deerflow.skills.storage import get_or_new_skill_storage
storage_kwargs = {"app_config": self.app_config} if self.app_config is not None else {}
storage = await asyncio.to_thread(get_or_new_skill_storage, **storage_kwargs)
# Use asyncio.to_thread to avoid blocking the event loop (LangGraph ASGI requirement)
all_skills = await asyncio.to_thread(load_skills, enabled_only=True)
all_skills = await asyncio.to_thread(storage.load_skills, enabled_only=True)
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} loaded {len(all_skills)} enabled skills from disk")
except Exception:
logger.warning(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}", exc_info=True)
@@ -292,6 +399,10 @@ class SubagentExecutor:
status=SubagentStatus.RUNNING,
started_at=datetime.now(),
)
ai_messages = result.ai_messages
if ai_messages is None:
ai_messages = []
result.ai_messages = ai_messages
try:
agent = self._create_agent()
@@ -301,10 +412,12 @@ class SubagentExecutor:
run_config: RunnableConfig = {
"recursion_limit": self.config.max_turns,
}
context = {}
context: dict[str, Any] = {}
if self.thread_id:
run_config["configurable"] = {"thread_id": self.thread_id}
context["thread_id"] = self.thread_id
if self.app_config is not None:
context["app_config"] = self.app_config
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting async execution with max_turns={self.config.max_turns}")
@@ -351,13 +464,13 @@ class SubagentExecutor:
message_id = message_dict.get("id")
is_duplicate = False
if message_id:
is_duplicate = any(msg.get("id") == message_id for msg in result.ai_messages)
is_duplicate = any(msg.get("id") == message_id for msg in ai_messages)
else:
is_duplicate = message_dict in result.ai_messages
is_duplicate = message_dict in ai_messages
if not is_duplicate:
result.ai_messages.append(message_dict)
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(result.ai_messages)}")
ai_messages.append(message_dict)
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(ai_messages)}")
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution")
@@ -443,42 +556,40 @@ class SubagentExecutor:
return result
def _execute_in_isolated_loop(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
"""Execute the subagent in a completely fresh event loop.
"""Execute the subagent on the persistent isolated event loop.
This method is designed to run in a separate thread to ensure complete
isolation from any parent event loop, preventing conflicts with asyncio
primitives that may be bound to the parent loop (e.g., httpx clients).
This method is used by the sync ``execute()`` path when the caller is
already running inside an event loop. Because ``execute()`` is a sync
API, this path blocks the caller while the actual coroutine runs on the
long-lived isolated loop. Reusing that loop keeps shared async clients
from being tied to a short-lived loop that gets closed per execution.
"""
future: Future[SubagentResult] | None = None
parent_context = copy_context()
try:
previous_loop = asyncio.get_event_loop()
except RuntimeError:
previous_loop = None
# Create and set a new event loop for this thread
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
return loop.run_until_complete(self._aexecute(task, result_holder))
finally:
try:
pending = asyncio.all_tasks(loop)
if pending:
for task_obj in pending:
task_obj.cancel()
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(loop.shutdown_default_executor())
except Exception:
future = _submit_to_isolated_loop_in_context(
parent_context,
lambda: self._aexecute(task, result_holder),
)
return future.result(timeout=self.config.timeout_seconds)
except FuturesTimeoutError:
if result_holder is not None:
result_holder.cancel_event.set()
if future is not None:
future.cancel()
raise
except Exception:
if future is None:
logger.debug(
f"[trace={self.trace_id}] Failed while cleaning up isolated event loop for subagent {self.config.name}",
f"[trace={self.trace_id}] Failed to submit subagent {self.config.name} to the isolated event loop",
exc_info=True,
)
finally:
try:
loop.close()
finally:
asyncio.set_event_loop(previous_loop)
else:
logger.debug(
f"[trace={self.trace_id}] Subagent {self.config.name} failed while executing on the isolated event loop",
exc_info=True,
)
raise
def execute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
"""Execute a task synchronously (wrapper around async execution).
@@ -487,9 +598,9 @@ class SubagentExecutor:
asynchronous tools (like MCP tools) to be used within the thread pool.
When called from within an already-running event loop (e.g., when the
parent agent is async), this method isolates the subagent execution in
a separate thread to avoid event loop conflicts with shared async
primitives like httpx clients.
parent agent is async), this method synchronously waits on the
persistent isolated loop to avoid event loop conflicts with shared
async primitives like httpx clients.
Args:
task: The task description for the subagent.
@@ -505,9 +616,8 @@ class SubagentExecutor:
loop = None
if loop is not None and loop.is_running():
logger.debug(f"[trace={self.trace_id}] Subagent {self.config.name} detected running event loop, using isolated thread")
future = _isolated_loop_pool.submit(self._execute_in_isolated_loop, task, result_holder)
return future.result()
logger.debug(f"[trace={self.trace_id}] Subagent {self.config.name} detected running event loop, using isolated loop")
return self._execute_in_isolated_loop(task, result_holder)
# Standard path: no running event loop, use asyncio.run
return asyncio.run(self._aexecute(task, result_holder))
@@ -553,6 +663,8 @@ class SubagentExecutor:
with _background_tasks_lock:
_background_tasks[task_id] = result
parent_context = copy_context()
# Submit to scheduler pool
def run_task():
with _background_tasks_lock:
@@ -561,9 +673,12 @@ class SubagentExecutor:
result_holder = _background_tasks[task_id]
try:
# Submit execution to execution pool with timeout
# Pass result_holder so execute() can update it in real-time
execution_future: Future = _execution_pool.submit(self.execute, task, result_holder)
# Submit execution directly to the persistent isolated loop so the
# background path does not create a temporary loop via execute().
execution_future = _submit_to_isolated_loop_in_context(
parent_context,
lambda: self._aexecute(task, result_holder),
)
try:
# Wait for execution with timeout
exec_result = execution_future.result(timeout=self.config.timeout_seconds)
@@ -2,6 +2,7 @@
import logging
from dataclasses import replace
from typing import Any
from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
@@ -10,19 +11,26 @@ from deerflow.subagents.config import SubagentConfig
logger = logging.getLogger(__name__)
def _build_custom_subagent_config(name: str) -> SubagentConfig | None:
def _resolve_subagents_app_config(app_config: Any | None = None):
if app_config is None:
from deerflow.config.subagents_config import get_subagents_app_config
return get_subagents_app_config()
return getattr(app_config, "subagents", app_config)
def _build_custom_subagent_config(name: str, *, app_config: Any | None = None) -> SubagentConfig | None:
"""Build a SubagentConfig from config.yaml custom_agents section.
Args:
name: The name of the custom subagent.
app_config: Optional AppConfig or SubagentsAppConfig to resolve from.
Returns:
SubagentConfig if found in custom_agents, None otherwise.
"""
from deerflow.config.subagents_config import get_subagents_app_config
app_config = get_subagents_app_config()
custom = app_config.custom_agents.get(name)
subagents_config = _resolve_subagents_app_config(app_config)
custom = subagents_config.custom_agents.get(name)
if custom is None:
return None
@@ -39,7 +47,7 @@ def _build_custom_subagent_config(name: str) -> SubagentConfig | None:
)
def get_subagent_config(name: str) -> SubagentConfig | None:
def get_subagent_config(name: str, *, app_config: Any | None = None) -> SubagentConfig | None:
"""Get a subagent configuration by name, with config.yaml overrides applied.
Resolution order (mirrors Codex's config layering):
@@ -49,6 +57,7 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
Args:
name: The name of the subagent.
app_config: Optional AppConfig or SubagentsAppConfig to resolve overrides from.
Returns:
SubagentConfig if found (with any config.yaml overrides applied), None otherwise.
@@ -56,7 +65,7 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
# Step 1: Look up built-in, then fall back to custom_agents
config = BUILTIN_SUBAGENTS.get(name)
if config is None:
config = _build_custom_subagent_config(name)
config = _build_custom_subagent_config(name, app_config=app_config)
if config is None:
return None
@@ -65,12 +74,9 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
# (timeout_seconds, max_turns at the top level) apply to built-in agents
# but must NOT override custom agents' own values — custom agents define
# their own defaults in the custom_agents section.
# Lazy import to avoid circular deps.
from deerflow.config.subagents_config import get_subagents_app_config
app_config = get_subagents_app_config()
subagents_config = _resolve_subagents_app_config(app_config)
is_builtin = name in BUILTIN_SUBAGENTS
agent_override = app_config.agents.get(name)
agent_override = subagents_config.agents.get(name)
overrides = {}
@@ -79,27 +85,27 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
if agent_override.timeout_seconds != config.timeout_seconds:
logger.debug("Subagent '%s': timeout overridden (%ss -> %ss)", name, config.timeout_seconds, agent_override.timeout_seconds)
overrides["timeout_seconds"] = agent_override.timeout_seconds
elif is_builtin and app_config.timeout_seconds != config.timeout_seconds:
logger.debug("Subagent '%s': timeout from global default (%ss -> %ss)", name, config.timeout_seconds, app_config.timeout_seconds)
overrides["timeout_seconds"] = app_config.timeout_seconds
elif is_builtin and subagents_config.timeout_seconds != config.timeout_seconds:
logger.debug("Subagent '%s': timeout from global default (%ss -> %ss)", name, config.timeout_seconds, subagents_config.timeout_seconds)
overrides["timeout_seconds"] = subagents_config.timeout_seconds
# Max turns: per-agent override > global default (builtins only) > config's own value
if agent_override is not None and agent_override.max_turns is not None:
if agent_override.max_turns != config.max_turns:
logger.debug("Subagent '%s': max_turns overridden (%s -> %s)", name, config.max_turns, agent_override.max_turns)
overrides["max_turns"] = agent_override.max_turns
elif is_builtin and app_config.max_turns is not None and app_config.max_turns != config.max_turns:
logger.debug("Subagent '%s': max_turns from global default (%s -> %s)", name, config.max_turns, app_config.max_turns)
overrides["max_turns"] = app_config.max_turns
elif is_builtin and subagents_config.max_turns is not None and subagents_config.max_turns != config.max_turns:
logger.debug("Subagent '%s': max_turns from global default (%s -> %s)", name, config.max_turns, subagents_config.max_turns)
overrides["max_turns"] = subagents_config.max_turns
# Model: per-agent override only (no global default for model)
effective_model = app_config.get_model_for(name)
effective_model = subagents_config.get_model_for(name)
if effective_model is not None and effective_model != config.model:
logger.debug("Subagent '%s': model overridden (%s -> %s)", name, config.model, effective_model)
overrides["model"] = effective_model
# Skills: per-agent override only (no global default for skills)
effective_skills = app_config.get_skills_for(name)
effective_skills = subagents_config.get_skills_for(name)
if effective_skills is not None and effective_skills != config.skills:
logger.debug("Subagent '%s': skills overridden (%s -> %s)", name, config.skills, effective_skills)
overrides["skills"] = effective_skills
@@ -110,21 +116,21 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
return config
def list_subagents() -> list[SubagentConfig]:
def list_subagents(*, app_config: Any | None = None) -> list[SubagentConfig]:
"""List all available subagent configurations (with config.yaml overrides applied).
Returns:
List of all registered SubagentConfig instances (built-in + custom).
"""
configs = []
for name in get_subagent_names():
config = get_subagent_config(name)
for name in get_subagent_names(app_config=app_config):
config = get_subagent_config(name, app_config=app_config)
if config is not None:
configs.append(config)
return configs
def get_subagent_names() -> list[str]:
def get_subagent_names(*, app_config: Any | None = None) -> list[str]:
"""Get all available subagent names (built-in + custom).
Returns:
@@ -133,25 +139,23 @@ def get_subagent_names() -> list[str]:
names = list(BUILTIN_SUBAGENTS.keys())
# Merge custom_agents from config.yaml
from deerflow.config.subagents_config import get_subagents_app_config
app_config = get_subagents_app_config()
for custom_name in app_config.custom_agents:
subagents_config = _resolve_subagents_app_config(app_config)
for custom_name in subagents_config.custom_agents:
if custom_name not in names:
names.append(custom_name)
return names
def get_available_subagent_names() -> list[str]:
def get_available_subagent_names(*, app_config: Any | None = None) -> list[str]:
"""Get subagent names that should be exposed to the active runtime.
Returns:
List of subagent names visible to the current sandbox configuration.
"""
names = get_subagent_names()
names = get_subagent_names(app_config=app_config)
try:
host_bash_allowed = is_host_bash_allowed()
host_bash_allowed = is_host_bash_allowed(app_config) if hasattr(app_config, "sandbox") else is_host_bash_allowed()
except Exception:
logger.debug("Could not determine host bash availability; exposing all subagents")
return names
@@ -4,20 +4,39 @@ import asyncio
import logging
import uuid
from dataclasses import replace
from typing import Annotated
from typing import TYPE_CHECKING, Annotated, Any, cast
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langgraph.config import get_stream_writer
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadState
from deerflow.config import get_app_config
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task
from deerflow.subagents.config import resolve_subagent_model_name
from deerflow.subagents.executor import (
SubagentStatus,
cleanup_background_task,
get_background_task_result,
request_cancel_background_task,
)
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
def _get_runtime_app_config(runtime: Any) -> "AppConfig | None":
context = getattr(runtime, "context", None)
if isinstance(context, dict):
app_config = context.get("app_config")
if app_config is not None:
return cast("AppConfig", app_config)
return None
def _merge_skill_allowlists(parent: list[str] | None, child: list[str] | None) -> list[str] | None:
"""Return the effective subagent skill allowlist under the parent policy."""
if parent is None:
@@ -74,15 +93,18 @@ async def task_tool(
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max.
"""
available_subagent_names = get_available_subagent_names()
runtime_app_config = _get_runtime_app_config(runtime)
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
# Get subagent configuration
config = get_subagent_config(subagent_type)
config = get_subagent_config(subagent_type, app_config=runtime_app_config) if runtime_app_config is not None else get_subagent_config(subagent_type)
if config is None:
available = ", ".join(available_subagent_names)
return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}"
if subagent_type == "bash" and not is_host_bash_allowed():
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
if subagent_type == "bash":
host_bash_allowed = is_host_bash_allowed(runtime_app_config) if runtime_app_config is not None else is_host_bash_allowed()
if not host_bash_allowed:
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
# Build config overrides
overrides: dict = {}
@@ -129,20 +151,34 @@ async def task_tool(
# Inherit parent agent's tool_groups so subagents respect the same restrictions
parent_tool_groups = metadata.get("tool_groups")
resolved_app_config = runtime_app_config
if config.model == "inherit" and parent_model is None and resolved_app_config is None:
resolved_app_config = get_app_config()
effective_model = resolve_subagent_model_name(config, parent_model, app_config=resolved_app_config)
# Subagents should not have subagent tools enabled (prevent recursive nesting)
tools = get_available_tools(model_name=parent_model, groups=parent_tool_groups, subagent_enabled=False)
available_tools_kwargs = {
"model_name": effective_model,
"groups": parent_tool_groups,
"subagent_enabled": False,
}
if resolved_app_config is not None:
available_tools_kwargs["app_config"] = resolved_app_config
tools = get_available_tools(**available_tools_kwargs)
# Create executor
executor = SubagentExecutor(
config=config,
tools=tools,
parent_model=parent_model,
sandbox_state=sandbox_state,
thread_data=thread_data,
thread_id=thread_id,
trace_id=trace_id,
)
executor_kwargs = {
"config": config,
"tools": tools,
"parent_model": parent_model,
"sandbox_state": sandbox_state,
"thread_data": thread_data,
"thread_id": thread_id,
"trace_id": trace_id,
}
if resolved_app_config is not None:
executor_kwargs["app_config"] = resolved_app_config
executor = SubagentExecutor(**executor_kwargs)
# Start background execution (always async to prevent blocking)
# Use tool_call_id as task_id for better traceability
@@ -177,11 +213,12 @@ async def task_tool(
last_status = result.status
# Check for new AI messages and send task_running events
current_message_count = len(result.ai_messages)
ai_messages = result.ai_messages or []
current_message_count = len(ai_messages)
if current_message_count > last_message_count:
# Send task_running event for each new message
for i in range(last_message_count, current_message_count):
message = result.ai_messages[i]
message = ai_messages[i]
writer(
{
"type": "task_running",
@@ -8,7 +8,42 @@ from langchain_core.messages import ToolMessage
from langgraph.types import Command
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadState
from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
_ALLOWED_IMAGE_VIRTUAL_ROOTS = (
f"{VIRTUAL_PATH_PREFIX}/workspace",
f"{VIRTUAL_PATH_PREFIX}/uploads",
f"{VIRTUAL_PATH_PREFIX}/outputs",
)
_ALLOWED_IMAGE_VIRTUAL_ROOTS_TEXT = ", ".join(_ALLOWED_IMAGE_VIRTUAL_ROOTS)
_MAX_IMAGE_BYTES = 20 * 1024 * 1024
_EXTENSION_TO_MIME = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".webp": "image/webp",
}
def _is_allowed_image_virtual_path(image_path: str) -> bool:
return any(image_path == root or image_path.startswith(f"{root}/") for root in _ALLOWED_IMAGE_VIRTUAL_ROOTS)
def _detect_image_mime(image_data: bytes) -> str | None:
if image_data.startswith(b"\xff\xd8\xff"):
return "image/jpeg"
if image_data.startswith(b"\x89PNG\r\n\x1a\n"):
return "image/png"
if len(image_data) >= 12 and image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP":
return "image/webp"
return None
def _sanitize_image_error(error: Exception, thread_data: ThreadDataState | None) -> str:
from deerflow.sandbox.tools import mask_local_paths_in_output
return mask_local_paths_in_output(f"{type(error).__name__}: {error}", thread_data)
@tool("view_image", parse_docstring=True)
@@ -29,22 +64,39 @@ def view_image_tool(
- For multiple files at once (use present_files instead)
Args:
image_path: Absolute path to the image file. Common formats supported: jpg, jpeg, png, webp.
image_path: Absolute /mnt/user-data virtual path to the image file. Common formats supported: jpg, jpeg, png, webp.
"""
from deerflow.sandbox.tools import get_thread_data, replace_virtual_path
from deerflow.sandbox.exceptions import SandboxRuntimeError
from deerflow.sandbox.tools import (
get_thread_data,
resolve_and_validate_user_data_path,
validate_local_tool_path,
)
# Replace virtual path with actual path
# /mnt/user-data/* paths are mapped to thread-specific directories
thread_data = get_thread_data(runtime)
actual_path = replace_virtual_path(image_path, thread_data)
# Validate that the path is absolute
path = Path(actual_path)
if not path.is_absolute():
if not _is_allowed_image_virtual_path(image_path):
return Command(
update={"messages": [ToolMessage(f"Error: Path must be absolute, got: {image_path}", tool_call_id=tool_call_id)]},
update={
"messages": [
ToolMessage(
f"Error: Only image paths under {_ALLOWED_IMAGE_VIRTUAL_ROOTS_TEXT} are allowed",
tool_call_id=tool_call_id,
)
]
},
)
try:
validate_local_tool_path(image_path, thread_data, read_only=True)
actual_path = resolve_and_validate_user_data_path(image_path, thread_data)
except (PermissionError, SandboxRuntimeError) as e:
return Command(
update={"messages": [ToolMessage(f"Error: {str(e)}", tool_call_id=tool_call_id)]},
)
path = Path(actual_path)
# Validate that the file exists
if not path.exists():
return Command(
@@ -58,34 +110,49 @@ def view_image_tool(
)
# Validate image extension
valid_extensions = {".jpg", ".jpeg", ".png", ".webp"}
if path.suffix.lower() not in valid_extensions:
expected_mime_type = _EXTENSION_TO_MIME.get(path.suffix.lower())
if expected_mime_type is None:
return Command(
update={"messages": [ToolMessage(f"Error: Unsupported image format: {path.suffix}. Supported formats: {', '.join(valid_extensions)}", tool_call_id=tool_call_id)]},
update={"messages": [ToolMessage(f"Error: Unsupported image format: {path.suffix}. Supported formats: {', '.join(_EXTENSION_TO_MIME)}", tool_call_id=tool_call_id)]},
)
# Detect MIME type from file extension
mime_type, _ = mimetypes.guess_type(actual_path)
if mime_type is None:
# Fallback to default MIME types for common image formats
extension_to_mime = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".webp": "image/webp",
}
mime_type = extension_to_mime.get(path.suffix.lower(), "application/octet-stream")
mime_type = expected_mime_type
try:
image_size = path.stat().st_size
except OSError as e:
return Command(
update={"messages": [ToolMessage(f"Error reading image metadata: {_sanitize_image_error(e, thread_data)}", tool_call_id=tool_call_id)]},
)
if image_size > _MAX_IMAGE_BYTES:
return Command(
update={"messages": [ToolMessage(f"Error: Image file is too large: {image_size} bytes. Maximum supported size is {_MAX_IMAGE_BYTES} bytes", tool_call_id=tool_call_id)]},
)
# Read image file and convert to base64
try:
with open(actual_path, "rb") as f:
image_data = f.read()
image_base64 = base64.b64encode(image_data).decode("utf-8")
except Exception as e:
return Command(
update={"messages": [ToolMessage(f"Error reading image file: {str(e)}", tool_call_id=tool_call_id)]},
update={"messages": [ToolMessage(f"Error reading image file: {_sanitize_image_error(e, thread_data)}", tool_call_id=tool_call_id)]},
)
detected_mime_type = _detect_image_mime(image_data)
if detected_mime_type is None:
return Command(
update={"messages": [ToolMessage("Error: File contents do not match a supported image format", tool_call_id=tool_call_id)]},
)
if detected_mime_type != expected_mime_type:
return Command(
update={"messages": [ToolMessage(f"Error: Image contents are {detected_mime_type}, but file extension indicates {expected_mime_type}", tool_call_id=tool_call_id)]},
)
mime_type = detected_mime_type
image_base64 = base64.b64encode(image_data).decode("utf-8")
# Update viewed_images in state
# The merge_viewed_images reducer will handle merging with existing images
new_viewed_images = {image_path: {"base64": image_base64, "mime_type": mime_type}}
@@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio
import logging
import shutil
from typing import Any
from weakref import WeakValueDictionary
@@ -14,20 +13,10 @@ from langgraph.typing import ContextT
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.agents.thread_state import ThreadState
from deerflow.mcp.tools import _make_sync_tool_wrapper
from deerflow.skills.manager import (
append_history,
atomic_write,
custom_skill_exists,
ensure_custom_skill_is_editable,
ensure_safe_support_path,
get_custom_skill_dir,
get_custom_skill_file,
public_skill_exists,
read_custom_skill_content,
validate_skill_markdown_content,
validate_skill_name,
)
from deerflow.skills.security_scanner import scan_skill_content
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.storage.skill_storage import SkillStorage
from deerflow.skills.types import SKILL_MD_FILE
logger = logging.getLogger(__name__)
@@ -96,50 +85,50 @@ async def _skill_manage_impl(
replace: Replacement text for patch.
expected_count: Optional expected number of replacements for patch.
"""
name = validate_skill_name(name)
name = SkillStorage.validate_skill_name(name)
lock = _get_lock(name)
thread_id = _get_thread_id(runtime)
skill_storage = get_or_new_skill_storage()
async with lock:
if action == "create":
if await _to_thread(custom_skill_exists, name):
if await _to_thread(skill_storage.custom_skill_exists, name):
raise ValueError(f"Custom skill '{name}' already exists.")
if content is None:
raise ValueError("content is required for create.")
await _to_thread(validate_skill_markdown_content, name, content)
scan = await _scan_or_raise(content, executable=False, location=f"{name}/SKILL.md")
skill_file = await _to_thread(get_custom_skill_file, name)
await _to_thread(atomic_write, skill_file, content)
await _to_thread(skill_storage.validate_skill_markdown_content, name, content)
scan = await _scan_or_raise(content, executable=False, location=f"{name}/{SKILL_MD_FILE}")
await _to_thread(skill_storage.write_custom_skill, name, SKILL_MD_FILE, content)
await _to_thread(
append_history,
skill_storage.append_history,
name,
_history_record(action="create", file_path="SKILL.md", prev_content=None, new_content=content, thread_id=thread_id, scanner=scan),
_history_record(action="create", file_path=SKILL_MD_FILE, prev_content=None, new_content=content, thread_id=thread_id, scanner=scan),
)
await refresh_skills_system_prompt_cache_async()
return f"Created custom skill '{name}'."
if action == "edit":
await _to_thread(ensure_custom_skill_is_editable, name)
await _to_thread(skill_storage.ensure_custom_skill_is_editable, name)
if content is None:
raise ValueError("content is required for edit.")
await _to_thread(validate_skill_markdown_content, name, content)
scan = await _scan_or_raise(content, executable=False, location=f"{name}/SKILL.md")
skill_file = await _to_thread(get_custom_skill_file, name)
await _to_thread(skill_storage.validate_skill_markdown_content, name, content)
scan = await _scan_or_raise(content, executable=False, location=f"{name}/{SKILL_MD_FILE}")
skill_file = skill_storage.get_custom_skill_file(name)
prev_content = await _to_thread(skill_file.read_text, encoding="utf-8")
await _to_thread(atomic_write, skill_file, content)
await _to_thread(skill_storage.write_custom_skill, name, SKILL_MD_FILE, content)
await _to_thread(
append_history,
skill_storage.append_history,
name,
_history_record(action="edit", file_path="SKILL.md", prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan),
_history_record(action="edit", file_path=SKILL_MD_FILE, prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan),
)
await refresh_skills_system_prompt_cache_async()
return f"Updated custom skill '{name}'."
if action == "patch":
await _to_thread(ensure_custom_skill_is_editable, name)
await _to_thread(skill_storage.ensure_custom_skill_is_editable, name)
if find is None or replace is None:
raise ValueError("find and replace are required for patch.")
skill_file = await _to_thread(get_custom_skill_file, name)
skill_file = skill_storage.get_custom_skill_file(name)
prev_content = await _to_thread(skill_file.read_text, encoding="utf-8")
occurrences = prev_content.count(find)
if occurrences == 0:
@@ -148,64 +137,67 @@ async def _skill_manage_impl(
raise ValueError(f"Expected {expected_count} replacements but found {occurrences}.")
replacement_count = expected_count if expected_count is not None else 1
new_content = prev_content.replace(find, replace, replacement_count)
await _to_thread(validate_skill_markdown_content, name, new_content)
scan = await _scan_or_raise(new_content, executable=False, location=f"{name}/SKILL.md")
await _to_thread(atomic_write, skill_file, new_content)
await _to_thread(skill_storage.validate_skill_markdown_content, name, new_content)
scan = await _scan_or_raise(new_content, executable=False, location=f"{name}/{SKILL_MD_FILE}")
await _to_thread(skill_storage.write_custom_skill, name, SKILL_MD_FILE, new_content)
await _to_thread(
append_history,
skill_storage.append_history,
name,
_history_record(action="patch", file_path="SKILL.md", prev_content=prev_content, new_content=new_content, thread_id=thread_id, scanner=scan),
_history_record(action="patch", file_path=SKILL_MD_FILE, prev_content=prev_content, new_content=new_content, thread_id=thread_id, scanner=scan),
)
await refresh_skills_system_prompt_cache_async()
return f"Patched custom skill '{name}' ({replacement_count} replacement(s) applied, {occurrences} match(es) found)."
if action == "delete":
await _to_thread(ensure_custom_skill_is_editable, name)
skill_dir = await _to_thread(get_custom_skill_dir, name)
prev_content = await _to_thread(read_custom_skill_content, name)
await _to_thread(
append_history,
skill_storage.delete_custom_skill,
name,
_history_record(action="delete", file_path="SKILL.md", prev_content=prev_content, new_content=None, thread_id=thread_id, scanner={"decision": "allow", "reason": "Deletion requested."}),
history_meta=_history_record(
action="delete",
file_path=SKILL_MD_FILE,
prev_content=None,
new_content=None,
thread_id=thread_id,
scanner={"decision": "allow", "reason": "Deletion requested."},
),
)
await _to_thread(shutil.rmtree, skill_dir)
await refresh_skills_system_prompt_cache_async()
return f"Deleted custom skill '{name}'."
if action == "write_file":
await _to_thread(ensure_custom_skill_is_editable, name)
await _to_thread(skill_storage.ensure_custom_skill_is_editable, name)
if path is None or content is None:
raise ValueError("path and content are required for write_file.")
target = await _to_thread(ensure_safe_support_path, name, path)
target = await _to_thread(skill_storage.ensure_safe_support_path, name, path)
exists = await _to_thread(target.exists)
prev_content = await _to_thread(target.read_text, encoding="utf-8") if exists else None
executable = "scripts/" in path or path.startswith("scripts/")
scan = await _scan_or_raise(content, executable=executable, location=f"{name}/{path}")
await _to_thread(atomic_write, target, content)
await _to_thread(skill_storage.write_custom_skill, name, path, content)
await _to_thread(
append_history,
skill_storage.append_history,
name,
_history_record(action="write_file", file_path=path, prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan),
)
return f"Wrote '{path}' for custom skill '{name}'."
if action == "remove_file":
await _to_thread(ensure_custom_skill_is_editable, name)
await _to_thread(skill_storage.ensure_custom_skill_is_editable, name)
if path is None:
raise ValueError("path is required for remove_file.")
target = await _to_thread(ensure_safe_support_path, name, path)
target = await _to_thread(skill_storage.ensure_safe_support_path, name, path)
if not await _to_thread(target.exists):
raise FileNotFoundError(f"Supporting file '{path}' not found for skill '{name}'.")
prev_content = await _to_thread(target.read_text, encoding="utf-8")
await _to_thread(target.unlink)
await _to_thread(
append_history,
skill_storage.append_history,
name,
_history_record(action="remove_file", file_path=path, prev_content=prev_content, new_content=None, thread_id=thread_id, scanner={"decision": "allow", "reason": "Deletion requested."}),
)
return f"Removed '{path}' from custom skill '{name}'."
if await _to_thread(public_skill_exists, name):
if await _to_thread(skill_storage.public_skill_exists, name):
raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.")
raise ValueError(f"Unsupported action '{action}'.")
@@ -3,6 +3,7 @@ import logging
from langchain.tools import BaseTool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_variable
from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
@@ -37,6 +38,8 @@ def get_available_tools(
include_mcp: bool = True,
model_name: str | None = None,
subagent_enabled: bool = False,
*,
app_config: AppConfig | None = None,
) -> list[BaseTool]:
"""Get all available tools from config.
@@ -52,7 +55,7 @@ def get_available_tools(
Returns:
List of available tools.
"""
config = get_app_config()
config = app_config or get_app_config()
tool_configs = [tool for tool in config.tools if groups is None or tool.group in groups]
# Do not expose host bash by default when LocalSandboxProvider is active.
@@ -138,10 +141,14 @@ def get_available_tools(
# Add invoke_acp_agent tool if any ACP agents are configured
acp_tools: list[BaseTool] = []
try:
from deerflow.config.acp_config import get_acp_agents
from deerflow.tools.builtins.invoke_acp_agent_tool import build_invoke_acp_agent_tool
acp_agents = get_acp_agents()
if app_config is None:
from deerflow.config.acp_config import get_acp_agents
acp_agents = get_acp_agents()
else:
acp_agents = getattr(config, "acp_agents", {}) or {}
if acp_agents:
acp_tools.append(build_invoke_acp_agent_tool(acp_agents))
logger.info(f"Including invoke_acp_agent tool ({len(acp_agents)} agent(s): {list(acp_agents.keys())})")
+1
View File
@@ -17,6 +17,7 @@ dependencies = [
"langgraph-sdk>=0.1.51",
"markdown-to-mrkdwn>=0.3.1",
"wecom-aibot-python-sdk>=0.1.6",
"dingtalk-stream>=0.24.3",
"bcrypt>=4.0.0",
"pyjwt>=2.9.0",
"email-validator>=2.0.0",
+15
View File
@@ -68,6 +68,21 @@ def provisioner_module():
# context should mark themselves ``@pytest.mark.no_auto_user``.
@pytest.fixture(autouse=True)
def _reset_skill_storage_singleton():
"""Reset the SkillStorage singleton between tests to prevent cross-test contamination."""
try:
from deerflow.skills.storage import reset_skill_storage
except ImportError:
yield
return
reset_skill_storage()
try:
yield
finally:
reset_skill_storage()
@pytest.fixture(autouse=True)
def _auto_user_context(request):
"""Inject a default ``test-user-autouse`` into the contextvar.
+52
View File
@@ -133,6 +133,58 @@ class TestListDirSerialization:
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
class TestNoChangeTimeout:
"""Verify that no_change_timeout is forwarded to every exec_command call."""
def test_execute_command_passes_no_change_timeout(self, sandbox):
"""execute_command should pass no_change_timeout to exec_command."""
calls = []
def mock_exec(command, **kwargs):
calls.append(kwargs)
return SimpleNamespace(data=SimpleNamespace(output="ok"))
sandbox._client.shell.exec_command = mock_exec
sandbox.execute_command("echo hello")
assert len(calls) == 1
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
def test_retry_passes_no_change_timeout(self, sandbox):
"""The ErrorObservation retry path should also pass no_change_timeout."""
calls = []
def mock_exec(command, **kwargs):
calls.append(kwargs)
if len(calls) == 1:
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
return SimpleNamespace(data=SimpleNamespace(output="ok"))
sandbox._client.shell.exec_command = mock_exec
sandbox.execute_command("echo hello")
assert len(calls) == 2
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
assert calls[1].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
def test_list_dir_passes_no_change_timeout(self, sandbox):
"""list_dir should pass no_change_timeout to exec_command."""
calls = []
def mock_exec(command, **kwargs):
calls.append(kwargs)
return SimpleNamespace(data=SimpleNamespace(output="/a\n/b"))
sandbox._client.shell.exec_command = mock_exec
sandbox.list_dir("/test")
assert len(calls) == 1
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
class TestConcurrentFileWrites:
"""Verify file write paths do not lose concurrent updates."""
+209 -1
View File
@@ -1,4 +1,14 @@
from deerflow.community.aio_sandbox.local_backend import _format_container_mount
import logging
import os
from types import SimpleNamespace
from deerflow.community.aio_sandbox.local_backend import (
LocalContainerBackend,
_format_container_command_for_log,
_format_container_mount,
_redact_container_command_for_log,
_resolve_docker_bind_host,
)
def test_format_container_mount_uses_mount_syntax_for_docker_windows_paths():
@@ -26,3 +36,201 @@ def test_format_container_mount_keeps_volume_syntax_for_apple_container():
"-v",
"/host/path:/mnt/path:ro",
]
def test_redact_container_command_for_log_redacts_env_values():
redacted = _redact_container_command_for_log(
[
"docker",
"run",
"-e",
"API_KEY=secret-value",
"--env=TOKEN=token-value",
"--name",
"sandbox",
"image",
]
)
assert "API_KEY=<redacted>" in redacted
assert "--env=TOKEN=<redacted>" in redacted
assert "secret-value" not in " ".join(redacted)
assert "token-value" not in " ".join(redacted)
def test_redact_container_command_for_log_keeps_inherited_env_names():
redacted = _redact_container_command_for_log(
[
"docker",
"run",
"-e",
"API_KEY",
"--env=TOKEN",
"--name",
"sandbox",
"image",
]
)
assert redacted == [
"docker",
"run",
"-e",
"API_KEY",
"--env=TOKEN",
"--name",
"sandbox",
"image",
]
def test_format_container_command_for_log_uses_windows_quoting(monkeypatch):
monkeypatch.setattr(os, "name", "nt")
command = _format_container_command_for_log(["docker", "run", "--name", "sandbox one", "image"])
assert command == 'docker run --name "sandbox one" image'
def test_start_container_logs_redacted_env_values(monkeypatch, caplog):
backend = LocalContainerBackend(
image="sandbox:latest",
base_port=8080,
container_prefix="sandbox",
config_mounts=[],
environment={"API_KEY": "secret-value", "NORMAL": "visible-value"},
)
monkeypatch.setattr(backend, "_runtime", "docker")
captured_cmd: list[str] = []
def fake_run(cmd, **kwargs):
captured_cmd.extend(cmd)
return SimpleNamespace(stdout="container-id\n", stderr="", returncode=0)
monkeypatch.setattr("subprocess.run", fake_run)
with caplog.at_level(logging.INFO, logger="deerflow.community.aio_sandbox.local_backend"):
backend._start_container("sandbox-test", 18080)
joined_cmd = " ".join(captured_cmd)
assert "API_KEY=secret-value" in joined_cmd
assert "NORMAL=visible-value" in joined_cmd
log_output = "\n".join(record.getMessage() for record in caplog.records)
assert "API_KEY=<redacted>" in log_output
assert "NORMAL=<redacted>" in log_output
assert "secret-value" not in log_output
assert "visible-value" not in log_output
def _capture_start_container_command(monkeypatch, backend: LocalContainerBackend, runtime: str = "docker") -> list[str]:
monkeypatch.setattr(backend, "_runtime", runtime)
captured_cmd: list[str] = []
def fake_run(cmd, **kwargs):
captured_cmd.extend(cmd)
return SimpleNamespace(stdout="container-id\n", stderr="", returncode=0)
monkeypatch.setattr("subprocess.run", fake_run)
backend._start_container("sandbox-test", 18080)
return captured_cmd
def test_resolve_docker_bind_host_defaults_loopback_for_localhost(monkeypatch):
monkeypatch.delenv("DEER_FLOW_SANDBOX_BIND_HOST", raising=False)
monkeypatch.delenv("DEER_FLOW_SANDBOX_HOST", raising=False)
assert _resolve_docker_bind_host() == "127.0.0.1"
def test_resolve_docker_bind_host_keeps_dood_compatibility(monkeypatch):
monkeypatch.delenv("DEER_FLOW_SANDBOX_BIND_HOST", raising=False)
monkeypatch.setenv("DEER_FLOW_SANDBOX_HOST", "host.docker.internal")
assert _resolve_docker_bind_host() == "0.0.0.0"
def test_resolve_docker_bind_host_uses_ipv6_loopback_for_ipv6_sandbox_host(monkeypatch):
monkeypatch.delenv("DEER_FLOW_SANDBOX_BIND_HOST", raising=False)
monkeypatch.setenv("DEER_FLOW_SANDBOX_HOST", "[::1]")
assert _resolve_docker_bind_host() == "[::1]"
def test_resolve_docker_bind_host_logs_selected_bind_reason(caplog):
with caplog.at_level(logging.DEBUG, logger="deerflow.community.aio_sandbox.local_backend"):
assert _resolve_docker_bind_host(sandbox_host="localhost", bind_host="") == "127.0.0.1"
messages = "\n".join(record.getMessage() for record in caplog.records)
assert "Docker sandbox bind: 127.0.0.1 (loopback default)" in messages
def test_resolve_docker_bind_host_allows_explicit_override(monkeypatch):
monkeypatch.setenv("DEER_FLOW_SANDBOX_HOST", "localhost")
monkeypatch.setenv("DEER_FLOW_SANDBOX_BIND_HOST", "192.0.2.10")
assert _resolve_docker_bind_host() == "192.0.2.10"
def test_start_container_binds_local_docker_port_to_loopback_by_default(monkeypatch):
backend = LocalContainerBackend(
image="sandbox:latest",
base_port=8080,
container_prefix="sandbox",
config_mounts=[],
environment={},
)
monkeypatch.delenv("DEER_FLOW_SANDBOX_HOST", raising=False)
monkeypatch.delenv("DEER_FLOW_SANDBOX_BIND_HOST", raising=False)
captured_cmd = _capture_start_container_command(monkeypatch, backend)
assert captured_cmd[captured_cmd.index("-p") + 1] == "127.0.0.1:18080:8080"
def test_start_container_keeps_broad_bind_for_dood_sandbox_host(monkeypatch):
backend = LocalContainerBackend(
image="sandbox:latest",
base_port=8080,
container_prefix="sandbox",
config_mounts=[],
environment={},
)
monkeypatch.setenv("DEER_FLOW_SANDBOX_HOST", "host.docker.internal")
monkeypatch.delenv("DEER_FLOW_SANDBOX_BIND_HOST", raising=False)
captured_cmd = _capture_start_container_command(monkeypatch, backend)
assert captured_cmd[captured_cmd.index("-p") + 1] == "0.0.0.0:18080:8080"
def test_start_container_binds_ipv6_sandbox_host_to_ipv6_loopback(monkeypatch):
backend = LocalContainerBackend(
image="sandbox:latest",
base_port=8080,
container_prefix="sandbox",
config_mounts=[],
environment={},
)
monkeypatch.setenv("DEER_FLOW_SANDBOX_HOST", "[::1]")
monkeypatch.delenv("DEER_FLOW_SANDBOX_BIND_HOST", raising=False)
captured_cmd = _capture_start_container_command(monkeypatch, backend)
assert captured_cmd[captured_cmd.index("-p") + 1] == "[::1]:18080:8080"
def test_start_container_keeps_apple_container_port_format(monkeypatch):
backend = LocalContainerBackend(
image="sandbox:latest",
base_port=8080,
container_prefix="sandbox",
config_mounts=[],
environment={},
)
monkeypatch.setenv("DEER_FLOW_SANDBOX_BIND_HOST", "127.0.0.1")
captured_cmd = _capture_start_container_command(monkeypatch, backend, runtime="container")
assert captured_cmd[captured_cmd.index("-p") + 1] == "18080:8080"
+101 -4
View File
@@ -4,12 +4,14 @@ from datetime import timedelta
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import bcrypt
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password
from app.gateway.auth.models import User
from app.gateway.auth.password import needs_rehash
from app.gateway.authz import (
AuthContext,
Permissions,
@@ -26,6 +28,7 @@ def test_hash_password_and_verify():
password = "s3cr3tP@ssw0rd!"
hashed = hash_password(password)
assert hashed != password
assert hashed.startswith("$dfv2$")
assert verify_password(password, hashed) is True
assert verify_password("wrongpassword", hashed) is False
@@ -47,6 +50,47 @@ def test_verify_password_rejects_empty():
assert verify_password("", hashed) is False
def test_hash_produces_v2_prefix():
"""hash_password output starts with $dfv2$."""
hashed = hash_password("anypassword123")
assert hashed.startswith("$dfv2$")
def test_verify_v1_prefixed_hash():
"""verify_password handles $dfv1$ prefixed hashes (plain bcrypt)."""
password = "legacyP@ssw0rd"
raw_bcrypt = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
v1_hash = f"$dfv1${raw_bcrypt}"
assert verify_password(password, v1_hash) is True
assert verify_password("wrong", v1_hash) is False
def test_verify_bare_bcrypt_hash():
"""verify_password handles bare bcrypt hashes (no prefix) as v1."""
password = "oldstyleP@ss"
raw_bcrypt = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
assert verify_password(password, raw_bcrypt) is True
assert verify_password("wrong", raw_bcrypt) is False
def test_needs_rehash_returns_false_for_v2():
"""v2 hashes do not need rehashing."""
hashed = hash_password("something")
assert needs_rehash(hashed) is False
def test_needs_rehash_returns_true_for_v1():
"""v1-prefixed hashes need rehashing."""
raw = bcrypt.hashpw(b"pw", bcrypt.gensalt()).decode("utf-8")
assert needs_rehash(f"$dfv1${raw}") is True
def test_needs_rehash_returns_true_for_bare_bcrypt():
"""Bare bcrypt hashes (no prefix) need rehashing."""
raw = bcrypt.hashpw(b"pw", bcrypt.gensalt()).decode("utf-8")
assert needs_rehash(raw) is True
# ── JWT ─────────────────────────────────────────────────────────────────────
@@ -166,7 +210,7 @@ def test_get_auth_context_set():
def test_require_auth_sets_auth_context():
"""require_auth sets auth context on request from cookie."""
"""require_auth rejects unauthenticated requests with 401."""
from fastapi import Request
app = FastAPI()
@@ -178,10 +222,9 @@ def test_require_auth_sets_auth_context():
return {"authenticated": ctx.is_authenticated}
with TestClient(app) as client:
# No cookie → anonymous
# No cookie → 401 (require_auth independently enforces authentication)
response = client.get("/test")
assert response.status_code == 200
assert response.json()["authenticated"] is False
assert response.status_code == 401
def test_require_auth_requires_request_param():
@@ -652,3 +695,57 @@ def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog):
# Cleanup
config_module._auth_config = None
# ── Auto-rehash on login ──────────────────────────────────────────────────
def test_authenticate_auto_rehashes_legacy_hash():
"""authenticate() upgrades a bare bcrypt hash to v2 on successful login."""
import asyncio
from app.gateway.auth.local_provider import LocalAuthProvider
password = "rehashTest123"
user = User(
id=uuid4(),
email="rehash@test.com",
password_hash=bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8"),
)
mock_repo = MagicMock()
mock_repo.get_user_by_email = AsyncMock(return_value=user)
mock_repo.update_user = AsyncMock(return_value=user)
provider = LocalAuthProvider(mock_repo)
result = asyncio.run(provider.authenticate({"email": "rehash@test.com", "password": password}))
assert result is not None
assert result.password_hash.startswith("$dfv2$")
mock_repo.update_user.assert_called_once()
def test_authenticate_skips_rehash_for_v2_hash():
"""authenticate() does NOT rehash when the stored hash is already v2."""
import asyncio
from app.gateway.auth.local_provider import LocalAuthProvider
password = "alreadyv2Pass!"
user = User(
id=uuid4(),
email="v2@test.com",
password_hash=hash_password(password),
)
mock_repo = MagicMock()
mock_repo.get_user_by_email = AsyncMock(return_value=user)
mock_repo.update_user = AsyncMock(return_value=user)
provider = LocalAuthProvider(mock_repo)
result = asyncio.run(provider.authenticate({"email": "v2@test.com", "password": password}))
assert result is not None
mock_repo.update_user.assert_not_called()
+100
View File
@@ -462,6 +462,7 @@ class TestChannelManager:
)
mock_channel = MagicMock()
mock_channel.receive_file = AsyncMock(return_value=modified_msg)
mock_channel.supports_streaming = False
mock_service = MagicMock()
mock_service.get_channel.return_value = mock_channel
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: mock_service)
@@ -535,6 +536,89 @@ class TestChannelManager:
_run(go())
def test_handle_chat_outbound_preserves_inbound_metadata(self):
"""DingTalk (and similar) need inbound metadata on outbound sends (e.g. sender_staff_id)."""
from app.channels.manager import ChannelManager
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(bus=bus, store=store)
outbound_received: list[OutboundMessage] = []
async def capture_outbound(msg: OutboundMessage) -> None:
outbound_received.append(msg)
bus.subscribe_outbound(capture_outbound)
mock_client = _make_mock_langgraph_client()
manager._client = mock_client
await manager.start()
meta = {
"sender_staff_id": "staff_001",
"conversation_type": "1",
"conversation_id": "conv_001",
}
inbound = InboundMessage(
channel_name="test",
chat_id="chat1",
user_id="user1",
text="hi",
metadata=meta,
)
await bus.publish_inbound(inbound)
await _wait_for(lambda: len(outbound_received) >= 1)
await manager.stop()
assert len(outbound_received) == 1
assert outbound_received[0].metadata == meta
_run(go())
def test_handle_chat_outbound_drops_large_metadata_keys(self):
"""Large metadata keys like raw_message should be stripped from outbound messages."""
from app.channels.manager import ChannelManager
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(bus=bus, store=store)
outbound_received: list[OutboundMessage] = []
async def capture_outbound(msg: OutboundMessage) -> None:
outbound_received.append(msg)
bus.subscribe_outbound(capture_outbound)
mock_client = _make_mock_langgraph_client()
manager._client = mock_client
await manager.start()
meta = {
"sender_staff_id": "staff_001",
"conversation_type": "1",
"raw_message": {"huge": "payload" * 1000},
"ref_msg": {"also": "large"},
}
inbound = InboundMessage(
channel_name="test",
chat_id="chat1",
user_id="user1",
text="hi",
metadata=meta,
)
await bus.publish_inbound(inbound)
await _wait_for(lambda: len(outbound_received) >= 1)
await manager.stop()
assert len(outbound_received) == 1
out_meta = outbound_received[0].metadata
assert "sender_staff_id" in out_meta
assert "conversation_type" in out_meta
assert "raw_message" not in out_meta
assert "ref_msg" not in out_meta
_run(go())
def test_handle_chat_uses_channel_session_overrides(self):
from app.channels.manager import ChannelManager
@@ -2032,6 +2116,22 @@ class TestChannelService:
assert service.manager._langgraph_url == "http://custom-gateway:8001/api"
assert service.manager._gateway_url == "http://custom-gateway:8001"
def test_from_app_config_uses_explicit_config(self):
from app.channels.service import ChannelService
app_config = SimpleNamespace(
model_extra={
"channels": {
"telegram": {"enabled": False},
}
}
)
with patch("deerflow.config.app_config.get_app_config", side_effect=AssertionError("should not read global config")):
service = ChannelService.from_app_config(app_config)
assert service._config == {"telegram": {"enabled": False}}
def test_disabled_channel_with_string_creds_emits_warning(self, caplog):
"""Warning is emitted when a channel has string credentials but enabled=false."""
import logging
+48
View File
@@ -1,6 +1,7 @@
"""Unit tests for checkpointer config and singleton factory."""
import sys
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -103,6 +104,53 @@ class TestGetCheckpointer:
cp2 = get_checkpointer()
assert cp1 is not cp2
def test_explicit_app_config_bypasses_global_config_lookup(self):
from langgraph.checkpoint.memory import InMemorySaver
explicit_config = SimpleNamespace(
checkpointer=CheckpointerConfig(type="memory"),
database=SimpleNamespace(backend="memory"),
)
with patch(
"deerflow.runtime.checkpointer.provider.get_app_config",
side_effect=AssertionError("ambient get_app_config() must not be used when app_config is explicit"),
):
cp = get_checkpointer(app_config=explicit_config)
assert isinstance(cp, InMemorySaver)
def test_explicit_app_config_uses_unified_database_sqlite_backend(self):
explicit_config = SimpleNamespace(
checkpointer=None,
database=SimpleNamespace(backend="sqlite", checkpointer_sqlite_path="/tmp/explicit/deerflow.db"),
)
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
with (
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
patch(
"deerflow.runtime.checkpointer.provider.get_app_config",
side_effect=AssertionError("ambient get_app_config() must not be used when app_config is explicit"),
),
patch("deerflow.runtime.checkpointer.provider.ensure_sqlite_parent_dir") as mock_ensure,
):
cp = get_checkpointer(app_config=explicit_config)
assert cp is mock_saver_instance
mock_ensure.assert_called_once_with("/tmp/explicit/deerflow.db")
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/explicit/deerflow.db")
def test_sqlite_raises_when_package_missing(self):
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
+94 -35
View File
@@ -43,12 +43,27 @@ def mock_app_config():
@pytest.fixture
def client(mock_app_config):
def client(mock_app_config, tmp_path):
"""Create a DeerFlowClient with mocked config loading."""
import deerflow.skills.storage as _storage_mod
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
_storage_mod._default_skill_storage = LocalSkillStorage(host_path=str(tmp_path))
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
return DeerFlowClient()
@pytest.fixture
def allow_skill_security_scan():
async def _scan(*args, **kwargs):
from deerflow.skills.security_scanner import ScanResult
return ScanResult(decision="allow", reason="ok")
with patch("deerflow.skills.installer.scan_skill_content", _scan):
yield
# ---------------------------------------------------------------------------
# __init__
# ---------------------------------------------------------------------------
@@ -124,7 +139,7 @@ class TestConfigQueries:
skill.category = "public"
skill.enabled = True
with patch("deerflow.skills.loader.load_skills", return_value=[skill]) as mock_load:
with patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", return_value=[skill]) as mock_load:
result = client.list_skills()
mock_load.assert_called_once_with(enabled_only=False)
@@ -139,7 +154,7 @@ class TestConfigQueries:
}
def test_list_skills_enabled_only(self, client):
with patch("deerflow.skills.loader.load_skills", return_value=[]) as mock_load:
with patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", return_value=[]) as mock_load:
client.list_skills(enabled_only=True)
mock_load.assert_called_once_with(enabled_only=True)
@@ -833,6 +848,28 @@ class TestEnsureAgent:
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
assert mock_apply_prompt.call_args.kwargs.get("available_skills") == {"test_skill"}
def test_threads_explicit_app_config_to_dependencies(self, client):
"""Client-owned AppConfig must flow into model/tool/prompt/checkpointer composition."""
mock_agent = MagicMock()
mock_checkpointer = MagicMock()
config = client._get_runnable_config("t1")
with (
patch("deerflow.client.create_chat_model", return_value=MagicMock()) as mock_create_chat_model,
patch("deerflow.client.create_agent", return_value=mock_agent),
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
patch("deerflow.tools.get_available_tools", return_value=[]) as mock_get_available_tools,
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer) as mock_get_checkpointer,
):
client._ensure_agent(config)
assert mock_create_chat_model.call_args.kwargs["app_config"] is client._app_config
assert mock_build_middlewares.call_args.kwargs["app_config"] is client._app_config
assert mock_apply_prompt.call_args.kwargs["app_config"] is client._app_config
assert mock_get_available_tools.call_args.kwargs["app_config"] is client._app_config
assert mock_get_checkpointer.call_args.kwargs["app_config"] is client._app_config
def test_uses_default_checkpointer_when_available(self, client):
mock_agent = MagicMock()
mock_checkpointer = MagicMock()
@@ -1152,13 +1189,13 @@ class TestSkillsManagement:
def test_get_skill_found(self, client):
skill = self._make_skill()
with patch("deerflow.skills.loader.load_skills", return_value=[skill]):
with patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", return_value=[skill]):
result = client.get_skill("test-skill")
assert result is not None
assert result["name"] == "test-skill"
def test_get_skill_not_found(self, client):
with patch("deerflow.skills.loader.load_skills", return_value=[]):
with patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", return_value=[]):
result = client.get_skill("nonexistent")
assert result is None
@@ -1179,7 +1216,7 @@ class TestSkillsManagement:
client._agent = MagicMock()
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated_skill]]),
patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", side_effect=[[skill], [updated_skill]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
@@ -1191,11 +1228,11 @@ class TestSkillsManagement:
tmp_path.unlink()
def test_update_skill_not_found(self, client):
with patch("deerflow.skills.loader.load_skills", return_value=[]):
with patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", return_value=[]):
with pytest.raises(ValueError, match="not found"):
client.update_skill("nonexistent", enabled=True)
def test_install_skill(self, client):
def test_install_skill(self, client, allow_skill_security_scan):
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
@@ -1211,7 +1248,9 @@ class TestSkillsManagement:
skills_root = tmp_path / "skills"
(skills_root / "custom").mkdir(parents=True)
with patch("deerflow.skills.installer.get_skills_root_path", return_value=skills_root):
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
with patch("deerflow.skills.storage._default_skill_storage", LocalSkillStorage(host_path=str(skills_root))):
result = client.install_skill(archive_path)
assert result["success"] is True
@@ -1774,12 +1813,12 @@ class TestScenarioConfigManagement:
skill.category = "public"
skill.enabled = True
with patch("deerflow.skills.loader.load_skills", return_value=[skill]):
with patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", return_value=[skill]):
skills_result = client.list_skills()
assert len(skills_result["skills"]) == 1
# Get specific skill
with patch("deerflow.skills.loader.load_skills", return_value=[skill]):
with patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", return_value=[skill]):
detail = client.get_skill("web-search")
assert detail is not None
assert detail["enabled"] is True
@@ -1830,7 +1869,7 @@ class TestScenarioConfigManagement:
client._agent = MagicMock() # Simulate re-created agent
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [toggled]]),
patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", side_effect=[[skill], [toggled]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
@@ -2033,7 +2072,7 @@ class TestScenarioMemoryWorkflow:
class TestScenarioSkillInstallAndUse:
"""Scenario: Install a skill → verify it appears → toggle it."""
def test_install_then_toggle(self, client):
def test_install_then_toggle(self, client, allow_skill_security_scan):
"""Install .skill archive → list to verify → disable → verify disabled."""
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
@@ -2050,7 +2089,9 @@ class TestScenarioSkillInstallAndUse:
(skills_root / "custom").mkdir(parents=True)
# Step 1: Install
with patch("deerflow.skills.installer.get_skills_root_path", return_value=skills_root):
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
with patch("deerflow.skills.storage._default_skill_storage", LocalSkillStorage(host_path=str(skills_root))):
result = client.install_skill(archive)
assert result["success"] is True
assert (skills_root / "custom" / "my-analyzer" / "SKILL.md").exists()
@@ -2063,7 +2104,7 @@ class TestScenarioSkillInstallAndUse:
installed_skill.category = "custom"
installed_skill.enabled = True
with patch("deerflow.skills.loader.load_skills", return_value=[installed_skill]):
with patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", return_value=[installed_skill]):
skills_result = client.list_skills()
assert any(s["name"] == "my-analyzer" for s in skills_result["skills"])
@@ -2083,7 +2124,7 @@ class TestScenarioSkillInstallAndUse:
config_file.write_text("{}")
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[installed_skill], [disabled_skill]]),
patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", side_effect=[[installed_skill], [disabled_skill]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
@@ -2257,7 +2298,7 @@ class TestGatewayConformance:
skill.category = "public"
skill.enabled = True
with patch("deerflow.skills.loader.load_skills", return_value=[skill]):
with patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", return_value=[skill]):
result = client.list_skills()
parsed = SkillsListResponse(**result)
@@ -2272,14 +2313,14 @@ class TestGatewayConformance:
skill.category = "public"
skill.enabled = True
with patch("deerflow.skills.loader.load_skills", return_value=[skill]):
with patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", return_value=[skill]):
result = client.get_skill("web-search")
assert result is not None
parsed = SkillResponse(**result)
assert parsed.name == "web-search"
def test_install_skill(self, client, tmp_path):
def test_install_skill(self, client, tmp_path, allow_skill_security_scan):
skill_dir = tmp_path / "my-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text("---\nname: my-skill\ndescription: A test skill\n---\nBody\n")
@@ -2288,7 +2329,9 @@ class TestGatewayConformance:
with zipfile.ZipFile(archive, "w") as zf:
zf.write(skill_dir / "SKILL.md", "my-skill/SKILL.md")
with patch("deerflow.skills.installer.get_skills_root_path", return_value=tmp_path):
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
with patch("deerflow.skills.storage._default_skill_storage", LocalSkillStorage(host_path=str(tmp_path))):
result = client.install_skill(archive)
parsed = SkillInstallResponse(**result)
@@ -2442,8 +2485,10 @@ class TestInstallSkillSecurity:
def patched_extract(zf, dest, max_total_size=100):
return orig(zf, dest, max_total_size=100)
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
with (
patch("deerflow.skills.installer.get_skills_root_path", return_value=skills_root),
patch("deerflow.skills.storage._default_skill_storage", LocalSkillStorage(host_path=str(skills_root))),
patch("deerflow.skills.installer.safe_extract_skill_archive", side_effect=patched_extract),
):
with pytest.raises(ValueError, match="too large"):
@@ -2459,7 +2504,9 @@ class TestInstallSkillSecurity:
skills_root = Path(tmp) / "skills"
(skills_root / "custom").mkdir(parents=True)
with patch("deerflow.skills.installer.get_skills_root_path", return_value=skills_root):
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
with patch("deerflow.skills.storage._default_skill_storage", LocalSkillStorage(host_path=str(skills_root))):
with pytest.raises(ValueError, match="unsafe"):
client.install_skill(archive)
@@ -2473,11 +2520,13 @@ class TestInstallSkillSecurity:
skills_root = Path(tmp) / "skills"
(skills_root / "custom").mkdir(parents=True)
with patch("deerflow.skills.installer.get_skills_root_path", return_value=skills_root):
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
with patch("deerflow.skills.storage._default_skill_storage", LocalSkillStorage(host_path=str(skills_root))):
with pytest.raises(ValueError, match="unsafe"):
client.install_skill(archive)
def test_symlinks_skipped_during_extraction(self, client):
def test_symlinks_skipped_during_extraction(self, client, allow_skill_security_scan):
"""Symlink entries in the archive are skipped (never written to disk)."""
import stat as stat_mod
@@ -2495,7 +2544,9 @@ class TestInstallSkillSecurity:
skills_root = tmp_path / "skills"
(skills_root / "custom").mkdir(parents=True)
with patch("deerflow.skills.installer.get_skills_root_path", return_value=skills_root):
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
with patch("deerflow.skills.storage._default_skill_storage", LocalSkillStorage(host_path=str(skills_root))):
result = client.install_skill(archive)
assert result["success"] is True
@@ -2519,9 +2570,11 @@ class TestInstallSkillSecurity:
skills_root = tmp_path / "skills"
(skills_root / "custom").mkdir(parents=True)
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
with (
patch("deerflow.skills.installer.get_skills_root_path", return_value=skills_root),
patch("deerflow.skills.installer._validate_skill_frontmatter", return_value=(True, "OK", "../evil")),
patch("deerflow.skills.storage._default_skill_storage", LocalSkillStorage(host_path=str(skills_root))),
patch("deerflow.skills.validation._validate_skill_frontmatter", return_value=(True, "OK", "../evil")),
):
with pytest.raises(ValueError, match="Invalid skill name"):
client.install_skill(archive)
@@ -2542,9 +2595,11 @@ class TestInstallSkillSecurity:
skills_root = tmp_path / "skills"
(skills_root / "custom" / "dupe-skill").mkdir(parents=True)
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
with (
patch("deerflow.skills.installer.get_skills_root_path", return_value=skills_root),
patch("deerflow.skills.installer._validate_skill_frontmatter", return_value=(True, "OK", "dupe-skill")),
patch("deerflow.skills.storage._default_skill_storage", LocalSkillStorage(host_path=str(skills_root))),
patch("deerflow.skills.validation._validate_skill_frontmatter", return_value=(True, "OK", "dupe-skill")),
):
with pytest.raises(ValueError, match="already exists"):
client.install_skill(archive)
@@ -2559,7 +2614,9 @@ class TestInstallSkillSecurity:
skills_root = Path(tmp) / "skills"
(skills_root / "custom").mkdir(parents=True)
with patch("deerflow.skills.installer.get_skills_root_path", return_value=skills_root):
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
with patch("deerflow.skills.storage._default_skill_storage", LocalSkillStorage(host_path=str(skills_root))):
with pytest.raises(ValueError, match="empty"):
client.install_skill(archive)
@@ -2578,9 +2635,11 @@ class TestInstallSkillSecurity:
skills_root = tmp_path / "skills"
(skills_root / "custom").mkdir(parents=True)
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
with (
patch("deerflow.skills.installer.get_skills_root_path", return_value=skills_root),
patch("deerflow.skills.installer._validate_skill_frontmatter", return_value=(False, "Missing name field", "")),
patch("deerflow.skills.storage._default_skill_storage", LocalSkillStorage(host_path=str(skills_root))),
patch("deerflow.skills.validation._validate_skill_frontmatter", return_value=(False, "Missing name field", "")),
):
with pytest.raises(ValueError, match="Invalid skill"):
client.install_skill(archive)
@@ -2672,7 +2731,7 @@ class TestConfigUpdateErrors:
skill.name = "some-skill"
with (
patch("deerflow.skills.loader.load_skills", return_value=[skill]),
patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", return_value=[skill]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=None),
):
with pytest.raises(FileNotFoundError, match="Cannot locate"):
@@ -2692,7 +2751,7 @@ class TestConfigUpdateErrors:
config_file.write_text("{}")
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], []]),
patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", side_effect=[[skill], []]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
@@ -3107,7 +3166,7 @@ class TestBugAgentInvalidationInconsistency:
config_file.write_text("{}")
with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated]]),
patch("deerflow.skills.storage.local_skill_storage.LocalSkillStorage.load_skills", side_effect=[[skill], [updated]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.reload_extensions_config"),
+59 -25
View File
@@ -17,14 +17,13 @@ import json
import os
import uuid
import zipfile
from pathlib import Path
import pytest
from dotenv import load_dotenv
from deerflow.client import DeerFlowClient, StreamEvent
from deerflow.config.app_config import AppConfig
from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig
# Load .env from project root (for OPENAI_API_KEY etc.)
load_dotenv(os.path.join(os.path.dirname(__file__), "../../.env"))
@@ -55,24 +54,34 @@ def _make_e2e_config() -> AppConfig:
- ``E2E_MODEL_ID`` (default: ``ep-20251211175242-llcmh``)
- ``E2E_BASE_URL`` (default: ``https://ark-cn-beijing.bytedance.net/api/v3``)
- ``OPENAI_API_KEY`` (required for LLM tests)
Note: We use model_validate with a raw dict (not AppConfig(models=[ModelConfig(...)]))
because passing already-validated Pydantic instances triggers a pydantic-core
shortcut that returns stale cached data when another AppConfig was previously
loaded from disk in the same process. Dict-based validation is always correct.
"""
return AppConfig(
models=[
ModelConfig(
name=os.getenv("E2E_MODEL_NAME", "volcengine-ark"),
display_name="E2E Test Model",
use=os.getenv("E2E_MODEL_USE", "langchain_openai:ChatOpenAI"),
model=os.getenv("E2E_MODEL_ID", "ep-20251211175242-llcmh"),
base_url=os.getenv("E2E_BASE_URL", "https://ark-cn-beijing.bytedance.net/api/v3"),
api_key=os.getenv("OPENAI_API_KEY", ""),
max_tokens=512,
temperature=0.7,
supports_thinking=False,
supports_reasoning_effort=False,
supports_vision=False,
)
],
sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", allow_host_bash=True),
return AppConfig.model_validate(
{
"models": [
{
"name": os.getenv("E2E_MODEL_NAME", "volcengine-ark"),
"display_name": "E2E Test Model",
"use": os.getenv("E2E_MODEL_USE", "langchain_openai:ChatOpenAI"),
"model": os.getenv("E2E_MODEL_ID", "ep-20251211175242-llcmh"),
"base_url": os.getenv("E2E_BASE_URL", "https://ark-cn-beijing.bytedance.net/api/v3"),
"api_key": os.getenv("OPENAI_API_KEY", ""),
"max_tokens": 512,
"temperature": 0.7,
"supports_thinking": False,
"supports_reasoning_effort": False,
"supports_vision": False,
}
],
"sandbox": {
"use": "deerflow.sandbox.local:LocalSandboxProvider",
"allow_host_bash": True,
},
}
)
@@ -86,19 +95,31 @@ def e2e_env(tmp_path, monkeypatch):
"""Isolated filesystem environment for E2E tests.
- DEER_FLOW_HOME tmp_path (all thread data lands in a temp dir)
- DEER_FLOW_PROJECT_ROOT repository root (shared skills/config assets
still resolve correctly when tests run from backend/)
- Singletons reset so they pick up the new env
- Title/memory/summarization disabled to avoid extra LLM calls
- AppConfig built programmatically (avoids config.yaml param-name issues)
"""
# 1. Filesystem isolation
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
monkeypatch.setenv(
"DEER_FLOW_PROJECT_ROOT",
str(Path(__file__).resolve().parents[2]),
)
monkeypatch.setattr("deerflow.config.paths._paths", None)
monkeypatch.setattr("deerflow.sandbox.sandbox_provider._default_sandbox_provider", None)
# 2. Inject a clean AppConfig via the global singleton.
# 2. Inject a clean AppConfig. We must reset _app_config to None BEFORE
# calling _make_e2e_config() because AppConfig() constructor misbehaves when
# a disk config is already cached: it returns the cached model list instead
# of the provided one. Clearing first ensures the test config is correct.
monkeypatch.setattr("deerflow.config.app_config._app_config", None)
monkeypatch.setattr("deerflow.config.app_config._app_config_is_custom", False)
config = _make_e2e_config()
monkeypatch.setattr("deerflow.config.app_config._app_config", config)
monkeypatch.setattr("deerflow.config.app_config._app_config_is_custom", True)
monkeypatch.setattr("deerflow.client.get_app_config", lambda: config)
# 3. Disable title generation (extra LLM call, non-deterministic)
from deerflow.config.title_config import TitleConfig
@@ -525,15 +546,26 @@ class TestArtifactAccess:
class TestSkillInstallation:
"""install_skill() with real ZIP handling and filesystem."""
@pytest.fixture(autouse=True)
def _allow_skill_security_scan(self, monkeypatch):
async def _scan(*args, **kwargs):
from deerflow.skills.security_scanner import ScanResult
return ScanResult(decision="allow", reason="ok")
monkeypatch.setattr("deerflow.skills.installer.scan_skill_content", _scan)
@pytest.fixture(autouse=True)
def _isolate_skills_dir(self, tmp_path, monkeypatch):
"""Redirect skill installation to a temp directory."""
skills_root = tmp_path / "skills"
(skills_root / "public").mkdir(parents=True)
(skills_root / "custom").mkdir(parents=True)
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
monkeypatch.setattr(
"deerflow.skills.installer.get_skills_root_path",
lambda: skills_root,
"deerflow.skills.storage._default_skill_storage",
LocalSkillStorage(host_path=str(skills_root)),
)
self._skills_root = skills_root
@@ -608,19 +640,21 @@ class TestConfigManagement:
def test_list_models_returns_injected_config(self, e2e_env):
"""list_models() returns the model from the injected AppConfig."""
expected_model_name = os.getenv("E2E_MODEL_NAME", "volcengine-ark")
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
result = c.list_models()
assert "models" in result
assert len(result["models"]) == 1
assert result["models"][0]["name"] == "volcengine-ark"
assert result["models"][0]["name"] == expected_model_name
assert result["models"][0]["display_name"] == "E2E Test Model"
def test_get_model_found(self, e2e_env):
"""get_model() returns the model when it exists."""
expected_model_name = os.getenv("E2E_MODEL_NAME", "volcengine-ark")
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
model = c.get_model("volcengine-ark")
model = c.get_model(expected_model_name)
assert model is not None
assert model["name"] == "volcengine-ark"
assert model["name"] == expected_model_name
assert model["supports_thinking"] is False
def test_get_model_not_found(self, e2e_env):
+18 -6
View File
@@ -116,10 +116,22 @@ def test_middleware_and_features_conflict():
# ---------------------------------------------------------------------------
# 7. Vision feature auto-injects view_image_tool
# 7. Vision feature auto-injects view_image_tool when thread data is available
# ---------------------------------------------------------------------------
@patch("deerflow.agents.factory.create_agent")
def test_vision_injects_view_image_tool(mock_create_agent):
mock_create_agent.return_value = MagicMock()
feat = RuntimeFeatures(vision=True, sandbox=True)
create_deerflow_agent(_make_mock_model(), features=feat)
call_kwargs = mock_create_agent.call_args[1]
tool_names = [t.name for t in call_kwargs["tools"]]
assert "view_image" in tool_names
@patch("deerflow.agents.factory.create_agent")
def test_vision_without_sandbox_does_not_inject_view_image_tool(mock_create_agent):
mock_create_agent.return_value = MagicMock()
feat = RuntimeFeatures(vision=True, sandbox=False)
@@ -127,7 +139,7 @@ def test_vision_injects_view_image_tool(mock_create_agent):
call_kwargs = mock_create_agent.call_args[1]
tool_names = [t.name for t in call_kwargs["tools"]]
assert "view_image" in tool_names
assert "view_image" not in tool_names
def test_view_image_middleware_preserves_viewed_images_reducer():
@@ -301,11 +313,11 @@ def test_always_on_error_handling(mock_create_agent):
# ---------------------------------------------------------------------------
# 17. Vision with custom middleware still injects tool
# 17. Vision with custom middleware follows thread-data availability
# ---------------------------------------------------------------------------
@patch("deerflow.agents.factory.create_agent")
def test_vision_custom_middleware_still_injects_tool(mock_create_agent):
"""Custom vision middleware still gets the view_image_tool auto-injected."""
def test_vision_custom_middleware_without_sandbox_does_not_inject_tool(mock_create_agent):
"""Custom vision middleware without thread data does not get view_image_tool auto-injected."""
from langchain.agents.middleware import AgentMiddleware
mock_create_agent.return_value = MagicMock()
@@ -319,7 +331,7 @@ def test_vision_custom_middleware_still_injects_tool(mock_create_agent):
call_kwargs = mock_create_agent.call_args[1]
tool_names = [t.name for t in call_kwargs["tools"]]
assert "view_image" in tool_names
assert "view_image" not in tool_names
# ===========================================================================
File diff suppressed because it is too large Load Diff
+41
View File
@@ -0,0 +1,41 @@
from __future__ import annotations
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
def test_get_config_returns_app_state_config():
"""get_config should return the exact AppConfig stored on app.state."""
app = FastAPI()
config = AppConfig(sandbox=SandboxConfig(use="test"))
app.state.config = config
@app.get("/probe")
def probe(cfg: AppConfig = Depends(get_config)):
return {"same_identity": cfg is config, "log_level": cfg.log_level}
client = TestClient(app)
response = client.get("/probe")
assert response.status_code == 200
assert response.json() == {"same_identity": True, "log_level": "info"}
def test_get_config_reads_updated_app_state():
"""Swapping app.state.config should be visible to the dependency."""
app = FastAPI()
app.state.config = AppConfig(sandbox=SandboxConfig(use="test"), log_level="info")
@app.get("/log-level")
def log_level(cfg: AppConfig = Depends(get_config)):
return {"level": cfg.log_level}
client = TestClient(app)
assert client.get("/log-level").json() == {"level": "info"}
app.state.config = app.state.config.model_copy(update={"log_level": "debug"})
assert client.get("/log-level").json() == {"level": "debug"}
+124
View File
@@ -0,0 +1,124 @@
"""Tests for GATEWAY_ENABLE_DOCS configuration toggle.
Verifies that Swagger UI (/docs), ReDoc (/redoc), and the OpenAPI schema
(/openapi.json) can be disabled via the GATEWAY_ENABLE_DOCS environment
variable for production deployments.
"""
from __future__ import annotations
import os
from unittest.mock import patch
import pytest
from fastapi.testclient import TestClient
def _reset_gateway_config():
"""Reset the cached gateway config so env changes take effect."""
import app.gateway.config as cfg
cfg._gateway_config = None
@pytest.fixture(autouse=True)
def _clean_config():
"""Ensure gateway config cache is cleared before and after each test."""
_reset_gateway_config()
yield
_reset_gateway_config()
# ---------------------------------------------------------------------------
# Config parsing
# ---------------------------------------------------------------------------
def test_enable_docs_defaults_to_true():
"""When GATEWAY_ENABLE_DOCS is not set, enable_docs should be True."""
with patch.dict(os.environ, {}, clear=False):
if "GATEWAY_ENABLE_DOCS" in os.environ:
del os.environ["GATEWAY_ENABLE_DOCS"]
_reset_gateway_config()
from app.gateway.config import get_gateway_config
config = get_gateway_config()
assert config.enable_docs is True
def test_enable_docs_false():
"""GATEWAY_ENABLE_DOCS=false should disable docs."""
with patch.dict(os.environ, {"GATEWAY_ENABLE_DOCS": "false"}):
_reset_gateway_config()
from app.gateway.config import get_gateway_config
config = get_gateway_config()
assert config.enable_docs is False
def test_enable_docs_case_insensitive():
"""GATEWAY_ENABLE_DOCS is case-insensitive (FALSE, False, false)."""
for value in ("FALSE", "False", "false"):
with patch.dict(os.environ, {"GATEWAY_ENABLE_DOCS": value}):
_reset_gateway_config()
from app.gateway.config import get_gateway_config
config = get_gateway_config()
assert config.enable_docs is False, f"Expected False for GATEWAY_ENABLE_DOCS={value}"
def test_enable_docs_unexpected_value_disables():
"""Any non-'true' value should disable docs (fail-closed)."""
for value in ("0", "no", "off", "anything"):
with patch.dict(os.environ, {"GATEWAY_ENABLE_DOCS": value}):
_reset_gateway_config()
from app.gateway.config import get_gateway_config
config = get_gateway_config()
assert config.enable_docs is False, f"Expected False for GATEWAY_ENABLE_DOCS={value}"
# ---------------------------------------------------------------------------
# App-level endpoint visibility
# ---------------------------------------------------------------------------
def test_docs_endpoints_available_by_default():
"""With enable_docs=True (default), /docs, /redoc, /openapi.json return 200."""
with patch.dict(os.environ, {}, clear=False):
if "GATEWAY_ENABLE_DOCS" in os.environ:
del os.environ["GATEWAY_ENABLE_DOCS"]
_reset_gateway_config()
from app.gateway.app import create_app
app = create_app()
client = TestClient(app)
assert client.get("/docs").status_code == 200
assert client.get("/redoc").status_code == 200
assert client.get("/openapi.json").status_code == 200
def test_docs_endpoints_disabled_when_false():
"""With GATEWAY_ENABLE_DOCS=false, /docs, /redoc, /openapi.json return 404."""
with patch.dict(os.environ, {"GATEWAY_ENABLE_DOCS": "false"}):
_reset_gateway_config()
from app.gateway.app import create_app
app = create_app()
client = TestClient(app)
assert client.get("/docs").status_code == 404
assert client.get("/redoc").status_code == 404
assert client.get("/openapi.json").status_code == 404
def test_health_still_works_when_docs_disabled():
"""Disabling docs should NOT affect /health or other normal endpoints."""
with patch.dict(os.environ, {"GATEWAY_ENABLE_DOCS": "false"}):
_reset_gateway_config()
from app.gateway.app import create_app
app = create_app()
client = TestClient(app)
resp = client.get("/health")
assert resp.status_code == 200
assert resp.json()["status"] == "healthy"
+31
View File
@@ -256,6 +256,37 @@ def test_context_merges_into_configurable():
assert "thread_id" not in {k for k in context if k in _CONTEXT_CONFIGURABLE_KEYS}
def test_merge_run_context_overrides_propagates_to_runtime_context():
"""Regression for issue #2677: ``agent_name`` (and other whitelisted keys) from
``body.context`` must be propagated into BOTH ``config['configurable']`` and
``config['context']``. Previously only ``configurable`` was populated, so after
the LangGraph 1.1.x upgrade removed the fallback from ``configurable``, the
``setup_agent`` tool read ``runtime.context`` with ``agent_name=None`` and
silently wrote SOUL.md to the global base_dir.
"""
from app.gateway.services import build_run_config, merge_run_context_overrides
config = build_run_config("thread-1", None, None)
merge_run_context_overrides(config, {"agent_name": "my-agent", "is_bootstrap": True, "thread_id": "ignored"})
assert config["configurable"]["agent_name"] == "my-agent"
assert config["configurable"]["is_bootstrap"] is True
assert config["context"]["agent_name"] == "my-agent"
assert config["context"]["is_bootstrap"] is True
# Non-whitelisted keys are not forwarded.
assert "thread_id" not in config["context"]
def test_merge_run_context_overrides_noop_for_empty_context():
from app.gateway.services import build_run_config, merge_run_context_overrides
config = build_run_config("thread-1", None, None)
before = {k: dict(v) if isinstance(v, dict) else v for k, v in config.items()}
merge_run_context_overrides(config, None)
merge_run_context_overrides(config, {})
assert config == before
def test_context_does_not_override_existing_configurable():
"""Values already in config.configurable must NOT be overridden by context."""
from app.gateway.services import build_run_config
+17
View File
@@ -22,6 +22,7 @@ _TEST_SECRET = "test-secret-key-initialize-admin-min-32"
def _setup_auth(tmp_path):
"""Fresh SQLite engine + auth config per test."""
from app.gateway import deps
from app.gateway.routers.auth import _SETUP_STATUS_COOLDOWN
from deerflow.persistence.engine import close_engine, init_engine
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
@@ -29,11 +30,13 @@ def _setup_auth(tmp_path):
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
deps._cached_local_provider = None
deps._cached_repo = None
_SETUP_STATUS_COOLDOWN.clear()
try:
yield
finally:
deps._cached_local_provider = None
deps._cached_repo = None
_SETUP_STATUS_COOLDOWN.clear()
asyncio.run(close_engine())
@@ -163,3 +166,17 @@ def test_setup_status_false_when_only_regular_user_exists(client):
resp = client.get("/api/v1/auth/setup-status")
assert resp.status_code == 200
assert resp.json()["needs_setup"] is True
def test_setup_status_rate_limited_on_second_call(client):
"""Second /setup-status call within the cooldown window returns 429 with Retry-After."""
# First call succeeds.
resp1 = client.get("/api/v1/auth/setup-status")
assert resp1.status_code == 200
# Immediate second call is rate-limited.
resp2 = client.get("/api/v1/auth/setup-status")
assert resp2.status_code == 429
assert "Retry-After" in resp2.headers
retry_after = int(resp2.headers["Retry-After"])
assert 1 <= retry_after <= 60
@@ -697,3 +697,33 @@ def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(mo
assert "invoke_acp_agent" in [tool.name for tool in tools]
load_acp_config_from_dict({})
def test_get_available_tools_uses_explicit_app_config_for_acp_agents(monkeypatch):
explicit_agents = {"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")}
explicit_config = SimpleNamespace(
tools=[],
models=[],
tool_search=SimpleNamespace(enabled=False),
skill_evolution=SimpleNamespace(enabled=False),
get_model_config=lambda name: None,
acp_agents=explicit_agents,
)
sentinel_tool = SimpleNamespace(name="invoke_acp_agent")
captured: dict[str, object] = {}
def fail_get_acp_agents():
raise AssertionError("ambient get_acp_agents() must not be used when app_config is explicit")
def fake_build_invoke_acp_agent_tool(agents):
captured["agents"] = agents
return sentinel_tool
monkeypatch.setattr("deerflow.tools.tools.is_host_bash_allowed", lambda config=None: True)
monkeypatch.setattr("deerflow.config.acp_config.get_acp_agents", fail_get_acp_agents)
monkeypatch.setattr("deerflow.tools.builtins.invoke_acp_agent_tool.build_invoke_acp_agent_tool", fake_build_invoke_acp_agent_tool)
tools = get_available_tools(include_mcp=False, subagent_enabled=False, app_config=explicit_config)
assert captured["agents"] is explicit_agents
assert "invoke_acp_agent" in [tool.name for tool in tools]
+2 -2
View File
@@ -63,7 +63,7 @@ def test_invalid_jwt_raises_401():
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": "garbage"})))
assert exc.value.status_code == 401
assert "Token error" in str(exc.value.detail)
assert "Invalid token" in str(exc.value.detail)
def test_expired_jwt_raises_401():
@@ -295,7 +295,7 @@ def test_csrf_post_matching_token_proceeds_to_jwt():
)
# Past CSRF, rejected by JWT decode
assert exc.value.status_code == 401
assert "Token error" in str(exc.value.detail)
assert "Invalid token" in str(exc.value.detail)
def test_csrf_put_requires_token():
+186 -13
View File
@@ -2,6 +2,7 @@
from __future__ import annotations
import inspect
from unittest.mock import MagicMock
import pytest
@@ -33,6 +34,82 @@ def _make_model(name: str, *, supports_thinking: bool) -> ModelConfig:
)
def test_make_lead_agent_signature_matches_langgraph_server_factory_abi():
assert list(inspect.signature(lead_agent_module.make_lead_agent).parameters) == ["config"]
def test_internal_make_lead_agent_uses_explicit_app_config(monkeypatch):
app_config = _make_app_config([_make_model("explicit-model", supports_thinking=False)])
import deerflow.tools as tools_module
def _raise_get_app_config():
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
monkeypatch.setattr(lead_agent_module, "get_app_config", _raise_get_app_config)
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: [])
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name
captured["app_config"] = app_config
return object()
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
result = lead_agent_module._make_lead_agent(
{"configurable": {"model_name": "explicit-model"}},
app_config=app_config,
)
assert captured == {
"name": "explicit-model",
"app_config": app_config,
}
assert result["model"] is not None
def test_make_lead_agent_uses_runtime_app_config_from_context_without_global_read(monkeypatch):
app_config = _make_app_config([_make_model("context-model", supports_thinking=False)])
import deerflow.tools as tools_module
def _raise_get_app_config():
raise AssertionError("ambient get_app_config() must not be used when runtime context already carries app_config")
monkeypatch.setattr(lead_agent_module, "get_app_config", _raise_get_app_config)
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: [])
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name
captured["app_config"] = app_config
return object()
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
result = lead_agent_module.make_lead_agent(
{
"context": {
"model_name": "context-model",
"app_config": app_config,
}
}
)
assert captured == {
"name": "context-model",
"app_config": app_config,
}
assert result["model"] is not None
def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog):
app_config = _make_app_config(
[
@@ -84,14 +161,15 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: [])
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
captured["app_config"] = app_config
return object()
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
@@ -110,6 +188,7 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
assert captured["name"] == "safe-model"
assert captured["thinking_enabled"] is False
assert captured["app_config"] is app_config
assert result["model"] is not None
@@ -126,14 +205,15 @@ def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
get_available_tools = MagicMock(return_value=[])
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(tools_module, "get_available_tools", get_available_tools)
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: [])
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
captured["app_config"] = app_config
return object()
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
@@ -156,8 +236,9 @@ def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
"name": "context-model",
"thinking_enabled": False,
"reasoning_effort": "high",
"app_config": app_config,
}
get_available_tools.assert_called_once_with(model_name="context-model", groups=None, subagent_enabled=True)
get_available_tools.assert_called_once_with(model_name="context-model", groups=None, subagent_enabled=True, app_config=app_config)
assert result["model"] is not None
@@ -198,23 +279,71 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
)
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda: None)
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda **kwargs: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
middlewares = lead_agent_module._build_middlewares({"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}}, model_name="vision-model", custom_middlewares=[MagicMock()])
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"model_name": "stale-model", "is_plan_mode": False, "subagent_enabled": False}},
model_name="vision-model",
custom_middlewares=[MagicMock()],
app_config=app_config,
)
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
# verify the custom middleware is injected correctly
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch):
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
captured: dict[str, object] = {}
def _raise_get_app_config():
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
def _fake_build_lead_runtime_middlewares(*, app_config, lazy_init):
captured["app_config"] = app_config
captured["lazy_init"] = lazy_init
return ["base-middleware"]
monkeypatch.setattr(lead_agent_module, "get_app_config", _raise_get_app_config)
monkeypatch.setattr(
lead_agent_module,
"get_summarization_config",
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
"build_lead_runtime_middlewares",
_fake_build_lead_runtime_middlewares,
)
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda **kwargs: None)
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
monkeypatch.setattr(
lead_agent_module,
"TitleMiddleware",
lambda *, app_config: captured.setdefault("title_app_config", app_config) or "title-middleware",
)
monkeypatch.setattr(
lead_agent_module,
"MemoryMiddleware",
lambda agent_name=None, *, memory_config: captured.setdefault("memory_config", memory_config) or "memory-middleware",
)
middlewares = lead_agent_module._build_middlewares(
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
model_name="safe-model",
app_config=app_config,
)
assert captured == {
"app_config": app_config,
"lazy_init": True,
"title_app_config": app_config,
"memory_config": app_config.memory,
}
assert middlewares[0] == "base-middleware"
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
app_config = _make_app_config([_make_model("model-masswork", supports_thinking=False)])
app_config.summarization = SummarizationConfig(enabled=True, model_name="model-masswork")
app_config.memory = MemoryConfig(enabled=False)
from unittest.mock import MagicMock
@@ -222,18 +351,62 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
fake_model = MagicMock()
fake_model.with_config.return_value = fake_model
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
captured["app_config"] = app_config
return fake_model
def _raise_get_app_config():
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
monkeypatch.setattr(lead_agent_module, "get_app_config", _raise_get_app_config)
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
middleware = lead_agent_module._create_summarization_middleware()
middleware = lead_agent_module._create_summarization_middleware(app_config=app_config)
assert captured["name"] == "model-masswork"
assert captured["thinking_enabled"] is False
assert captured["app_config"] is app_config
assert middleware["model"] is fake_model
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
def test_create_summarization_middleware_threads_resolved_app_config_to_model(monkeypatch):
fallback_app_config = _make_app_config([_make_model("fallback-model", supports_thinking=False)])
fallback_app_config.summarization = SummarizationConfig(enabled=True, model_name="fallback-model")
fallback_app_config.memory = MemoryConfig(enabled=False)
from unittest.mock import MagicMock
captured: dict[str, object] = {}
fake_model = MagicMock()
fake_model.with_config.return_value = fake_model
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None):
captured["app_config"] = app_config
return fake_model
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: fallback_app_config)
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
lead_agent_module._create_summarization_middleware()
assert captured["app_config"] is fallback_app_config
def test_memory_middleware_uses_explicit_memory_config_without_global_read(monkeypatch):
from deerflow.agents.middlewares import memory_middleware as memory_middleware_module
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
def _raise_get_memory_config():
raise AssertionError("ambient get_memory_config() must not be used when memory_config is explicit")
monkeypatch.setattr(memory_middleware_module, "get_memory_config", _raise_get_memory_config)
middleware = MemoryMiddleware(memory_config=MemoryConfig(enabled=False))
assert middleware.after_agent({"messages": []}, runtime=MagicMock(context={"thread_id": "thread-1"})) is None
+141 -8
View File
@@ -4,6 +4,7 @@ from types import SimpleNamespace
import anyio
from deerflow.agents.lead_agent import prompt as prompt_module
from deerflow.config.subagents_config import CustomSubagentConfig, SubagentsAppConfig
from deerflow.skills.types import Skill
@@ -40,6 +41,21 @@ def test_build_custom_mounts_section_lists_configured_mounts(monkeypatch):
assert "read-only" in section
def test_build_custom_mounts_section_uses_explicit_app_config_without_global_read(monkeypatch):
mounts = [SimpleNamespace(container_path="/home/user/shared", read_only=False)]
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=mounts))
def fail_get_app_config():
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
monkeypatch.setattr("deerflow.config.get_app_config", fail_get_app_config)
section = prompt_module._build_custom_mounts_section(app_config=config)
assert "`/home/user/shared`" in section
assert "read-write" in section
def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
mounts = [SimpleNamespace(container_path="/home/user/shared", read_only=False)]
config = SimpleNamespace(
@@ -48,9 +64,9 @@ def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda **kwargs: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda **kwargs: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None, **kwargs: "")
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
prompt = prompt_module.apply_prompt_template()
@@ -66,9 +82,9 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
)
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda **kwargs: "")
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda **kwargs: "")
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None, **kwargs: "")
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
prompt = prompt_module.apply_prompt_template()
@@ -77,6 +93,123 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
assert "`hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`" in prompt
def test_apply_prompt_template_threads_explicit_app_config_without_global_config(monkeypatch):
mounts = [SimpleNamespace(container_path="/home/user/shared", read_only=False)]
explicit_config = SimpleNamespace(
sandbox=SimpleNamespace(mounts=mounts),
skills=SimpleNamespace(container_path="/mnt/explicit-skills"),
skill_evolution=SimpleNamespace(enabled=False),
tool_search=SimpleNamespace(enabled=False),
memory=SimpleNamespace(enabled=False, injection_enabled=True, max_injection_tokens=2000),
acp_agents={},
)
def fail_get_app_config():
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
def fail_get_memory_config():
raise AssertionError("ambient get_memory_config() must not be used when app_config is explicit")
monkeypatch.setattr("deerflow.config.get_app_config", fail_get_app_config)
monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config)
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda app_config=None: SimpleNamespace(load_skills=lambda enabled_only=True: []))
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
prompt = prompt_module.apply_prompt_template(app_config=explicit_config)
assert "`/home/user/shared`" in prompt
assert "Custom Mounted Directories" in prompt
def test_apply_prompt_template_threads_explicit_app_config_to_subagents_without_global_config(monkeypatch):
explicit_config = SimpleNamespace(
sandbox=SimpleNamespace(
use="deerflow.sandbox.local:LocalSandboxProvider",
allow_host_bash=False,
mounts=[],
),
subagents=SubagentsAppConfig(
custom_agents={
"researcher": CustomSubagentConfig(
description="Research agent\nwith details",
system_prompt="You research.",
)
}
),
skills=SimpleNamespace(container_path="/mnt/skills"),
skill_evolution=SimpleNamespace(enabled=False),
tool_search=SimpleNamespace(enabled=False),
memory=SimpleNamespace(enabled=False, injection_enabled=True, max_injection_tokens=2000),
acp_agents={},
)
def fail_get_app_config():
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
def fail_get_subagents_app_config():
raise AssertionError("ambient get_subagents_app_config() must not be used when app_config is explicit")
monkeypatch.setattr("deerflow.config.get_app_config", fail_get_app_config)
monkeypatch.setattr("deerflow.config.subagents_config.get_subagents_app_config", fail_get_subagents_app_config)
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda app_config=None: SimpleNamespace(load_skills=lambda enabled_only=True: []))
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
prompt = prompt_module.apply_prompt_template(subagent_enabled=True, app_config=explicit_config)
assert "**researcher**: Research agent" in prompt
assert "**bash**" not in prompt
def test_build_acp_section_uses_explicit_app_config_without_global_config(monkeypatch):
explicit_config = SimpleNamespace(acp_agents={"codex": object()})
def fail_get_acp_agents():
raise AssertionError("ambient get_acp_agents() must not be used when app_config is explicit")
monkeypatch.setattr("deerflow.config.acp_config.get_acp_agents", fail_get_acp_agents)
section = prompt_module._build_acp_section(app_config=explicit_config)
assert "ACP Agent Tasks" in section
assert "/mnt/acp-workspace/" in section
def test_get_memory_context_uses_explicit_app_config_without_global_config(monkeypatch):
explicit_config = SimpleNamespace(
memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234),
)
captured: dict[str, object] = {}
def fail_get_memory_config():
raise AssertionError("ambient get_memory_config() must not be used when app_config is explicit")
def fake_get_memory_data(agent_name=None, *, user_id=None):
captured["agent_name"] = agent_name
captured["user_id"] = user_id
return {"facts": []}
def fake_format_memory_for_injection(memory_data, *, max_tokens):
captured["memory_data"] = memory_data
captured["max_tokens"] = max_tokens
return "remember this"
monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config)
monkeypatch.setattr("deerflow.runtime.user_context.get_effective_user_id", lambda: "user-1")
monkeypatch.setattr("deerflow.agents.memory.get_memory_data", fake_get_memory_data)
monkeypatch.setattr("deerflow.agents.memory.format_memory_for_injection", fake_format_memory_for_injection)
context = prompt_module._get_memory_context("agent-a", app_config=explicit_config)
assert "<memory>" in context
assert "remember this" in context
assert captured == {
"agent_name": "agent-a",
"user_id": "user-1",
"memory_data": {"facts": []},
"max_tokens": 1234,
}
def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatch, tmp_path):
def make_skill(name: str) -> Skill:
skill_dir = tmp_path / name
@@ -92,7 +225,7 @@ def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatc
)
state = {"skills": [make_skill("first-skill")]}
monkeypatch.setattr(prompt_module, "load_skills", lambda enabled_only=True: list(state["skills"]))
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda **kwargs: __import__("types").SimpleNamespace(load_skills=lambda *, enabled_only: list(state["skills"])))
_set_skills_cache_state()
try:
@@ -145,7 +278,7 @@ def test_clear_cache_does_not_spawn_parallel_refresh_workers(monkeypatch, tmp_pa
return [make_skill(f"skill-{current_call}")]
monkeypatch.setattr(prompt_module, "load_skills", fake_load_skills)
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda **kwargs: __import__("types").SimpleNamespace(load_skills=lambda *, enabled_only: fake_load_skills(enabled_only=enabled_only)))
_set_skills_cache_state()
try:
+23 -1
View File
@@ -100,6 +100,28 @@ def test_get_skills_prompt_section_cache_respects_skill_evolution_toggle(monkeyp
assert "Skill Self-Evolution" not in disabled_result
def test_get_skills_prompt_section_uses_explicit_config_for_enabled_skills(monkeypatch):
explicit_config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/alt-skills"),
skill_evolution=SimpleNamespace(enabled=False),
)
def fail_get_app_config():
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: [_make_skill("global-skill")])
monkeypatch.setattr("deerflow.config.get_app_config", fail_get_app_config)
monkeypatch.setattr(
"deerflow.agents.lead_agent.prompt.get_or_new_skill_storage",
lambda app_config=None, **kwargs: __import__("types").SimpleNamespace(load_skills=lambda *, enabled_only: [_make_skill("explicit-skill")] if app_config is explicit_config else []),
)
result = get_skills_prompt_section(app_config=explicit_config)
assert "explicit-skill" in result
assert "global-skill" not in result
def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
from unittest.mock import MagicMock
@@ -107,7 +129,7 @@ def test_make_lead_agent_empty_skills_passed_correctly(monkeypatch):
# Mock dependencies
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: MagicMock())
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None: "default-model")
monkeypatch.setattr(lead_agent_module, "_resolve_model_name", lambda x=None, **kwargs: "default-model")
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: "model")
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda *args, **kwargs: [])
@@ -11,6 +11,13 @@ from langgraph.errors import GraphBubbleUp
from deerflow.agents.middlewares.llm_error_handling_middleware import (
LLMErrorHandlingMiddleware,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
def _make_app_config() -> AppConfig:
"""Minimal AppConfig for middleware tests; circuit_breaker uses defaults."""
return AppConfig(sandbox=SandboxConfig(use="test"))
class FakeError(Exception):
@@ -31,7 +38,7 @@ class FakeError(Exception):
def _build_middleware(**attrs: int) -> LLMErrorHandlingMiddleware:
middleware = LLMErrorHandlingMiddleware()
middleware = LLMErrorHandlingMiddleware(app_config=_make_app_config())
for key, value in attrs.items():
setattr(middleware, key, value)
return middleware
@@ -226,9 +233,7 @@ def test_circuit_breaker_trips_and_recovers(monkeypatch: pytest.MonkeyPatch) ->
current_time = 1000.0
monkeypatch.setattr("time.time", lambda: current_time)
middleware = LLMErrorHandlingMiddleware()
middleware.circuit_failure_threshold = 3
middleware.circuit_recovery_timeout_sec = 10
middleware = _build_middleware(circuit_failure_threshold=3, circuit_recovery_timeout_sec=10)
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
request: Any = {"messages": []}
@@ -284,8 +289,7 @@ def test_circuit_breaker_does_not_trip_on_non_retriable_errors(monkeypatch: pyte
waits: list[float] = []
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
middleware = LLMErrorHandlingMiddleware()
middleware.circuit_failure_threshold = 3
middleware = _build_middleware(circuit_failure_threshold=3)
monkeypatch.setattr(middleware, "_classify_error", mock_classify_non_retriable)
request: Any = {"messages": []}
@@ -386,9 +390,7 @@ async def test_async_circuit_breaker_trips_and_recovers(monkeypatch: pytest.Monk
current_time = 1000.0
monkeypatch.setattr("time.time", lambda: current_time)
middleware = LLMErrorHandlingMiddleware()
middleware.circuit_failure_threshold = 3
middleware.circuit_recovery_timeout_sec = 10
middleware = _build_middleware(circuit_failure_threshold=3, circuit_recovery_timeout_sec=10)
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
async def async_failing_handler(request: Any) -> Any:

Some files were not shown because too many files have changed in this diff Show More