Compare commits

..

40 Commits

Author SHA1 Message Date
Hinotobi 80e210f5bb [security] fix(uploads): require explicit opt-in for host-side document conversion (#2332)
* fix: disable host-side upload conversion by default

* fix: address PR review comments on upload conversion gate
2026-04-18 22:47:42 +08:00
dependabot[bot] 5656f90792 chore(deps-dev): bump pytest from 9.0.2 to 9.0.3 in /backend (#2349)
Bumps [pytest](https://github.com/pytest-dev/pytest) from 9.0.2 to 9.0.3.
- [Release notes](https://github.com/pytest-dev/pytest/releases)
- [Changelog](https://github.com/pytest-dev/pytest/blob/main/CHANGELOG.rst)
- [Commits](https://github.com/pytest-dev/pytest/compare/9.0.2...9.0.3)

---
updated-dependencies:
- dependency-name: pytest
  dependency-version: 9.0.3
  dependency-type: direct:development
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-18 22:22:40 +08:00
Shawn Jasper 55474011c9 fix(subagent): inherit parent agent's tool_groups in task_tool (#2305)
* fix(subagent): inherit parent agent's tool_groups in task_tool

When a custom agent defines tool_groups (e.g. [file:read, file:write, bash]),
the restriction is correctly applied to the lead agent. However, when the lead
agent delegates work to a subagent via the task tool, get_available_tools() is
called without the groups parameter, causing the subagent to receive ALL tools
(including web_search, web_fetch, image_search, etc.) regardless of the parent
agent's configuration.

This fix propagates tool_groups through run metadata so that task_tool passes
the same group filter when building the subagent's tool set.

Changes:
- agent.py: include tool_groups in run metadata
- task_tool.py: read tool_groups from metadata and pass to get_available_tools()

* fix: initialize metadata before conditional block and update tests for tool_groups propagation

- Initialize metadata = {} before the 'if runtime is not None' block to
  avoid Ruff F821 (possibly-undefined variable) and simplify the
  parent_tool_groups expression.
- Update existing test assertion to expect groups=None in
  get_available_tools call signature.
- Add 3 new test cases:
  - test_task_tool_propagates_tool_groups_to_subagent
  - test_task_tool_no_tool_groups_passes_none
  - test_task_tool_runtime_none_passes_groups_none
2026-04-18 22:17:37 +08:00
imhaoran 24fe5fbd8c fix(mcp): prevent RuntimeError from escaping except block in get_cach… (#2252)
* fix(mcp): prevent RuntimeError from escaping except block in get_cached_mcp_tools

When `asyncio.get_event_loop()` raises RuntimeError and the fallback
`asyncio.run()` also fails, the exception escapes unhandled because
Python does not route exceptions raised inside an `except` block to
sibling `except` clauses. Wrap the fallback call in its own try/except
so failures are logged and the function returns [] as intended.

* fix: use logger.exception to preserve stack traces on MCP init failure
2026-04-18 21:07:30 +08:00
Willem Jiang be4663505a chroe(script): disable the color log of langgraph 2026-04-18 20:03:05 +08:00
dependabot[bot] aa6098e6a4 chore(deps): bump langsmith from 0.6.4 to 0.7.31 in /backend (#2291)
Bumps [langsmith](https://github.com/langchain-ai/langsmith-sdk) from 0.6.4 to 0.7.31.
- [Release notes](https://github.com/langchain-ai/langsmith-sdk/releases)
- [Commits](https://github.com/langchain-ai/langsmith-sdk/compare/v0.6.4...v0.7.31)

---
updated-dependencies:
- dependency-name: langsmith
  dependency-version: 0.7.31
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-18 19:54:21 +08:00
Airene Fang 1221448029 fix(scripts): Cloud Provider Reports Security Issue(aliyun could) (#2323)
ATT&CK矩阵ID:T1059.004
数据来源:进程启动触发检测
告警原因:该进程的命令行显示出反弹shelI的特征
命令行:timeout 1 bash -c exec 3<>/dev/tcp/127.0.0.1/2024
进程路径:/usr/bin/timeout
进程链:-[337650] /usr/sbin/sshd -D
-[397971] /usr/sbin/sshd -D -R
-[397977]-bash
-[398903] make dev
-[398920] bash ./scripts/serve.sh --dev
-[399037]bash ./scripts/wait-for-port.sh 2024 60 LangGraph
2026-04-18 19:33:32 +08:00
Jason 3b91df2b18 fix(frontend): add catch-all API rewrite for gateway routes (#2335)
When NEXT_PUBLIC_BACKEND_BASE_URL is unset, the frontend proxies API
requests to the gateway. Only /api/agents and /api/skills had rewrite
rules, causing 404s for /api/models, /api/threads, /api/memory,
/api/mcp, /api/suggestions, /api/runs, etc.

Add a catch-all /api/:path* rewrite that proxies all remaining gateway
API routes. The existing /api/langgraph rewrite takes priority because
it is pushed to the array first (Next.js checks rewrites in order).

Fixes #2327

Co-authored-by: JasonOA888 <JasonOA888@users.noreply.github.com>
2026-04-18 11:35:19 +08:00
Shawn Jasper ca1b7d5f48 fix(sandbox): add missing path masking in ls_tool output (#2317)
ls_tool was the only file-system tool that did not call
mask_local_paths_in_output() before returning its result, causing host
absolute paths (e.g. /Users/.../backend/.deer-flow/knowledge-base/...)
to leak to the LLM instead of the expected virtual paths
(/mnt/knowledge-base/...).

This patch:
- Adds the mask_local_paths_in_output() call to ls_tool, consistent
  with bash_tool, glob_tool and grep_tool.
- Initialises thread_data = None before the is_local_sandbox branch
  (same pattern as glob_tool) so the variable is always in scope.
- Adds three new tests covering user-data path masking, skills path
  masking and the empty-directory edge case.
2026-04-18 08:46:59 +08:00
yangzheli c6b0423558 feat(frontend): add Playwright E2E tests with CI workflow (#2279)
* feat(frontend): add Playwright E2E tests with CI workflow

Add end-to-end testing infrastructure using Playwright (Chromium only).
14 tests across 5 spec files cover landing page, chat workspace,
thread history, sidebar navigation, and agent chat — all with mocked
LangGraph/Backend APIs via network interception (zero backend dependency).

New files:
- playwright.config.ts — Chromium, 30s timeout, auto-start Next.js
- tests/e2e/utils/mock-api.ts — shared API mocks & SSE stream helpers
- tests/e2e/{landing,chat,thread-history,sidebar,agent-chat}.spec.ts
- .github/workflows/e2e-tests.yml — push main + PR trigger, paths filter

Updated: package.json, Makefile, .gitignore, CONTRIBUTING.md,
frontend/CLAUDE.md, frontend/AGENTS.md, frontend/README.md

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: apply Copilot suggestions

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-18 08:21:08 +08:00
DanielWalnut 898f4e8ac2 fix: Memory update system has cache corruption, data loss, and thread-safety bugs (#2251)
* fix(memory): cache corruption, thread-safety, and caller mutation bugs

Bug 1 (updater.py): deep-copy current_memory before passing to
_apply_updates() so a subsequent save() failure cannot leave a
partially-mutated object in the storage cache.

Bug 3 (storage.py): add _cache_lock (threading.Lock) to
FileMemoryStorage and acquire it around every read/write of
_memory_cache, fixing concurrent-access races between the background
timer thread and HTTP reload calls.

Bug 4 (storage.py): replace in-place mutation
  memory_data["lastUpdated"] = ...
with a shallow copy
  memory_data = {**memory_data, "lastUpdated": ...}
so save() no longer silently modifies the caller's dict.

Regression tests added for all three bugs in test_memory_storage.py
and test_memory_updater.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style: format test_memory_updater.py with ruff

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style: remove stale bug-number labels from code comments and docstrings

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-17 12:00:31 +08:00
dependabot[bot] 259a6844bf chore(deps): bump python-multipart from 0.0.22 to 0.0.26 in /backend (#2282)
Bumps [python-multipart](https://github.com/Kludex/python-multipart) from 0.0.22 to 0.0.26.
- [Release notes](https://github.com/Kludex/python-multipart/releases)
- [Changelog](https://github.com/Kludex/python-multipart/blob/master/CHANGELOG.md)
- [Commits](https://github.com/Kludex/python-multipart/compare/0.0.22...0.0.26)

---
updated-dependencies:
- dependency-name: python-multipart
  dependency-version: 0.0.26
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-16 09:07:28 +08:00
d 🔹 a664d2f5c4 fix(checkpointer): create parent directory before opening SQLite in sync provider (#2272)
* fix(checkpointer): create parent directory before opening SQLite in sync provider

The sync checkpointer factory (_sync_checkpointer_cm) opens a SQLite
connection without first ensuring the parent directory exists.  The async
provider and both store providers already call ensure_sqlite_parent_dir(),
but this call was missing from the sync path.

When the deer-flow harness package is used from an external virtualenv
(where the .deer-flow directory is not pre-created), the missing parent
directory causes:

    sqlite3.OperationalError: unable to open database file

Add the missing ensure_sqlite_parent_dir() call in the sync SQLite
branch, consistent with the async provider, and add a regression test.

Closes #2259

* style: fix ruff format + add call-order assertion for ensure_parent_dir

- Fix formatting in test_checkpointer.py (ruff format)
- Add test_sqlite_ensure_parent_dir_before_connect to verify
  ensure_sqlite_parent_dir is called before from_conn_string
  (addresses Copilot review suggestion)

---------

Co-authored-by: voidborne-d <voidborne-d@users.noreply.github.com>
2026-04-16 09:06:38 +08:00
YuJitang 105db00987 feat: show token usage per assistant response (#2270)
* feat: show token usage per assistant response

* fix: align client models response with token usage

* fix: address token usage review feedback

* docs: clarify token usage config example

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-16 08:56:49 +08:00
Nan Gao 0e16a7fe55 fix(frontend): make Suggestion button opaque in dark mode (#2276)
* fix(frontend): make Suggestion button opaque in dark mode

The outline Button variant applies dark:bg-input/30, leaving Suggestion
pills ~70% transparent in dark mode. Scrolled chat content bled through
the buttons, making suggestion text unreadable. Override with
dark:bg-background so it matches the opaque light-mode appearance.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix the lint error of commit

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-16 08:55:16 +08:00
Nan Gao 4d3038a7b6 fix(frontend): stop artifact panel from auto-opening on rehydrated write_file (#2278)
After a page refresh, the artifact panel's autoOpen/autoSelect state is
reset to true. Submitting a new question flips thread.isLoading to true,
which message-list passes to every MessageGroup — including historical
ones. The previous response's last write_file step then satisfies the
auto-open condition and re-pops the stale artifact.

Gate the auto-open on the tool call having no result yet, so only a
write_file that is still streaming in the current response can trigger
it; rehydrated tool calls always carry a result and are now skipped.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-16 08:46:47 +08:00
Hinotobi 2176b2bbfc fix: validate bootstrap agent names before filesystem writes (#2274)
* fix: validate bootstrap agent names before filesystem writes

* fix: tighten bootstrap agent-name validation
2026-04-16 08:36:42 +08:00
Wen 8e3591312a test: add unit tests for ViewImageMiddleware (#2256)
* test: add unit tests for ViewImageMiddleware

- Add 33 test cases covering all 7 internal methods plus sync/async
  before_model hooks
- Cover normal path, edge cases (missing keys, empty base64, stale
  ToolMessages before assistant turn), and deduplication logic
- Related to Q2 Roadmap #1669

* test: add unit tests for ViewImageMiddleware

Add 35 test cases covering all internal methods, before_model hooks,
and edge cases (missing attrs, list-content dedup, stale ToolMessages).

Related to #1669
2026-04-15 23:54:30 +08:00
Willem Jiang 242c654075 fix(frontend):lint error of message-list-item.tsx 2026-04-15 23:35:50 +08:00
Willem Jiang 0c21cbf01f fix(frontend): lint error of frontend 2026-04-15 23:27:46 +08:00
Jason 772538ddba fix(frontend): add skills API rewrite rule to prevent HTML fallback (#2241)
Fixes #2203

When NEXT_PUBLIC_BACKEND_BASE_URL is not set, the frontend uses Next.js
rewrites to proxy API calls to the gateway. Skills API routes were missing
from the rewrite config, causing /api/skills to return the SPA HTML instead
of JSON, which produced 'Unexpected token <' errors in the skill settings page.

Co-authored-by: JasonOA888 <JasonOA888@users.noreply.github.com>
2026-04-15 23:21:40 +08:00
Jason 35fb3dd65a fix(frontend): resolve /mnt/ links in markdown to artifact API URLs (#2243)
* fix(gateway): forward agent_name and is_bootstrap from context to configurable

The frontend sends agent_name and is_bootstrap via the context field
in run requests, but services.py only forwards a hardcoded whitelist
of keys (_CONTEXT_CONFIGURABLE_KEYS) into the agent's configurable
dict.  Since agent_name was missing, custom agents never received
their name — make_lead_agent always fell back to the default lead
agent, skipping SOUL.md, per-agent config and skill filtering.

Similarly, is_bootstrap was dropped, so the bootstrap creation flow
could never activate the setup_agent tool path.

Add both keys to the whitelist so they reach make_lead_agent.

Fixes #2222

* fix(frontend): resolve /mnt/ links in markdown to artifact API URLs

AI agent messages contain links like /mnt/user-data/outputs/file.pdf
which were rendered as-is in the browser, resulting in 404 errors.
Images already got the correct treatment via MessageImage and
resolveArtifactURL, but anchor tags (<a>) were passed through
unchanged.

Add an 'a' component override in MessageContent_ that rewrites
/mnt/-prefixed hrefs to the artifact API endpoint, matching the
existing image handling pattern.

Fixes #2232

---------

Co-authored-by: JasonOA888 <JasonOA888@users.noreply.github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-15 23:12:21 +08:00
Jason 692f79452d fix(gateway): forward agent_name and is_bootstrap from context to configurable (#2242)
The frontend sends agent_name and is_bootstrap via the context field
in run requests, but services.py only forwards a hardcoded whitelist
of keys (_CONTEXT_CONFIGURABLE_KEYS) into the agent's configurable
dict.  Since agent_name was missing, custom agents never received
their name — make_lead_agent always fell back to the default lead
agent, skipping SOUL.md, per-agent config and skill filtering.

Similarly, is_bootstrap was dropped, so the bootstrap creation flow
could never activate the setup_agent tool path.

Add both keys to the whitelist so they reach make_lead_agent.

Fixes #2222

Co-authored-by: JasonOA888 <JasonOA888@users.noreply.github.com>
2026-04-15 23:11:10 +08:00
DanielWalnut 8760937439 fix(memory): use asyncio.to_thread for blocking file I/O in aupdate_memory (#2220)
* fix(memory): use asyncio.to_thread for blocking file I/O in aupdate_memory

`_finalize_update` performs synchronous blocking operations (os.mkdir,
file open/write/rename/stat) that were called directly from the async
`aupdate_memory` method, causing `BlockingError` from blockbuster when
running under an ASGI server. Wrap the call with `asyncio.to_thread` to
offload all blocking I/O to a thread pool.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(memory): use unique temp filename to prevent concurrent write collision

`file_path.with_suffix(".tmp")` produces a fixed path — concurrent saves
for the same agent (now possible after wrapping _finalize_update in
asyncio.to_thread) would clobber the same temp file. Use a UUID-suffixed
temp file so each write is isolated.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(memory): also offload _prepare_update_prompt to thread pool

FileMemoryStorage.load() inside _prepare_update_prompt performs
synchronous stat() and file read, blocking the event loop just like
_finalize_update did. Wrap _prepare_update_prompt in asyncio.to_thread
for the same reason.

The async path now has no blocking file I/O on the event loop:
  to_thread(_prepare_update_prompt) → await model.ainvoke() → to_thread(_finalize_update)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-14 16:41:54 +08:00
DanielWalnut 4ba3167f48 feat: flush memory before summarization (#2176)
* feat: flush memory before summarization

* fix: keep agent-scoped memory on summarization flush

* fix: harden summarization hook plumbing

* fix: address summarization review feedback

* style: format memory middleware
2026-04-14 15:01:06 +08:00
Octopus e4f896e90d fix(todo-middleware): prevent premature agent exit with incomplete todos (#2135)
* fix(todo-middleware): prevent premature agent exit with incomplete todos

When plan mode is active (is_plan_mode=True), the agent occasionally
exits the loop and outputs a final response while todo items are still
incomplete. This happens because the routing edge only checks for
tool_calls, not todo completion state.

Fixes #2112

Add an after_model override to TodoMiddleware with
@hook_config(can_jump_to=["model"]). When the model produces a
response with no tool calls but there are still incomplete todos, the
middleware injects a todo_completion_reminder HumanMessage and returns
jump_to=model to force another model turn. A cap of 2 reminders
prevents infinite loops when the agent cannot make further progress.

Also adds _completion_reminder_count() helper and 14 new unit tests
covering all edge cases of the new after_model / aafter_model logic.

* Remove unnecessary blank line in test file

* Fix runtime argument annotation in before_model

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: octo-patch <octo-patch@github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-14 11:11:26 +08:00
luo jiyin 07fc25d285 feat: switch memory updater to async LLM calls (#2138)
* docs: mark memory updater async migration as completed

- Update TODO.md to mark the replacement of sync model.invoke()
  with async model.ainvoke() in title_middleware and memory updater
  as completed using [x] format

Addresses #2131

* feat: switch memory updater to async LLM calls

- Add async aupdate_memory() method using await model.ainvoke()
- Convert sync update_memory() to use async wrapper
- Add _run_async_update_sync() for nested loop context handling
- Maintain backward compatibility with existing sync API
- Add ThreadPoolExecutor for async execution from sync contexts

Addresses #2131

* test: add tests for async memory updater

- Add test_async_update_memory_uses_ainvoke() to verify async path
- Convert existing tests to use AsyncMock and ainvoke assertions
- Add test_sync_update_memory_wrapper_works_in_running_loop()
- Update all model mocks to use async await patterns

Addresses #2131

* fix: apply ruff formatting to memory updater

- Format multi-line expressions to single line
- Ensure code style consistency with project standards
- Fix lint issues caught by GitHub Actions

* test: add comprehensive tests for async memory updater

- Add test_async_update_memory_uses_ainvoke() to verify async path
- Convert existing tests to use AsyncMock and ainvoke assertions
- Add test_sync_update_memory_wrapper_works_in_running_loop()
- Update all model mocks to use async await patterns
- Ensure backward compatibility with sync API

* fix: satisfy ruff formatting in memory updater test

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-14 11:10:42 +08:00
Nan Gao 55bc09ac33 fix(backend): fix uploads for mounted sandbox providers (#2199)
* fix uploads for mounted sandbox providers

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-14 10:44:31 +08:00
dependabot[bot] c43a45ea40 chore(deps): bump pillow from 12.1.1 to 12.2.0 in /backend (#2206)
Bumps [pillow](https://github.com/python-pillow/Pillow) from 12.1.1 to 12.2.0.
- [Release notes](https://github.com/python-pillow/Pillow/releases)
- [Changelog](https://github.com/python-pillow/Pillow/blob/main/CHANGES.rst)
- [Commits](https://github.com/python-pillow/Pillow/compare/12.1.1...12.2.0)

---
updated-dependencies:
- dependency-name: pillow
  dependency-version: 12.2.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-14 10:35:59 +08:00
Admire 9cf7153b1d fix(check): windows pnpm version detection in check script (#2189)
* fix: resolve Windows pnpm detection in check script

* style: format check script regression test

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix: resolve corepack fallback on windows

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-14 10:29:44 +08:00
Octopus c91785dd68 fix(title): strip <think> tags from title model responses and assistant context (#1927)
* fix(title): strip <think> tags from title model responses and assistant context

Reasoning models (e.g. minimax M2.7, DeepSeek-R1) emit <think>...</think>
blocks before their actual output. When such a model is used as the title
model (or as the main agent), the raw thinking content leaked into the thread
title stored in state, so the chat list showed the internal monologue instead
of a meaningful title.

Fixes #1884

- Add `_strip_think_tags()` helper using a regex to remove all <think>...</think> blocks
- Apply it in `_parse_title()` so the title model response is always clean
- Apply it to the assistant message in `_build_title_prompt()` so thinking
  content from the first AI turn is not fed back to the title model
- Add four new unit tests covering: stripping in parse, think-only response,
  assistant prompt stripping, and end-to-end async flow with think tags

* Fix the lint error

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-14 09:51:39 +08:00
sqsge 053e18e1a6 fix(skills): avoid blocking custom skill deletion on readonly history writes (#2197) 2026-04-14 09:00:29 +08:00
Hinotobi a7e7c6d667 fix: disable custom-agent management API by default (#2161)
* fix: disable custom-agent management API by default

* style: format agents API hardening files

* fix: address review feedback for agents API hardening

* fix: add missing disabled API coverage
2026-04-14 00:03:38 +08:00
Nan Gao f4c17c66ce fix(middleware): fix present_files thread id fallback (#2181)
* fix present files thread id fallback

* fix: resolve present_files thread id from runtime config
2026-04-13 22:59:13 +08:00
lesliewangwyc-dev 1df389b9d0 fix: wrap blocking readability call with asyncio.to_thread in web_fetch (#2157)
* fix: wrap blocking readability call with asyncio.to_thread in web_fetch

The readability extractor internally spawns a Node.js subprocess via
readabilipy, which blocks the async event loop and causes a
BlockingError when web_fetch is invoked inside LangGraph's async
runtime.

Wrap the synchronous extract_article call with asyncio.to_thread to
offload it to a thread pool, unblocking the event loop.

Note: community/infoquest/tools.py has the same latent issue and
should be addressed in a follow-up PR.

Closes #2152

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* test: verify web_fetch offloads extraction via asyncio.to_thread

Add a regression test that monkeypatches asyncio.to_thread to confirm
readability extraction is offloaded to a worker thread, preventing
future refactors from reintroducing the blocking call.

Addresses Copilot review feedback on #2157.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-13 21:15:24 +08:00
5db71cb68c fix(middleware): repair dangling tool-call history after loop interru… (#2035)
* fix(middleware): repair dangling tool-call history after loop interruption (#2029)

* docs(backend): fix middleware chain ordering

---------

Co-authored-by: luoxiao6645 <luoxiao6645@gmail.com>
2026-04-12 19:11:22 +08:00
yangzheli 4efc8d404f feat(frontend): set up Vitest frontend testing infrastructure with CI workflow (#2147)
* feat: set up Vitest frontend testing infrastructure with CI workflow

Migrate existing 4 frontend test files from Node.js native test runner
(node:test + node:assert/strict) to Vitest, reorganize test directory
structure under tests/unit/ mirroring src/ layout, and add a dedicated
CI workflow for frontend unit tests.

- Add vitest as devDependency, remove tsx
- Create vitest.config.ts with @/ path alias
- Migrate tests to Vitest API (test/expect/vi)
- Rename .mjs test files to .ts
- Move tests from src/ to tests/unit/ (mirrors src/ layout)
- Add frontend/Makefile `test` target
- Add .github/workflows/frontend-unit-tests.yml (parallel to backend)
- Update CONTRIBUTING.md, README.md, AGENTS.md, CLAUDE.md

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* style: fix the lint error

* style: fix the lint error

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-12 18:00:43 +08:00
Jin 4d4ddb3d3f feat(llm): introduce lightweight circuit breaker to prevent rate-limit bans and resource exhaustion (#2095) 2026-04-12 17:48:40 +08:00
luo jiyin 979a461af5 docs: move completed async migration to Completed Features (#2146)
- Move time.sleep() -> asyncio.sleep() from Planned to Completed Features
- Clean up duplicate entries in TODO.md

Ensures completed async optimizations are properly tracked.
2026-04-12 16:48:48 +08:00
Javen Fang ac04f2704f feat(subagents): allow model override per subagent in config.yaml (#2064)
* feat(subagents): allow model override per subagent in config.yaml

Wire the existing SubagentConfig.model field to config.yaml so users
can assign different models to different subagent types.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* test(subagents): cover model override in SubagentsAppConfig + registry

Addresses review feedback on #2064:

- registry.py: update stale inline comment — the block now applies
  timeout, max_turns AND model overrides, not just timeout.
- test_subagent_timeout_config.py: add coverage for model override
  resolution across SubagentOverrideConfig, SubagentsAppConfig
  (get_model_for + load), and registry.get_subagent_config:
  - per-agent model override is applied to registry-returned config
  - omitted `model` keeps the builtin value
  - explicit `model: null` in config.yaml is equivalent to omission
  - model override on one agent does not affect other agents
  - model override preserves all other fields (name, description,
    timeout_seconds, max_turns)
  - model override does not mutate BUILTIN_SUBAGENTS

Copilot's suggestion (3) "setting model to 'inherit' forces inheritance"
is skipped intentionally: there is no 'inherit' sentinel in the current
implementation — model is `str | None`, and None already means
"inherit from parent". Adding a sentinel would be a new feature, not
test coverage for this PR.

Tests run locally: 51 passed (37 existing + 14 new / expanded).

* test(subagents): reject empty-string model at config load time

Addresses WillemJiang's review comment on #2064 (empty-string edge case):

- subagents_config.py: add `min_length=1` to the `model` field on
  SubagentOverrideConfig. `model: ""` in config.yaml would otherwise
  bypass the `is not None` check and reach create_chat_model(name="")
  as a confusing runtime error. This is symmetric with the existing
  `ge=1` guards on timeout_seconds / max_turns, so the validation style
  stays consistent across all three override fields.
- test_subagent_timeout_config.py: add test_rejects_empty_model
  mirroring the existing test_rejects_zero / test_rejects_negative
  cases; update the docstring on test_model_accepts_any_string (now
  test_model_accepts_any_non_empty_string) to reflect the new guard.

Not addressing the first comment (validating `model` against the
`models:` section at load time) in this PR. `SubagentsAppConfig` is
scoped to the `subagents:` block and cannot see the sibling `models:`
section, so proper cross-section validation needs a second pass or a
structural change that is out of scope here — and the current behavior
is consistent with how timeout_seconds / max_turns work today. Happy to
track this as a follow-up issue covering cross-section validation
uniformly for all three fields.

Tests run locally: 52 passed in this file; 1847 passed, 18 skipped
across the full backend suite. Ruff check + format clean.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-12 16:40:21 +08:00
181 changed files with 9116 additions and 5567 deletions
+63
View File
@@ -0,0 +1,63 @@
name: E2E Tests
on:
push:
branches: [ 'main' ]
paths:
- 'frontend/**'
- '.github/workflows/e2e-tests.yml'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- 'frontend/**'
- '.github/workflows/e2e-tests.yml'
concurrency:
group: e2e-tests-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
e2e-tests:
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.draft == false }}
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
- name: Enable Corepack
run: corepack enable
- name: Use pinned pnpm version
run: corepack prepare pnpm@10.26.2 --activate
- name: Install frontend dependencies
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: Install Playwright Chromium
working-directory: frontend
run: npx playwright install chromium --with-deps
- name: Run E2E tests
working-directory: frontend
run: pnpm exec playwright test
env:
SKIP_ENV_VALIDATION: '1'
- name: Upload Playwright report
uses: actions/upload-artifact@v4
if: ${{ !cancelled() }}
with:
name: playwright-report
path: frontend/playwright-report/
retention-days: 7
+43
View File
@@ -0,0 +1,43 @@
name: Frontend Unit Tests
on:
push:
branches: [ 'main' ]
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
concurrency:
group: frontend-unit-tests-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
frontend-unit-tests:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
- name: Enable Corepack
run: corepack enable
- name: Use pinned pnpm version
run: corepack prepare pnpm@10.26.2 --activate
- name: Install frontend dependencies
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: Run unit tests of frontend
working-directory: frontend
run: make test
+2
View File
@@ -55,5 +55,7 @@ web/
backend/Dockerfile.langgraph backend/Dockerfile.langgraph
config.yaml.bak config.yaml.bak
.playwright-mcp .playwright-mcp
/frontend/test-results/
/frontend/playwright-report/
.gstack/ .gstack/
.worktrees .worktrees
+11 -6
View File
@@ -298,19 +298,24 @@ Nginx (port 2026) ← Unified entry point
```bash ```bash
# Backend tests # Backend tests
cd backend cd backend
uv run pytest make test
# Frontend checks # Frontend unit tests
cd frontend cd frontend
pnpm check make test
# Frontend E2E tests (requires Chromium; builds and auto-starts the Next.js production server)
cd frontend
make test-e2e
``` ```
### PR Regression Checks ### PR Regression Checks
Every pull request runs the backend regression workflow at [.github/workflows/backend-unit-tests.yml](.github/workflows/backend-unit-tests.yml), including: Every pull request triggers the following CI workflows:
- `tests/test_provisioner_kubeconfig.py` - **Backend unit tests** — [.github/workflows/backend-unit-tests.yml](.github/workflows/backend-unit-tests.yml)
- `tests/test_docker_sandbox_mode_detection.py` - **Frontend unit tests** — [.github/workflows/frontend-unit-tests.yml](.github/workflows/frontend-unit-tests.yml)
- **Frontend E2E tests** — [.github/workflows/e2e-tests.yml](.github/workflows/e2e-tests.yml) (triggered only when `frontend/` files change)
## Code Style ## Code Style
+2
View File
@@ -658,6 +658,8 @@ This is the difference between a chatbot with tool access and an agent with an a
**Summarization**: Within a session, DeerFlow manages context aggressively — summarizing completed sub-tasks, offloading intermediate results to the filesystem, compressing what's no longer immediately relevant. This lets it stay sharp across long, multi-step tasks without blowing the context window. **Summarization**: Within a session, DeerFlow manages context aggressively — summarizing completed sub-tasks, offloading intermediate results to the filesystem, compressing what's no longer immediately relevant. This lets it stay sharp across long, multi-step tasks without blowing the context window.
**Strict Tool-Call Recovery**: When a provider or middleware interrupts a tool-call loop, DeerFlow now strips provider-level raw tool-call metadata on forced-stop assistant messages and injects placeholder tool results for dangling calls before the next model invocation. This keeps OpenAI-compatible reasoning models that strictly validate `tool_call_id` sequences from failing with malformed history errors.
### Long-Term Memory ### Long-Term Memory
Most agents forget everything the moment a conversation ends. DeerFlow remembers. Most agents forget everything the moment a conversation ends. DeerFlow remembers.
+17 -13
View File
@@ -156,20 +156,26 @@ from deerflow.config import get_app_config
### Middleware Chain ### Middleware Chain
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`: Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`_build_middlewares`):
1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory 1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation 2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state 3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption) 4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption), including raw provider tool-call payloads preserved only in `additional_kwargs["tool_calls"]`
5. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider. 5. **LLMErrorHandlingMiddleware** - Normalizes provider/model invocation failures into recoverable assistant-facing errors before later middleware/tool stages run
6. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled) 6. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
7. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode) 7. **SandboxAuditMiddleware** - Audits sandboxed shell/file operations for security logging before tool execution continues
8. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model 8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
9. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses) 9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
10. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support) 10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
11. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if subagent_enabled) 11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
12. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last) 12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
15. **DeferredToolFilterMiddleware** - Hides deferred tool schemas from the bound model until tool search is enabled (optional)
16. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
17. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
18. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
### Configuration System ### Configuration System
@@ -179,9 +185,7 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
**Config Versioning**: `config.example.yaml` has a `config_version` field. On startup, `AppConfig.from_file()` compares user version vs example version and emits a warning if outdated. Missing `config_version` = version 0. Run `make config-upgrade` to auto-merge missing fields. When changing the config schema, bump `config_version` in `config.example.yaml`. **Config Versioning**: `config.example.yaml` has a `config_version` field. On startup, `AppConfig.from_file()` compares user version vs example version and emits a warning if outdated. Missing `config_version` = version 0. Run `make config-upgrade` to auto-merge missing fields. When changing the config schema, bump `config_version` in `config.example.yaml`.
**Config Lifecycle**: All config models are `frozen=True` (immutable after construction). `AppConfig.from_file()` is a pure function — no side effects on sub-module globals. `get_app_config()` is backed by a single `ContextVar`, set once via `init_app_config()` at process startup. To update config at runtime (e.g., Gateway API updates MCP/Skills), construct a new `AppConfig.from_file()` and call `init_app_config()` again. No mtime detection, no auto-reload. **Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart.
**DeerFlowContext**: Per-invocation typed context for the agent execution path, injected via LangGraph `Runtime[DeerFlowContext]`. Holds `app_config: AppConfig`, `thread_id: str`, `agent_name: str | None`. Gateway runtime and `DeerFlowClient` construct full `DeerFlowContext` at invoke time; LangGraph Server path uses a fallback via `resolve_context()`. Middleware and tools access context through `resolve_context(runtime)` which returns a typed `DeerFlowContext` regardless of entry point. Mutable runtime state (`sandbox_id`) flows through `ThreadState.sandbox`, not context.
Configuration priority: Configuration priority:
1. Explicit `config_path` argument 1. Explicit `config_path` argument
+2 -2
View File
@@ -67,9 +67,9 @@ class ChannelService:
@classmethod @classmethod
def from_app_config(cls) -> ChannelService: def from_app_config(cls) -> ChannelService:
"""Create a ChannelService from the application config.""" """Create a ChannelService from the application config."""
from deerflow.config.app_config import AppConfig from deerflow.config.app_config import get_app_config
config = AppConfig.current() config = get_app_config()
channels_config = {} channels_config = {}
# extra fields are allowed by AppConfig (extra="allow") # extra fields are allowed by AppConfig (extra="allow")
extra = config.model_extra or {} extra = config.model_extra or {}
+2 -2
View File
@@ -21,7 +21,7 @@ from app.gateway.routers import (
threads, threads,
uploads, uploads,
) )
from deerflow.config.app_config import AppConfig from deerflow.config.app_config import get_app_config
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
@@ -39,7 +39,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Load config and check necessary environment variables at startup # Load config and check necessary environment variables at startup
try: try:
AppConfig.current() get_app_config()
logger.info("Configuration loaded successfully") logger.info("Configuration loaded successfully")
except Exception as e: except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}" error_msg = f"Failed to load configuration during gateway startup: {e}"
+21
View File
@@ -8,6 +8,7 @@ import yaml
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from deerflow.config.agents_api_config import get_agents_api_config
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
@@ -73,6 +74,15 @@ def _normalize_agent_name(name: str) -> str:
return name.lower() return name.lower()
def _require_agents_api_enabled() -> None:
"""Reject access unless the custom-agent management API is explicitly enabled."""
if not get_agents_api_config().enabled:
raise HTTPException(
status_code=403,
detail=("Custom-agent management API is disabled. Set agents_api.enabled=true to expose agent and user-profile routes over HTTP."),
)
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse: def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
"""Convert AgentConfig to AgentResponse.""" """Convert AgentConfig to AgentResponse."""
soul: str | None = None soul: str | None = None
@@ -100,6 +110,8 @@ async def list_agents() -> AgentsListResponse:
Returns: Returns:
List of all custom agents with their metadata and soul content. List of all custom agents with their metadata and soul content.
""" """
_require_agents_api_enabled()
try: try:
agents = list_custom_agents() agents = list_custom_agents()
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents]) return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
@@ -125,6 +137,7 @@ async def check_agent_name(name: str) -> dict:
Raises: Raises:
HTTPException: 422 if the name is invalid. HTTPException: 422 if the name is invalid.
""" """
_require_agents_api_enabled()
_validate_agent_name(name) _validate_agent_name(name)
normalized = _normalize_agent_name(name) normalized = _normalize_agent_name(name)
available = not get_paths().agent_dir(normalized).exists() available = not get_paths().agent_dir(normalized).exists()
@@ -149,6 +162,7 @@ async def get_agent(name: str) -> AgentResponse:
Raises: Raises:
HTTPException: 404 if agent not found. HTTPException: 404 if agent not found.
""" """
_require_agents_api_enabled()
_validate_agent_name(name) _validate_agent_name(name)
name = _normalize_agent_name(name) name = _normalize_agent_name(name)
@@ -181,6 +195,7 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
Raises: Raises:
HTTPException: 409 if agent already exists, 422 if name is invalid. HTTPException: 409 if agent already exists, 422 if name is invalid.
""" """
_require_agents_api_enabled()
_validate_agent_name(request.name) _validate_agent_name(request.name)
normalized_name = _normalize_agent_name(request.name) normalized_name = _normalize_agent_name(request.name)
@@ -243,6 +258,7 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
Raises: Raises:
HTTPException: 404 if agent not found. HTTPException: 404 if agent not found.
""" """
_require_agents_api_enabled()
_validate_agent_name(name) _validate_agent_name(name)
name = _normalize_agent_name(name) name = _normalize_agent_name(name)
@@ -315,6 +331,8 @@ async def get_user_profile() -> UserProfileResponse:
Returns: Returns:
UserProfileResponse with content=None if USER.md does not exist yet. UserProfileResponse with content=None if USER.md does not exist yet.
""" """
_require_agents_api_enabled()
try: try:
user_md_path = get_paths().user_md_file user_md_path = get_paths().user_md_file
if not user_md_path.exists(): if not user_md_path.exists():
@@ -341,6 +359,8 @@ async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileR
Returns: Returns:
UserProfileResponse with the saved content. UserProfileResponse with the saved content.
""" """
_require_agents_api_enabled()
try: try:
paths = get_paths() paths = get_paths()
paths.base_dir.mkdir(parents=True, exist_ok=True) paths.base_dir.mkdir(parents=True, exist_ok=True)
@@ -367,6 +387,7 @@ async def delete_agent(name: str) -> None:
Raises: Raises:
HTTPException: 404 if agent not found. HTTPException: 404 if agent not found.
""" """
_require_agents_api_enabled()
_validate_agent_name(name) _validate_agent_name(name)
name = _normalize_agent_name(name) name = _normalize_agent_name(name)
+7 -9
View File
@@ -6,8 +6,7 @@ from typing import Literal
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from deerflow.config.app_config import AppConfig from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config
from deerflow.config.extensions_config import ExtensionsConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["mcp"]) router = APIRouter(prefix="/api", tags=["mcp"])
@@ -91,9 +90,9 @@ async def get_mcp_configuration() -> McpConfigResponse:
} }
``` ```
""" """
ext = AppConfig.current().extensions config = get_extensions_config()
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in ext.mcp_servers.items()}) return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in config.mcp_servers.items()})
@router.put( @router.put(
@@ -144,12 +143,12 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
logger.info(f"No existing extensions config found. Creating new config at: {config_path}") logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
# Load current config to preserve skills configuration # Load current config to preserve skills configuration
current_ext = AppConfig.current().extensions current_config = get_extensions_config()
# Convert request to dict format for JSON serialization # Convert request to dict format for JSON serialization
config_data = { config_data = {
"mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()}, "mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_ext.skills.items()}, "skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
} }
# Write the configuration to file # Write the configuration to file
@@ -162,9 +161,8 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
# will detect config file changes via mtime and reinitialize MCP tools automatically # will detect config file changes via mtime and reinitialize MCP tools automatically
# Reload the configuration and update the global cache # Reload the configuration and update the global cache
AppConfig.init(AppConfig.from_file()) reloaded_config = reload_extensions_config()
reloaded_ext = AppConfig.current().extensions return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_config.mcp_servers.items()})
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_ext.mcp_servers.items()})
except Exception as e: except Exception as e:
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True) logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
+3 -3
View File
@@ -12,7 +12,7 @@ from deerflow.agents.memory.updater import (
reload_memory_data, reload_memory_data,
update_memory_fact, update_memory_fact,
) )
from deerflow.config.app_config import AppConfig from deerflow.config.memory_config import get_memory_config
router = APIRouter(prefix="/api", tags=["memory"]) router = APIRouter(prefix="/api", tags=["memory"])
@@ -311,7 +311,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
} }
``` ```
""" """
config = AppConfig.current().memory config = get_memory_config()
return MemoryConfigResponse( return MemoryConfigResponse(
enabled=config.enabled, enabled=config.enabled,
storage_path=config.storage_path, storage_path=config.storage_path,
@@ -336,7 +336,7 @@ async def get_memory_status() -> MemoryStatusResponse:
Returns: Returns:
Combined memory configuration and current data. Combined memory configuration and current data.
""" """
config = AppConfig.current().memory config = get_memory_config()
memory_data = get_memory_data() memory_data = get_memory_data()
return MemoryStatusResponse( return MemoryStatusResponse(
+25 -8
View File
@@ -1,7 +1,7 @@
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
router = APIRouter(prefix="/api", tags=["models"]) router = APIRouter(prefix="/api", tags=["models"])
@@ -17,10 +17,17 @@ class ModelResponse(BaseModel):
supports_reasoning_effort: bool = Field(default=False, description="Whether model supports reasoning effort") supports_reasoning_effort: bool = Field(default=False, description="Whether model supports reasoning effort")
class TokenUsageResponse(BaseModel):
"""Token usage display configuration."""
enabled: bool = Field(default=False, description="Whether token usage display is enabled")
class ModelsListResponse(BaseModel): class ModelsListResponse(BaseModel):
"""Response model for listing all models.""" """Response model for listing all models."""
models: list[ModelResponse] models: list[ModelResponse]
token_usage: TokenUsageResponse
@router.get( @router.get(
@@ -36,7 +43,7 @@ async def list_models() -> ModelsListResponse:
excluding sensitive fields like API keys and internal configuration. excluding sensitive fields like API keys and internal configuration.
Returns: Returns:
A list of all configured models with their metadata. A list of all configured models with their metadata and token usage display settings.
Example Response: Example Response:
```json ```json
@@ -44,21 +51,28 @@ async def list_models() -> ModelsListResponse:
"models": [ "models": [
{ {
"name": "gpt-4", "name": "gpt-4",
"model": "gpt-4",
"display_name": "GPT-4", "display_name": "GPT-4",
"description": "OpenAI GPT-4 model", "description": "OpenAI GPT-4 model",
"supports_thinking": false "supports_thinking": false,
"supports_reasoning_effort": false
}, },
{ {
"name": "claude-3-opus", "name": "claude-3-opus",
"model": "claude-3-opus",
"display_name": "Claude 3 Opus", "display_name": "Claude 3 Opus",
"description": "Anthropic Claude 3 Opus model", "description": "Anthropic Claude 3 Opus model",
"supports_thinking": true "supports_thinking": true,
"supports_reasoning_effort": false
} }
] ],
"token_usage": {
"enabled": true
}
} }
``` ```
""" """
config = AppConfig.current() config = get_app_config()
models = [ models = [
ModelResponse( ModelResponse(
name=model.name, name=model.name,
@@ -70,7 +84,10 @@ async def list_models() -> ModelsListResponse:
) )
for model in config.models for model in config.models
] ]
return ModelsListResponse(models=models) return ModelsListResponse(
models=models,
token_usage=TokenUsageResponse(enabled=config.token_usage.enabled),
)
@router.get( @router.get(
@@ -101,7 +118,7 @@ async def get_model(model_name: str) -> ModelResponse:
} }
``` ```
""" """
config = AppConfig.current() config = get_app_config()
model = config.get_model_config(model_name) model = config.get_model_config(model_name)
if model is None: if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
+24 -19
View File
@@ -1,3 +1,4 @@
import errno
import json import json
import logging import logging
import shutil import shutil
@@ -8,8 +9,7 @@ from pydantic import BaseModel, Field
from app.gateway.path_utils import resolve_thread_virtual_path from app.gateway.path_utils import resolve_thread_virtual_path
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.config.app_config import AppConfig from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
from deerflow.skills import Skill, load_skills from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
from deerflow.skills.manager import ( from deerflow.skills.manager import (
@@ -202,18 +202,23 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
ensure_custom_skill_is_editable(skill_name) ensure_custom_skill_is_editable(skill_name)
skill_dir = get_custom_skill_dir(skill_name) skill_dir = get_custom_skill_dir(skill_name)
prev_content = read_custom_skill_content(skill_name) prev_content = read_custom_skill_content(skill_name)
append_history( try:
skill_name, append_history(
{ skill_name,
"action": "human_delete", {
"author": "human", "action": "human_delete",
"thread_id": None, "author": "human",
"file_path": "SKILL.md", "thread_id": None,
"prev_content": prev_content, "file_path": "SKILL.md",
"new_content": None, "prev_content": prev_content,
"scanner": {"decision": "allow", "reason": "Deletion requested."}, "new_content": None,
}, "scanner": {"decision": "allow", "reason": "Deletion requested."},
) },
)
except OSError as e:
if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}:
raise
logger.warning("Skipping delete history write for custom skill %s due to readonly/permission failure; continuing with skill directory removal: %s", skill_name, e)
shutil.rmtree(skill_dir) shutil.rmtree(skill_dir)
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async()
return {"success": True} return {"success": True}
@@ -326,19 +331,19 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
config_path = Path.cwd().parent / "extensions_config.json" config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}") logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
ext = AppConfig.current().extensions extensions_config = get_extensions_config()
ext.skills[skill_name] = SkillStateConfig(enabled=request.enabled) extensions_config.skills[skill_name] = SkillStateConfig(enabled=request.enabled)
config_data = { config_data = {
"mcpServers": {name: server.model_dump() for name, server in ext.mcp_servers.items()}, "mcpServers": {name: server.model_dump() for name, server in extensions_config.mcp_servers.items()},
"skills": {name: {"enabled": skill_config.enabled} for name, skill_config in ext.skills.items()}, "skills": {name: {"enabled": skill_config.enabled} for name, skill_config in extensions_config.skills.items()},
} }
with open(config_path, "w", encoding="utf-8") as f: with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2) json.dump(config_data, f, indent=2)
logger.info(f"Skills configuration updated and saved to: {config_path}") logger.info(f"Skills configuration updated and saved to: {config_path}")
AppConfig.init(AppConfig.from_file()) reload_extensions_config()
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async()
skills = load_skills(enabled_only=False) skills = load_skills(enabled_only=False)
+39 -6
View File
@@ -7,8 +7,9 @@ import stat
from fastapi import APIRouter, File, HTTPException, UploadFile from fastapi import APIRouter, File, HTTPException, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from deerflow.config.app_config import get_app_config
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
from deerflow.sandbox.sandbox_provider import get_sandbox_provider from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
from deerflow.uploads.manager import ( from deerflow.uploads.manager import (
PathTraversalError, PathTraversalError,
delete_file_safe, delete_file_safe,
@@ -53,6 +54,34 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
os.chmod(file_path, writable_mode, **chmod_kwargs) os.chmod(file_path, writable_mode, **chmod_kwargs)
def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
def _get_uploads_config_value(key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access."""
cfg = get_app_config()
uploads_cfg = getattr(cfg, "uploads", None)
if isinstance(uploads_cfg, dict):
return uploads_cfg.get(key, default)
return getattr(uploads_cfg, key, default)
def _auto_convert_documents_enabled() -> bool:
"""Return whether automatic host-side document conversion is enabled.
The secure default is disabled unless an operator explicitly opts in via
uploads.auto_convert_documents in config.yaml.
"""
try:
raw = _get_uploads_config_value("auto_convert_documents", False)
if isinstance(raw, str):
return raw.strip().lower() in {"1", "true", "yes", "on"}
return bool(raw)
except Exception:
return False
@router.post("", response_model=UploadResponse) @router.post("", response_model=UploadResponse)
async def upload_files( async def upload_files(
thread_id: str, thread_id: str,
@@ -70,8 +99,12 @@ async def upload_files(
uploaded_files = [] uploaded_files = []
sandbox_provider = get_sandbox_provider() sandbox_provider = get_sandbox_provider()
sandbox_id = sandbox_provider.acquire(thread_id) sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
sandbox = sandbox_provider.get(sandbox_id) sandbox = None
if sync_to_sandbox:
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
auto_convert_documents = _auto_convert_documents_enabled()
for file in files: for file in files:
if not file.filename: if not file.filename:
@@ -90,7 +123,7 @@ async def upload_files(
virtual_path = upload_virtual_path(safe_filename) virtual_path = upload_virtual_path(safe_filename)
if sandbox_id != "local": if sync_to_sandbox and sandbox is not None:
_make_file_sandbox_writable(file_path) _make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, content) sandbox.update_file(virtual_path, content)
@@ -105,12 +138,12 @@ async def upload_files(
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}") logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
file_ext = file_path.suffix.lower() file_ext = file_path.suffix.lower()
if file_ext in CONVERTIBLE_EXTENSIONS: if auto_convert_documents and file_ext in CONVERTIBLE_EXTENSIONS:
md_path = await convert_file_to_markdown(file_path) md_path = await convert_file_to_markdown(file_path)
if md_path: if md_path:
md_virtual_path = upload_virtual_path(md_path.name) md_virtual_path = upload_virtual_path(md_path.name)
if sandbox_id != "local": if sync_to_sandbox and sandbox is not None:
_make_file_sandbox_writable(md_path) _make_file_sandbox_writable(md_path)
sandbox.update_file(md_virtual_path, md_path.read_bytes()) sandbox.update_file(md_virtual_path, md_path.read_bytes())
+2
View File
@@ -298,6 +298,8 @@ async def start_run(
"is_plan_mode", "is_plan_mode",
"subagent_enabled", "subagent_enabled",
"max_concurrent_subagents", "max_concurrent_subagents",
"agent_name",
"is_bootstrap",
} }
configurable = config.setdefault("configurable", {}) configurable = config.setdefault("configurable", {})
for key in _CONTEXT_CONFIGURABLE_KEYS: for key in _CONTEXT_CONFIGURABLE_KEYS:
+6 -3
View File
@@ -2,12 +2,12 @@
## 概述 ## 概述
DeerFlow 后端提供了完整的文件上传功能,支持多文件上传,并自动将 Office 文档和 PDF 转换为 Markdown 格式。 DeerFlow 后端提供了完整的文件上传功能,支持多文件上传,并可选地将 Office 文档和 PDF 转换为 Markdown 格式。
## 功能特性 ## 功能特性
- ✅ 支持多文件同时上传 - ✅ 支持多文件同时上传
-自动转换文档为 MarkdownPDF、PPT、Excel、Word -可选地转换文档为 MarkdownPDF、PPT、Excel、Word
- ✅ 文件存储在线程隔离的目录中 - ✅ 文件存储在线程隔离的目录中
- ✅ Agent 自动感知已上传的文件 - ✅ Agent 自动感知已上传的文件
- ✅ 支持文件列表查询和删除 - ✅ 支持文件列表查询和删除
@@ -86,7 +86,7 @@ DELETE /api/threads/{thread_id}/uploads/{filename}
## 支持的文档格式 ## 支持的文档格式
以下格式会自动转换为 Markdown: 以下格式在显式启用 `uploads.auto_convert_documents: true`会自动转换为 Markdown
- PDF (`.pdf`) - PDF (`.pdf`)
- PowerPoint (`.ppt`, `.pptx`) - PowerPoint (`.ppt`, `.pptx`)
- Excel (`.xls`, `.xlsx`) - Excel (`.xls`, `.xlsx`)
@@ -94,6 +94,8 @@ DELETE /api/threads/{thread_id}/uploads/{filename}
转换后的 Markdown 文件会保存在同一目录下,文件名为原文件名 + `.md` 扩展名。 转换后的 Markdown 文件会保存在同一目录下,文件名为原文件名 + `.md` 扩展名。
默认情况下,自动转换是关闭的,以避免在网关主机上对不受信任的 Office/PDF 上传执行解析。只有在受信任部署中明确接受此风险时,才应将 `uploads.auto_convert_documents` 设置为 `true`
## Agent 集成 ## Agent 集成
### 自动文件列举 ### 自动文件列举
@@ -207,6 +209,7 @@ backend/.deer-flow/threads/
- 最大文件大小:100MB(可在 nginx.conf 中配置 `client_max_body_size` - 最大文件大小:100MB(可在 nginx.conf 中配置 `client_max_body_size`
- 文件名安全性:系统会自动验证文件路径,防止目录遍历攻击 - 文件名安全性:系统会自动验证文件路径,防止目录遍历攻击
- 线程隔离:每个线程的上传文件相互隔离,无法跨线程访问 - 线程隔离:每个线程的上传文件相互隔离,无法跨线程访问
- 自动文档转换默认关闭;如需启用,需在 `config.yaml` 中显式设置 `uploads.auto_convert_documents: true`
## 技术实现 ## 技术实现
+3 -3
View File
@@ -11,6 +11,7 @@
- [x] Add Plan Mode with TodoList middleware - [x] Add Plan Mode with TodoList middleware
- [x] Add vision model support with ViewImageMiddleware - [x] Add vision model support with ViewImageMiddleware
- [x] Skills system with SKILL.md format - [x] Skills system with SKILL.md format
- [x] Replace `time.sleep(5)` with `asyncio.sleep()` in `packages/harness/deerflow/tools/builtins/task_tool.py` (subagent polling)
## Planned Features ## Planned Features
@@ -21,10 +22,9 @@
- [ ] Support for more document formats in upload - [ ] Support for more document formats in upload
- [ ] Skill marketplace / remote skill installation - [ ] Skill marketplace / remote skill installation
- [ ] Optimize async concurrency in agent hot path (IM channels multi-task scenario) - [ ] Optimize async concurrency in agent hot path (IM channels multi-task scenario)
- Replace `time.sleep(5)` with `asyncio.sleep()` in `packages/harness/deerflow/tools/builtins/task_tool.py` (subagent polling) - [ ] Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
- Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
- Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search) - Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search)
- Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater - [x] Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
- Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O - Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O
- For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker) - For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker)
@@ -29,7 +29,7 @@ from deerflow.agents.checkpointer.provider import (
POSTGRES_INSTALL, POSTGRES_INSTALL,
SQLITE_INSTALL, SQLITE_INSTALL,
) )
from deerflow.config.app_config import AppConfig from deerflow.config.app_config import get_app_config
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -94,7 +94,7 @@ async def make_checkpointer() -> AsyncIterator[Checkpointer]:
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*. Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
""" """
config = AppConfig.current() config = get_app_config()
if config.checkpointer is None: if config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
@@ -25,9 +25,9 @@ from collections.abc import Iterator
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from deerflow.config.app_config import AppConfig from deerflow.config.app_config import get_app_config
from deerflow.config.checkpointer_config import CheckpointerConfig from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import resolve_sqlite_conn_str from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -67,6 +67,7 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
raise ImportError(SQLITE_INSTALL) from exc raise ImportError(SQLITE_INSTALL) from exc
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db") conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
ensure_sqlite_parent_dir(conn_str)
with SqliteSaver.from_conn_string(conn_str) as saver: with SqliteSaver.from_conn_string(conn_str) as saver:
saver.setup() saver.setup()
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str) logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
@@ -113,10 +114,25 @@ def get_checkpointer() -> Checkpointer:
if _checkpointer is not None: if _checkpointer is not None:
return _checkpointer return _checkpointer
try: # Ensure app config is loaded before checking checkpointer config
config = AppConfig.current().checkpointer # This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
except (LookupError, FileNotFoundError): # but hasn't been loaded yet
config = None from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
# Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk.
try:
get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected.
pass
config = get_checkpointer_config()
if config is None: if config is None:
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
@@ -165,7 +181,7 @@ def checkpointer_context() -> Iterator[Checkpointer]:
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*. Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
""" """
config = AppConfig.current() config = get_app_config()
if config.checkpointer is None: if config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
@@ -1,24 +1,26 @@
import logging import logging
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware from langchain.agents.middleware import AgentMiddleware
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from deerflow.agents.lead_agent.prompt import apply_prompt_template from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware
from deerflow.agents.middlewares.title_middleware import TitleMiddleware from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import AppConfig from deerflow.config.app_config import get_app_config
from deerflow.config.deer_flow_context import DeerFlowContext from deerflow.config.memory_config import get_memory_config
from deerflow.config.summarization_config import get_summarization_config
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,7 +28,7 @@ logger = logging.getLogger(__name__)
def _resolve_model_name(requested_model_name: str | None = None) -> str: def _resolve_model_name(requested_model_name: str | None = None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured.""" """Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
app_config = AppConfig.current() app_config = get_app_config()
default_model_name = app_config.models[0].name if app_config.models else None default_model_name = app_config.models[0].name if app_config.models else None
if default_model_name is None: if default_model_name is None:
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.") raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
@@ -39,9 +41,9 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
return default_model_name return default_model_name
def _create_summarization_middleware() -> SummarizationMiddleware | None: def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None:
"""Create and configure the summarization middleware from config.""" """Create and configure the summarization middleware from config."""
config = AppConfig.current().summarization config = get_summarization_config()
if not config.enabled: if not config.enabled:
return None return None
@@ -78,7 +80,11 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
if config.summary_prompt is not None: if config.summary_prompt is not None:
kwargs["summary_prompt"] = config.summary_prompt kwargs["summary_prompt"] = config.summary_prompt
return SummarizationMiddleware(**kwargs) hooks: list[BeforeSummarizationHook] = []
if get_memory_config().enabled:
hooks.append(memory_flush_hook)
return DeerFlowSummarizationMiddleware(**kwargs, before_summarization=hooks)
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None: def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
@@ -231,7 +237,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
middlewares.append(todo_list_middleware) middlewares.append(todo_list_middleware)
# Add TokenUsageMiddleware when token_usage tracking is enabled # Add TokenUsageMiddleware when token_usage tracking is enabled
if AppConfig.current().token_usage.enabled: if get_app_config().token_usage.enabled:
middlewares.append(TokenUsageMiddleware()) middlewares.append(TokenUsageMiddleware())
# Add TitleMiddleware # Add TitleMiddleware
@@ -242,7 +248,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
# Add ViewImageMiddleware only if the current model supports vision. # Add ViewImageMiddleware only if the current model supports vision.
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values. # Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
app_config = AppConfig.current() app_config = get_app_config()
model_config = app_config.get_model_config(model_name) if model_name else None model_config = app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision: if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware()) middlewares.append(ViewImageMiddleware())
@@ -271,7 +277,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
return middlewares return middlewares
def make_lead_agent(config: RunnableConfig) -> CompiledStateGraph: def make_lead_agent(config: RunnableConfig):
# Lazy import to avoid circular dependency # Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent from deerflow.tools.builtins import setup_agent
@@ -285,7 +291,7 @@ def make_lead_agent(config: RunnableConfig) -> CompiledStateGraph:
subagent_enabled = cfg.get("subagent_enabled", False) subagent_enabled = cfg.get("subagent_enabled", False)
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3) max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
is_bootstrap = cfg.get("is_bootstrap", False) is_bootstrap = cfg.get("is_bootstrap", False)
agent_name = cfg.get("agent_name") agent_name = validate_agent_name(cfg.get("agent_name"))
agent_config = load_agent_config(agent_name) if not is_bootstrap else None agent_config = load_agent_config(agent_name) if not is_bootstrap else None
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default # Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
@@ -294,7 +300,7 @@ def make_lead_agent(config: RunnableConfig) -> CompiledStateGraph:
# Final model name resolution: request → agent config → global default, with fallback for unknown names # Final model name resolution: request → agent config → global default, with fallback for unknown names
model_name = _resolve_model_name(requested_model_name or agent_model_name) model_name = _resolve_model_name(requested_model_name or agent_model_name)
app_config = AppConfig.current() app_config = get_app_config()
model_config = app_config.get_model_config(model_name) model_config = app_config.get_model_config(model_name)
if model_config is None: if model_config is None:
@@ -326,6 +332,7 @@ def make_lead_agent(config: RunnableConfig) -> CompiledStateGraph:
"reasoning_effort": reasoning_effort, "reasoning_effort": reasoning_effort,
"is_plan_mode": is_plan_mode, "is_plan_mode": is_plan_mode,
"subagent_enabled": subagent_enabled, "subagent_enabled": subagent_enabled,
"tool_groups": agent_config.tool_groups if agent_config else None,
} }
) )
@@ -337,7 +344,6 @@ def make_lead_agent(config: RunnableConfig) -> CompiledStateGraph:
middleware=_build_middlewares(config, model_name=model_name), middleware=_build_middlewares(config, model_name=model_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])), system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
state_schema=ThreadState, state_schema=ThreadState,
context_schema=DeerFlowContext,
) )
# Default lead agent (unchanged behavior) # Default lead agent (unchanged behavior)
@@ -349,5 +355,4 @@ def make_lead_agent(config: RunnableConfig) -> CompiledStateGraph:
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
), ),
state_schema=ThreadState, state_schema=ThreadState,
context_schema=DeerFlowContext,
) )
@@ -5,7 +5,6 @@ from datetime import datetime
from functools import lru_cache from functools import lru_cache
from deerflow.config.agents_config import load_agent_soul from deerflow.config.agents_config import load_agent_soul
from deerflow.config.app_config import AppConfig
from deerflow.skills import load_skills from deerflow.skills import load_skills
from deerflow.skills.types import Skill from deerflow.skills.types import Skill
from deerflow.subagents import get_available_subagent_names from deerflow.subagents import get_available_subagent_names
@@ -519,8 +518,9 @@ def _get_memory_context(agent_name: str | None = None) -> str:
""" """
try: try:
from deerflow.agents.memory import format_memory_for_injection, get_memory_data from deerflow.agents.memory import format_memory_for_injection, get_memory_data
from deerflow.config.memory_config import get_memory_config
config = AppConfig.current().memory config = get_memory_config()
if not config.enabled or not config.injection_enabled: if not config.enabled or not config.injection_enabled:
return "" return ""
@@ -576,7 +576,9 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
skills = _get_enabled_skills() skills = _get_enabled_skills()
try: try:
config = AppConfig.current() from deerflow.config import get_app_config
config = get_app_config()
container_base_path = config.skills.container_path container_base_path = config.skills.container_path
skill_evolution_enabled = config.skill_evolution.enabled skill_evolution_enabled = config.skill_evolution.enabled
except Exception: except Exception:
@@ -615,7 +617,9 @@ def get_deferred_tools_prompt_section() -> str:
from deerflow.tools.builtins.tool_search import get_deferred_registry from deerflow.tools.builtins.tool_search import get_deferred_registry
try: try:
if not AppConfig.current().tool_search.enabled: from deerflow.config import get_app_config
if not get_app_config().tool_search.enabled:
return "" return ""
except Exception: except Exception:
return "" return ""
@@ -631,7 +635,9 @@ def get_deferred_tools_prompt_section() -> str:
def _build_acp_section() -> str: def _build_acp_section() -> str:
"""Build the ACP agent prompt section, only if ACP agents are configured.""" """Build the ACP agent prompt section, only if ACP agents are configured."""
try: try:
agents = AppConfig.current().acp_agents from deerflow.config.acp_config import get_acp_agents
agents = get_acp_agents()
if not agents: if not agents:
return "" return ""
except Exception: except Exception:
@@ -649,7 +655,9 @@ def _build_acp_section() -> str:
def _build_custom_mounts_section() -> str: def _build_custom_mounts_section() -> str:
"""Build a prompt section for explicitly configured sandbox mounts.""" """Build a prompt section for explicitly configured sandbox mounts."""
try: try:
mounts = AppConfig.current().sandbox.mounts or [] from deerflow.config import get_app_config
mounts = get_app_config().sandbox.mounts or []
except Exception: except Exception:
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt") logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
return "" return ""
@@ -0,0 +1,109 @@
"""Shared helpers for turning conversations into memory update inputs."""
from __future__ import annotations
import re
from copy import copy
from typing import Any
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
_CORRECTION_PATTERNS = (
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
re.compile(r"\btry again\b", re.IGNORECASE),
re.compile(r"\bredo\b", re.IGNORECASE),
re.compile(r"不对"),
re.compile(r"你理解错了"),
re.compile(r"你理解有误"),
re.compile(r"重试"),
re.compile(r"重新来"),
re.compile(r"换一种"),
re.compile(r"改用"),
)
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
def extract_message_text(message: Any) -> str:
"""Extract plain text from message content for filtering and signal detection."""
content = getattr(message, "content", "")
if isinstance(content, list):
text_parts: list[str] = []
for part in content:
if isinstance(part, str):
text_parts.append(part)
elif isinstance(part, dict):
text_val = part.get("text")
if isinstance(text_val, str):
text_parts.append(text_val)
return " ".join(text_parts)
return str(content)
def filter_messages_for_memory(messages: list[Any]) -> list[Any]:
"""Keep only user inputs and final assistant responses for memory updates."""
filtered = []
skip_next_ai = False
for msg in messages:
msg_type = getattr(msg, "type", None)
if msg_type == "human":
content_str = extract_message_text(msg)
if "<uploaded_files>" in content_str:
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
if not stripped:
skip_next_ai = True
continue
clean_msg = copy(msg)
clean_msg.content = stripped
filtered.append(clean_msg)
skip_next_ai = False
else:
filtered.append(msg)
skip_next_ai = False
elif msg_type == "ai":
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
if skip_next_ai:
skip_next_ai = False
continue
filtered.append(msg)
return filtered
def detect_correction(messages: list[Any]) -> bool:
"""Detect explicit user corrections in recent conversation turns."""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = extract_message_text(msg).strip()
if content and any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
return True
return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns."""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = extract_message_text(msg).strip()
if content and any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
@@ -7,7 +7,7 @@ from dataclasses import dataclass, field
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any
from deerflow.config.app_config import AppConfig from deerflow.config.memory_config import get_memory_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -56,53 +56,93 @@ class MemoryUpdateQueue:
correction_detected: Whether recent turns include an explicit correction signal. correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal. reinforcement_detected: Whether recent turns include a positive reinforcement signal.
""" """
config = AppConfig.current().memory config = get_memory_config()
if not config.enabled: if not config.enabled:
return return
with self._lock: with self._lock:
existing_context = next( self._enqueue_locked(
(context for context in self._queue if context.thread_id == thread_id),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id, thread_id=thread_id,
messages=messages, messages=messages,
agent_name=agent_name, agent_name=agent_name,
correction_detected=merged_correction_detected, correction_detected=correction_detected,
reinforcement_detected=merged_reinforcement_detected, reinforcement_detected=reinforcement_detected,
) )
# Check if this thread already has a pending update
# If so, replace it with the newer one
self._queue = [c for c in self._queue if c.thread_id != thread_id]
self._queue.append(context)
# Reset or start the debounce timer
self._reset_timer() self._reset_timer()
logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue)) logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue))
def add_nowait(
self,
thread_id: str,
messages: list[Any],
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
"""Add a conversation and start processing immediately in the background."""
config = get_memory_config()
if not config.enabled:
return
with self._lock:
self._enqueue_locked(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
self._schedule_timer(0)
logger.info("Memory update queued for immediate processing on thread %s, queue size: %d", thread_id, len(self._queue))
def _enqueue_locked(
self,
*,
thread_id: str,
messages: list[Any],
agent_name: str | None,
correction_detected: bool,
reinforcement_detected: bool,
) -> None:
existing_context = next(
(context for context in self._queue if context.thread_id == thread_id),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
)
self._queue = [c for c in self._queue if c.thread_id != thread_id]
self._queue.append(context)
def _reset_timer(self) -> None: def _reset_timer(self) -> None:
"""Reset the debounce timer.""" """Reset the debounce timer."""
config = AppConfig.current().memory config = get_memory_config()
self._schedule_timer(config.debounce_seconds)
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
def _schedule_timer(self, delay_seconds: float) -> None:
"""Schedule queue processing after the provided delay."""
# Cancel existing timer if any # Cancel existing timer if any
if self._timer is not None: if self._timer is not None:
self._timer.cancel() self._timer.cancel()
# Start new timer
self._timer = threading.Timer( self._timer = threading.Timer(
config.debounce_seconds, delay_seconds,
self._process_queue, self._process_queue,
) )
self._timer.daemon = True self._timer.daemon = True
self._timer.start() self._timer.start()
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
def _process_queue(self) -> None: def _process_queue(self) -> None:
"""Process all queued conversation contexts.""" """Process all queued conversation contexts."""
# Import here to avoid circular dependency # Import here to avoid circular dependency
@@ -110,8 +150,8 @@ class MemoryUpdateQueue:
with self._lock: with self._lock:
if self._processing: if self._processing:
# Already processing, reschedule # Preserve immediate flush semantics even if another worker is active.
self._reset_timer() self._schedule_timer(0)
return return
if not self._queue: if not self._queue:
@@ -164,6 +204,13 @@ class MemoryUpdateQueue:
self._process_queue() self._process_queue()
def flush_nowait(self) -> None:
"""Start queue processing immediately in a background thread."""
with self._lock:
# Daemon thread: queued messages may be lost if the process exits
# before _process_queue completes. Acceptable for best-effort memory updates.
self._schedule_timer(0)
def clear(self) -> None: def clear(self) -> None:
"""Clear the queue without processing. """Clear the queue without processing.
@@ -4,12 +4,13 @@ import abc
import json import json
import logging import logging
import threading import threading
import uuid
from datetime import UTC, datetime from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from deerflow.config.agents_config import AGENT_NAME_PATTERN from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.app_config import AppConfig from deerflow.config.memory_config import get_memory_config
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -66,6 +67,8 @@ class FileMemoryStorage(MemoryStorage):
# Per-agent memory cache: keyed by agent_name (None = global) # Per-agent memory cache: keyed by agent_name (None = global)
# Value: (memory_data, file_mtime) # Value: (memory_data, file_mtime)
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {} self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
# Guards all reads and writes to _memory_cache across concurrent callers.
self._cache_lock = threading.Lock()
def _validate_agent_name(self, agent_name: str) -> None: def _validate_agent_name(self, agent_name: str) -> None:
"""Validate that the agent name is safe to use in filesystem paths. """Validate that the agent name is safe to use in filesystem paths.
@@ -84,7 +87,7 @@ class FileMemoryStorage(MemoryStorage):
self._validate_agent_name(agent_name) self._validate_agent_name(agent_name)
return get_paths().agent_memory_file(agent_name) return get_paths().agent_memory_file(agent_name)
config = AppConfig.current().memory config = get_memory_config()
if config.storage_path: if config.storage_path:
p = Path(config.storage_path) p = Path(config.storage_path)
return p if p.is_absolute() else get_paths().base_dir / p return p if p.is_absolute() else get_paths().base_dir / p
@@ -114,14 +117,17 @@ class FileMemoryStorage(MemoryStorage):
except OSError: except OSError:
current_mtime = None current_mtime = None
cached = self._memory_cache.get(agent_name) with self._cache_lock:
cached = self._memory_cache.get(agent_name)
if cached is not None and cached[1] == current_mtime:
return cached[0]
if cached is None or cached[1] != current_mtime: memory_data = self._load_memory_from_file(agent_name)
memory_data = self._load_memory_from_file(agent_name)
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, current_mtime) self._memory_cache[agent_name] = (memory_data, current_mtime)
return memory_data
return cached[0] return memory_data
def reload(self, agent_name: str | None = None) -> dict[str, Any]: def reload(self, agent_name: str | None = None) -> dict[str, Any]:
"""Reload memory data from file, forcing cache invalidation.""" """Reload memory data from file, forcing cache invalidation."""
@@ -133,7 +139,8 @@ class FileMemoryStorage(MemoryStorage):
except OSError: except OSError:
mtime = None mtime = None
self._memory_cache[agent_name] = (memory_data, mtime) with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
return memory_data return memory_data
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool: def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
@@ -142,9 +149,12 @@ class FileMemoryStorage(MemoryStorage):
try: try:
file_path.parent.mkdir(parents=True, exist_ok=True) file_path.parent.mkdir(parents=True, exist_ok=True)
memory_data["lastUpdated"] = utc_now_iso_z() # Shallow-copy before adding lastUpdated so the caller's dict is not
# mutated as a side-effect, and the cache reference is not silently
# updated before the file write succeeds.
memory_data = {**memory_data, "lastUpdated": utc_now_iso_z()}
temp_path = file_path.with_suffix(".tmp") temp_path = file_path.with_suffix(f".{uuid.uuid4().hex}.tmp")
with open(temp_path, "w", encoding="utf-8") as f: with open(temp_path, "w", encoding="utf-8") as f:
json.dump(memory_data, f, indent=2, ensure_ascii=False) json.dump(memory_data, f, indent=2, ensure_ascii=False)
@@ -155,7 +165,8 @@ class FileMemoryStorage(MemoryStorage):
except OSError: except OSError:
mtime = None mtime = None
self._memory_cache[agent_name] = (memory_data, mtime) with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
logger.info("Memory saved to %s", file_path) logger.info("Memory saved to %s", file_path)
return True return True
except OSError as e: except OSError as e:
@@ -177,7 +188,7 @@ def get_memory_storage() -> MemoryStorage:
if _storage_instance is not None: if _storage_instance is not None:
return _storage_instance return _storage_instance
config = AppConfig.current().memory config = get_memory_config()
storage_class_path = config.storage_class storage_class_path = config.storage_class
try: try:
@@ -0,0 +1,31 @@
"""Hooks fired before summarization removes messages from state."""
from __future__ import annotations
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
from deerflow.config.memory_config import get_memory_config
def memory_flush_hook(event: SummarizationEvent) -> None:
"""Flush messages about to be summarized into the memory queue."""
if not get_memory_config().enabled or not event.thread_id:
return
filtered_messages = filter_messages_for_memory(list(event.messages_to_summarize))
user_messages = [message for message in filtered_messages if getattr(message, "type", None) == "human"]
assistant_messages = [message for message in filtered_messages if getattr(message, "type", None) == "ai"]
if not user_messages or not assistant_messages:
return
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue()
queue.add_nowait(
thread_id=event.thread_id,
messages=filtered_messages,
agent_name=event.agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
@@ -1,10 +1,15 @@
"""Memory updater for reading, writing, and updating memory data.""" """Memory updater for reading, writing, and updating memory data."""
import asyncio
import atexit
import concurrent.futures
import copy
import json import json
import logging import logging
import math import math
import re import re
import uuid import uuid
from collections.abc import Awaitable
from typing import Any from typing import Any
from deerflow.agents.memory.prompt import ( from deerflow.agents.memory.prompt import (
@@ -16,11 +21,17 @@ from deerflow.agents.memory.storage import (
get_memory_storage, get_memory_storage,
utc_now_iso_z, utc_now_iso_z,
) )
from deerflow.config.app_config import AppConfig from deerflow.config.memory_config import get_memory_config
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
max_workers=4,
thread_name_prefix="memory-updater-sync",
)
atexit.register(lambda: _SYNC_MEMORY_UPDATER_EXECUTOR.shutdown(wait=False))
def _create_empty_memory() -> dict[str, Any]: def _create_empty_memory() -> dict[str, Any]:
"""Backward-compatible wrapper around the storage-layer empty-memory factory.""" """Backward-compatible wrapper around the storage-layer empty-memory factory."""
@@ -206,6 +217,39 @@ def _extract_text(content: Any) -> str:
return str(content) return str(content)
def _run_async_update_sync(coro: Awaitable[bool]) -> bool:
"""Run an async memory update from sync code, including nested-loop contexts."""
handed_off = False
try:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(asyncio.run, coro)
handed_off = True
return future.result()
handed_off = True
return asyncio.run(coro)
except Exception:
if not handed_off:
close = getattr(coro, "close", None)
if callable(close):
try:
close()
except Exception:
logger.debug(
"Failed to close un-awaited memory update coroutine",
exc_info=True,
)
logger.exception("Failed to run async memory update from sync context")
return False
# Matches sentences that describe a file-upload *event* rather than general # Matches sentences that describe a file-upload *event* rather than general
# file-related work. Deliberately narrow to avoid removing legitimate facts # file-related work. Deliberately narrow to avoid removing legitimate facts
# such as "User works with CSV files" or "prefers PDF export". # such as "User works with CSV files" or "prefers PDF export".
@@ -265,10 +309,121 @@ class MemoryUpdater:
def _get_model(self): def _get_model(self):
"""Get the model for memory updates.""" """Get the model for memory updates."""
config = AppConfig.current().memory config = get_memory_config()
model_name = self._model_name or config.model_name model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False) return create_chat_model(name=model_name, thinking_enabled=False)
def _build_correction_hint(
self,
correction_detected: bool,
reinforcement_detected: bool,
) -> str:
"""Build optional prompt hints for correction and reinforcement signals."""
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
return correction_hint
def _prepare_update_prompt(
self,
messages: list[Any],
agent_name: str | None,
correction_detected: bool,
reinforcement_detected: bool,
) -> tuple[dict[str, Any], str] | None:
"""Load memory and build the update prompt for a conversation."""
config = get_memory_config()
if not config.enabled or not messages:
return None
current_memory = get_memory_data(agent_name)
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return None
correction_hint = self._build_correction_hint(
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
)
return current_memory, prompt
def _finalize_update(
self,
current_memory: dict[str, Any],
response_content: Any,
thread_id: str | None,
agent_name: str | None,
) -> bool:
"""Parse the model response, apply updates, and persist memory."""
response_text = _extract_text(response_content).strip()
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
# Deep-copy before in-place mutation so a subsequent save() failure
# cannot corrupt the still-cached original object reference.
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
return get_memory_storage().save(updated_memory, agent_name)
async def aupdate_memory(
self,
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Update memory asynchronously based on conversation messages."""
try:
prepared = await asyncio.to_thread(
self._prepare_update_prompt,
messages=messages,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
if prepared is None:
return False
current_memory, prompt = prepared
model = self._get_model()
response = await model.ainvoke(prompt)
return await asyncio.to_thread(
self._finalize_update,
current_memory=current_memory,
response_content=response.content,
thread_id=thread_id,
agent_name=agent_name,
)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
return False
except Exception as e:
logger.exception("Memory update failed: %s", e)
return False
def update_memory( def update_memory(
self, self,
messages: list[Any], messages: list[Any],
@@ -277,7 +432,7 @@ class MemoryUpdater:
correction_detected: bool = False, correction_detected: bool = False,
reinforcement_detected: bool = False, reinforcement_detected: bool = False,
) -> bool: ) -> bool:
"""Update memory based on conversation messages. """Synchronously update memory via the async updater path.
Args: Args:
messages: List of conversation messages. messages: List of conversation messages.
@@ -289,78 +444,15 @@ class MemoryUpdater:
Returns: Returns:
True if update was successful, False otherwise. True if update was successful, False otherwise.
""" """
config = AppConfig.current().memory return _run_async_update_sync(
if not config.enabled: self.aupdate_memory(
return False messages=messages,
thread_id=thread_id,
if not messages: agent_name=agent_name,
return False correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
try:
# Get current memory
current_memory = get_memory_data(agent_name)
# Format conversation for prompt
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return False
# Build prompt
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
) )
)
# Call LLM
model = self._get_model()
response = model.invoke(prompt)
response_text = _extract_text(response.content).strip()
# Parse response
# Remove markdown code blocks if present
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
# Apply updates
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
# Strip file-upload mentions from all summaries before saving.
# Uploaded files are session-scoped and won't exist in future sessions,
# so recording upload events in long-term memory causes the agent to
# try (and fail) to locate those files in subsequent conversations.
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
# Save
return get_memory_storage().save(updated_memory, agent_name)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
return False
except Exception as e:
logger.exception("Memory update failed: %s", e)
return False
def _apply_updates( def _apply_updates(
self, self,
@@ -378,7 +470,7 @@ class MemoryUpdater:
Returns: Returns:
Updated memory data. Updated memory data.
""" """
config = AppConfig.current().memory config = get_memory_config()
now = utc_now_iso_z() now = utc_now_iso_z()
# Update user sections # Update user sections
@@ -13,6 +13,7 @@ at the correct positions (immediately after each dangling AIMessage), not append
to the end of the message list as before_model + add_messages reducer would do. to the end of the message list as before_model + add_messages reducer would do.
""" """
import json
import logging import logging
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import override from typing import override
@@ -33,6 +34,44 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
offending AIMessage so the LLM receives a well-formed conversation. offending AIMessage so the LLM receives a well-formed conversation.
""" """
@staticmethod
def _message_tool_calls(msg) -> list[dict]:
"""Return normalized tool calls from structured fields or raw provider payloads."""
tool_calls = getattr(msg, "tool_calls", None) or []
if tool_calls:
return list(tool_calls)
raw_tool_calls = (getattr(msg, "additional_kwargs", None) or {}).get("tool_calls") or []
normalized: list[dict] = []
for raw_tc in raw_tool_calls:
if not isinstance(raw_tc, dict):
continue
function = raw_tc.get("function")
name = raw_tc.get("name")
if not name and isinstance(function, dict):
name = function.get("name")
args = raw_tc.get("args", {})
if not args and isinstance(function, dict):
raw_args = function.get("arguments")
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except (TypeError, ValueError, json.JSONDecodeError):
parsed_args = {}
args = parsed_args if isinstance(parsed_args, dict) else {}
normalized.append(
{
"id": raw_tc.get("id"),
"name": name or "unknown",
"args": args if isinstance(args, dict) else {},
}
)
return normalized
def _build_patched_messages(self, messages: list) -> list | None: def _build_patched_messages(self, messages: list) -> list | None:
"""Return a new message list with patches inserted at the correct positions. """Return a new message list with patches inserted at the correct positions.
@@ -51,7 +90,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
for msg in messages: for msg in messages:
if getattr(msg, "type", None) != "ai": if getattr(msg, "type", None) != "ai":
continue continue
for tc in getattr(msg, "tool_calls", None) or []: for tc in self._message_tool_calls(msg):
tc_id = tc.get("id") tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids: if tc_id and tc_id not in existing_tool_msg_ids:
needs_patch = True needs_patch = True
@@ -70,7 +109,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
patched.append(msg) patched.append(msg)
if getattr(msg, "type", None) != "ai": if getattr(msg, "type", None) != "ai":
continue continue
for tc in getattr(msg, "tool_calls", None) or []: for tc in self._message_tool_calls(msg):
tc_id = tc.get("id") tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids: if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
patched.append( patched.append(
@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import threading
import time import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from email.utils import parsedate_to_datetime from email.utils import parsedate_to_datetime
@@ -19,6 +20,8 @@ from langchain.agents.middleware.types import (
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp from langgraph.errors import GraphBubbleUp
from deerflow.config import get_app_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504} _RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
@@ -67,6 +70,80 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
retry_base_delay_ms: int = 1000 retry_base_delay_ms: int = 1000
retry_cap_delay_ms: int = 8000 retry_cap_delay_ms: int = 8000
circuit_failure_threshold: int = 5
circuit_recovery_timeout_sec: int = 60
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
# Load Circuit Breaker configs from app config if available, fall back to defaults
try:
app_config = get_app_config()
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
except (FileNotFoundError, RuntimeError):
# Gracefully fall back to class defaults in test environments
pass
# Circuit Breaker state
self._circuit_lock = threading.Lock()
self._circuit_failure_count = 0
self._circuit_open_until = 0.0
self._circuit_state = "closed"
self._circuit_probe_in_flight = False
def _check_circuit(self) -> bool:
"""Returns True if circuit is OPEN (fast fail), False otherwise."""
with self._circuit_lock:
now = time.time()
if self._circuit_state == "open":
if now < self._circuit_open_until:
return True
self._circuit_state = "half_open"
self._circuit_probe_in_flight = False
if self._circuit_state == "half_open":
if self._circuit_probe_in_flight:
return True
self._circuit_probe_in_flight = True
return False
return False
def _record_success(self) -> None:
with self._circuit_lock:
if self._circuit_state != "closed" or self._circuit_failure_count > 0:
logger.info("Circuit breaker reset (Closed). LLM service recovered.")
self._circuit_failure_count = 0
self._circuit_open_until = 0.0
self._circuit_state = "closed"
self._circuit_probe_in_flight = False
def _record_failure(self) -> None:
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
self._circuit_state = "open"
self._circuit_probe_in_flight = False
logger.error(
"Circuit breaker probe failed (Open). Will probe again after %ds.",
self.circuit_recovery_timeout_sec,
)
return
self._circuit_failure_count += 1
if self._circuit_failure_count >= self.circuit_failure_threshold:
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
if self._circuit_state != "open":
self._circuit_state = "open"
self._circuit_probe_in_flight = False
logger.error(
"Circuit breaker tripped (Open). Threshold reached (%d). Will probe after %ds.",
self.circuit_failure_threshold,
self.circuit_recovery_timeout_sec,
)
def _classify_error(self, exc: BaseException) -> tuple[bool, str]: def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
detail = _extract_error_detail(exc) detail = _extract_error_detail(exc)
lowered = detail.lower() lowered = detail.lower()
@@ -104,6 +181,9 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily" reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s." return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
def _build_circuit_breaker_message(self) -> str:
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
def _build_user_message(self, exc: BaseException, reason: str) -> str: def _build_user_message(self, exc: BaseException, reason: str) -> str:
detail = _extract_error_detail(exc) detail = _extract_error_detail(exc)
if reason == "quota": if reason == "quota":
@@ -138,12 +218,20 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
request: ModelRequest, request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse], handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult: ) -> ModelCallResult:
if self._check_circuit():
return AIMessage(content=self._build_circuit_breaker_message())
attempt = 1 attempt = 1
while True: while True:
try: try:
return handler(request) response = handler(request)
self._record_success()
return response
except GraphBubbleUp: except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume). # Preserve LangGraph control-flow signals (interrupt/pause/resume).
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_probe_in_flight = False
raise raise
except Exception as exc: except Exception as exc:
retriable, reason = self._classify_error(exc) retriable, reason = self._classify_error(exc)
@@ -166,6 +254,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
_extract_error_detail(exc), _extract_error_detail(exc),
exc_info=exc, exc_info=exc,
) )
if retriable:
self._record_failure()
return AIMessage(content=self._build_user_message(exc, reason)) return AIMessage(content=self._build_user_message(exc, reason))
@override @override
@@ -174,12 +264,20 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
request: ModelRequest, request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]], handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult: ) -> ModelCallResult:
if self._check_circuit():
return AIMessage(content=self._build_circuit_breaker_message())
attempt = 1 attempt = 1
while True: while True:
try: try:
return await handler(request) response = await handler(request)
self._record_success()
return response
except GraphBubbleUp: except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume). # Preserve LangGraph control-flow signals (interrupt/pause/resume).
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_probe_in_flight = False
raise raise
except Exception as exc: except Exception as exc:
retriable, reason = self._classify_error(exc) retriable, reason = self._classify_error(exc)
@@ -202,6 +300,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
_extract_error_detail(exc), _extract_error_detail(exc),
exc_info=exc, exc_info=exc,
) )
if retriable:
self._record_failure()
return AIMessage(content=self._build_user_message(exc, reason)) return AIMessage(content=self._build_user_message(exc, reason))
@@ -17,6 +17,7 @@ import json
import logging import logging
import threading import threading
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from copy import deepcopy
from typing import override from typing import override
from langchain.agents import AgentState from langchain.agents import AgentState
@@ -24,8 +25,6 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.config.deer_flow_context import DeerFlowContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor # Defaults — can be overridden via constructor
@@ -182,9 +181,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set) self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
def _get_thread_id(self, runtime: Runtime[DeerFlowContext]) -> str: def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking.""" """Extract thread_id from runtime context for per-thread tracking."""
return runtime.context.thread_id or "default" thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id:
return thread_id
return "default"
def _evict_if_needed(self) -> None: def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit. """Evict least recently used threads if over the limit.
@@ -322,6 +324,26 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
# Fallback: coerce unexpected types to str to avoid TypeError # Fallback: coerce unexpected types to str to avoid TypeError
return str(content) + f"\n\n{text}" return str(content) + f"\n\n{text}"
@staticmethod
def _build_hard_stop_update(last_msg, content: str | list) -> dict:
"""Clear tool-call metadata so forced-stop messages serialize as plain assistant text."""
update = {
"tool_calls": [],
"content": content,
}
additional_kwargs = dict(getattr(last_msg, "additional_kwargs", {}) or {})
for key in ("tool_calls", "function_call"):
additional_kwargs.pop(key, None)
update["additional_kwargs"] = additional_kwargs
response_metadata = deepcopy(getattr(last_msg, "response_metadata", {}) or {})
if response_metadata.get("finish_reason") == "tool_calls":
response_metadata["finish_reason"] = "stop"
update["response_metadata"] = response_metadata
return update
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None: def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
warning, hard_stop = self._track_and_check(state, runtime) warning, hard_stop = self._track_and_check(state, runtime)
@@ -329,12 +351,8 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
# Strip tool_calls from the last AIMessage to force text output # Strip tool_calls from the last AIMessage to force text output
messages = state.get("messages", []) messages = state.get("messages", [])
last_msg = messages[-1] last_msg = messages[-1]
stripped_msg = last_msg.model_copy( content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG)
update={ stripped_msg = last_msg.model_copy(update=self._build_hard_stop_update(last_msg, content))
"tool_calls": [],
"content": self._append_text(last_msg.content, warning),
}
)
return {"messages": [stripped_msg]} return {"messages": [stripped_msg]}
if warning: if warning:
@@ -349,11 +367,11 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
return None return None
@override @override
def after_model(self, state: AgentState, runtime: Runtime[DeerFlowContext]) -> dict | None: def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime) return self._apply(state, runtime)
@override @override
async def aafter_model(self, state: AgentState, runtime: Runtime[DeerFlowContext]) -> dict | None: async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime) return self._apply(state, runtime)
def reset(self, thread_id: str | None = None) -> None: def reset(self, thread_id: str | None = None) -> None:
@@ -1,49 +1,19 @@
"""Middleware for memory mechanism.""" """Middleware for memory mechanism."""
import logging import logging
import re from typing import override
from typing import Any, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.deer_flow_context import DeerFlowContext from deerflow.config.memory_config import get_memory_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
_CORRECTION_PATTERNS = (
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
re.compile(r"\btry again\b", re.IGNORECASE),
re.compile(r"\bredo\b", re.IGNORECASE),
re.compile(r"不对"),
re.compile(r"你理解错了"),
re.compile(r"你理解有误"),
re.compile(r"重试"),
re.compile(r"重新来"),
re.compile(r"换一种"),
re.compile(r"改用"),
)
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
class MemoryMiddlewareState(AgentState): class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema.""" """Compatible with the `ThreadState` schema."""
@@ -51,125 +21,6 @@ class MemoryMiddlewareState(AgentState):
pass pass
def _extract_message_text(message: Any) -> str:
"""Extract plain text from message content for filtering and signal detection."""
content = getattr(message, "content", "")
if isinstance(content, list):
text_parts: list[str] = []
for part in content:
if isinstance(part, str):
text_parts.append(part)
elif isinstance(part, dict):
text_val = part.get("text")
if isinstance(text_val, str):
text_parts.append(text_val)
return " ".join(text_parts)
return str(content)
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
"""Filter messages to keep only user inputs and final assistant responses.
This filters out:
- Tool messages (intermediate tool call results)
- AI messages with tool_calls (intermediate steps, not final responses)
- The <uploaded_files> block injected by UploadsMiddleware into human messages
(file paths are session-scoped and must not persist in long-term memory).
The user's actual question is preserved; only turns whose content is entirely
the upload block (nothing remains after stripping) are dropped along with
their paired assistant response.
Only keeps:
- Human messages (with the ephemeral upload block removed)
- AI messages without tool_calls (final assistant responses), unless the
paired human turn was upload-only and had no real user text.
Args:
messages: List of all conversation messages.
Returns:
Filtered list containing only user inputs and final assistant responses.
"""
filtered = []
skip_next_ai = False
for msg in messages:
msg_type = getattr(msg, "type", None)
if msg_type == "human":
content_str = _extract_message_text(msg)
if "<uploaded_files>" in content_str:
# Strip the ephemeral upload block; keep the user's real question.
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
if not stripped:
# Nothing left — the entire turn was upload bookkeeping;
# skip it and the paired assistant response.
skip_next_ai = True
continue
# Rebuild the message with cleaned content so the user's question
# is still available for memory summarisation.
from copy import copy
clean_msg = copy(msg)
clean_msg.content = stripped
filtered.append(clean_msg)
skip_next_ai = False
else:
filtered.append(msg)
skip_next_ai = False
elif msg_type == "ai":
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
if skip_next_ai:
skip_next_ai = False
continue
filtered.append(msg)
# Skip tool messages and AI messages with tool_calls
return filtered
def detect_correction(messages: list[Any]) -> bool:
"""Detect explicit user corrections in recent conversation turns.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale corrections from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
return True
return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns.
Complements detect_correction() by identifying when the user confirms the
agent's approach was correct. This allows the memory system to record what
worked well, not just what went wrong.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale signals from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution. """Middleware that queues conversation for memory update after agent execution.
@@ -192,7 +43,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
self._agent_name = agent_name self._agent_name = agent_name
@override @override
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None: def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
"""Queue conversation for memory update after agent completes. """Queue conversation for memory update after agent completes.
Args: Args:
@@ -202,11 +53,15 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
Returns: Returns:
None (no state changes needed from this middleware). None (no state changes needed from this middleware).
""" """
memory_config = runtime.context.app_config.memory config = get_memory_config()
if not memory_config.enabled: if not config.enabled:
return None return None
thread_id = runtime.context.thread_id # Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
config_data = get_config()
thread_id = config_data.get("configurable", {}).get("thread_id")
if not thread_id: if not thread_id:
logger.debug("No thread_id in context, skipping memory update") logger.debug("No thread_id in context, skipping memory update")
return None return None
@@ -218,7 +73,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
return None return None
# Filter to only keep user inputs and final assistant responses # Filter to only keep user inputs and final assistant responses
filtered_messages = _filter_messages_for_memory(messages) filtered_messages = filter_messages_for_memory(messages)
# Only queue if there's meaningful conversation # Only queue if there's meaningful conversation
# At minimum need one user message and one assistant response # At minimum need one user message and one assistant response
@@ -0,0 +1,151 @@
"""Summarization middleware extensions for DeerFlow."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Protocol, runtime_checkable
from langchain.agents import AgentState
from langchain.agents.middleware import SummarizationMiddleware
from langchain_core.messages import AnyMessage, RemoveMessage
from langgraph.config import get_config
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class SummarizationEvent:
"""Context emitted before conversation history is summarized away."""
messages_to_summarize: tuple[AnyMessage, ...]
preserved_messages: tuple[AnyMessage, ...]
thread_id: str | None
agent_name: str | None
runtime: Runtime
@runtime_checkable
class BeforeSummarizationHook(Protocol):
"""Hook invoked before summarization removes messages from state."""
def __call__(self, event: SummarizationEvent) -> None: ...
def _resolve_thread_id(runtime: Runtime) -> str | None:
"""Resolve the current thread ID from runtime context or LangGraph config."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
try:
config_data = get_config()
except RuntimeError:
return None
thread_id = config_data.get("configurable", {}).get("thread_id")
return thread_id
def _resolve_agent_name(runtime: Runtime) -> str | None:
"""Resolve the current agent name from runtime context or LangGraph config."""
agent_name = runtime.context.get("agent_name") if runtime.context else None
if agent_name is None:
try:
config_data = get_config()
except RuntimeError:
return None
agent_name = config_data.get("configurable", {}).get("agent_name")
return agent_name
class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
"""Summarization middleware with pre-compression hook dispatch."""
def __init__(
self,
*args,
before_summarization: list[BeforeSummarizationHook] | None = None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._before_summarization_hooks = before_summarization or []
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._maybe_summarize(state, runtime)
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return await self._amaybe_summarize(state, runtime)
def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
messages = state["messages"]
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if not self._should_summarize(messages, total_tokens):
return None
cutoff_index = self._determine_cutoff_index(messages)
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = self._create_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages,
*preserved_messages,
]
}
async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
messages = state["messages"]
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if not self._should_summarize(messages, total_tokens):
return None
cutoff_index = self._determine_cutoff_index(messages)
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = await self._acreate_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages,
*preserved_messages,
]
}
def _fire_hooks(
self,
messages_to_summarize: list[AnyMessage],
preserved_messages: list[AnyMessage],
runtime: Runtime,
) -> None:
if not self._before_summarization_hooks:
return
event = SummarizationEvent(
messages_to_summarize=tuple(messages_to_summarize),
preserved_messages=tuple(preserved_messages),
thread_id=_resolve_thread_id(runtime),
agent_name=_resolve_agent_name(runtime),
runtime=runtime,
)
for hook in self._before_summarization_hooks:
try:
hook(event)
except Exception:
hook_name = getattr(hook, "__name__", None) or type(hook).__name__
logger.exception("before_summarization hook %s failed", hook_name)
@@ -3,10 +3,10 @@ from typing import NotRequired, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadDataState from deerflow.agents.thread_state import ThreadDataState
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -74,10 +74,14 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
return self._get_thread_paths(thread_id) return self._get_thread_paths(thread_id)
@override @override
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None: def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
thread_id = runtime.context.thread_id context = runtime.context or {}
thread_id = context.get("thread_id")
if thread_id is None:
config = get_config()
thread_id = config.get("configurable", {}).get("thread_id")
if not thread_id: if thread_id is None:
raise ValueError("Thread ID is required in runtime context or config.configurable") raise ValueError("Thread ID is required in runtime context or config.configurable")
if self._lazy_init: if self._lazy_init:
@@ -1,13 +1,14 @@
"""Middleware for automatic thread title generation.""" """Middleware for automatic thread title generation."""
import logging import logging
import re
from typing import NotRequired, override from typing import NotRequired, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.config.app_config import AppConfig from deerflow.config.title_config import get_title_config
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -45,7 +46,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
def _should_generate_title(self, state: TitleMiddlewareState) -> bool: def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
"""Check if we should generate a title for this thread.""" """Check if we should generate a title for this thread."""
config = AppConfig.current().title config = get_title_config()
if not config.enabled: if not config.enabled:
return False return False
@@ -70,14 +71,14 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
Returns (prompt_string, user_msg) so callers can use user_msg as fallback. Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
""" """
config = AppConfig.current().title config = get_title_config()
messages = state.get("messages", []) messages = state.get("messages", [])
user_msg_content = next((m.content for m in messages if m.type == "human"), "") user_msg_content = next((m.content for m in messages if m.type == "human"), "")
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "") assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
user_msg = self._normalize_content(user_msg_content) user_msg = self._normalize_content(user_msg_content)
assistant_msg = self._normalize_content(assistant_msg_content) assistant_msg = self._strip_think_tags(self._normalize_content(assistant_msg_content))
prompt = config.prompt_template.format( prompt = config.prompt_template.format(
max_words=config.max_words, max_words=config.max_words,
@@ -86,15 +87,20 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
) )
return prompt, user_msg return prompt, user_msg
def _strip_think_tags(self, text: str) -> str:
"""Remove <think>...</think> blocks emitted by reasoning models (e.g. minimax, DeepSeek-R1)."""
return re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE).strip()
def _parse_title(self, content: object) -> str: def _parse_title(self, content: object) -> str:
"""Normalize model output into a clean title string.""" """Normalize model output into a clean title string."""
config = AppConfig.current().title config = get_title_config()
title_content = self._normalize_content(content) title_content = self._normalize_content(content)
title_content = self._strip_think_tags(title_content)
title = title_content.strip().strip('"').strip("'") title = title_content.strip().strip('"').strip("'")
return title[: config.max_chars] if len(title) > config.max_chars else title return title[: config.max_chars] if len(title) > config.max_chars else title
def _fallback_title(self, user_msg: str) -> str: def _fallback_title(self, user_msg: str) -> str:
config = AppConfig.current().title config = get_title_config()
fallback_chars = min(config.max_chars, 50) fallback_chars = min(config.max_chars, 50)
if len(user_msg) > fallback_chars: if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..." return user_msg[:fallback_chars].rstrip() + "..."
@@ -113,7 +119,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
if not self._should_generate_title(state): if not self._should_generate_title(state):
return None return None
config = AppConfig.current().title config = get_title_config()
prompt, user_msg = self._build_title_prompt(state) prompt, user_msg = self._build_title_prompt(state)
try: try:
@@ -1,9 +1,14 @@
"""Middleware that extends TodoListMiddleware with context-loss detection. """Middleware that extends TodoListMiddleware with context-loss detection and premature-exit prevention.
When the message history is truncated (e.g., by SummarizationMiddleware), the When the message history is truncated (e.g., by SummarizationMiddleware), the
original `write_todos` tool call and its ToolMessage can be scrolled out of the original `write_todos` tool call and its ToolMessage can be scrolled out of the
active context window. This middleware detects that situation and injects a active context window. This middleware detects that situation and injects a
reminder message so the model still knows about the outstanding todo list. reminder message so the model still knows about the outstanding todo list.
Additionally, this middleware prevents the agent from exiting the loop while
there are still incomplete todo items. When the model produces a final response
(no tool calls) but todos are not yet complete, the middleware injects a reminder
and jumps back to the model node to force continued engagement.
""" """
from __future__ import annotations from __future__ import annotations
@@ -12,6 +17,7 @@ from typing import Any, override
from langchain.agents.middleware import TodoListMiddleware from langchain.agents.middleware import TodoListMiddleware
from langchain.agents.middleware.todo import PlanningState, Todo from langchain.agents.middleware.todo import PlanningState, Todo
from langchain.agents.middleware.types import hook_config
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
@@ -34,6 +40,11 @@ def _reminder_in_messages(messages: list[Any]) -> bool:
return False return False
def _completion_reminder_count(messages: list[Any]) -> int:
"""Return the number of todo_completion_reminder HumanMessages in *messages*."""
return sum(1 for msg in messages if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_completion_reminder")
def _format_todos(todos: list[Todo]) -> str: def _format_todos(todos: list[Todo]) -> str:
"""Format a list of Todo items into a human-readable string.""" """Format a list of Todo items into a human-readable string."""
lines: list[str] = [] lines: list[str] = []
@@ -57,7 +68,7 @@ class TodoMiddleware(TodoListMiddleware):
def before_model( def before_model(
self, self,
state: PlanningState, state: PlanningState,
runtime: Runtime, # noqa: ARG002 runtime: Runtime,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Inject a todo-list reminder when write_todos has left the context window.""" """Inject a todo-list reminder when write_todos has left the context window."""
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment] todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
@@ -98,3 +109,71 @@ class TodoMiddleware(TodoListMiddleware):
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Async version of before_model.""" """Async version of before_model."""
return self.before_model(state, runtime) return self.before_model(state, runtime)
# Maximum number of completion reminders before allowing the agent to exit.
# This prevents infinite loops when the agent cannot make further progress.
_MAX_COMPLETION_REMINDERS = 2
@hook_config(can_jump_to=["model"])
@override
def after_model(
self,
state: PlanningState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Prevent premature agent exit when todo items are still incomplete.
In addition to the base class check for parallel ``write_todos`` calls,
this override intercepts model responses that have no tool calls while
there are still incomplete todo items. It injects a reminder
``HumanMessage`` and jumps back to the model node so the agent
continues working through the todo list.
A retry cap of ``_MAX_COMPLETION_REMINDERS`` (default 2) prevents
infinite loops when the agent cannot make further progress.
"""
# 1. Preserve base class logic (parallel write_todos detection).
base_result = super().after_model(state, runtime)
if base_result is not None:
return base_result
# 2. Only intervene when the agent wants to exit (no tool calls).
messages = state.get("messages") or []
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
if not last_ai or last_ai.tool_calls:
return None
# 3. Allow exit when all todos are completed or there are no todos.
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
if not todos or all(t.get("status") == "completed" for t in todos):
return None
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS:
return None
# 5. Inject a reminder and force the agent back to the model.
incomplete = [t for t in todos if t.get("status") != "completed"]
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
reminder = HumanMessage(
name="todo_completion_reminder",
content=(
"<system_reminder>\n"
"You have incomplete todo items that must be finished before giving your final response:\n\n"
f"{incomplete_text}\n\n"
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
"as you finish them, and only respond when all items are done.\n"
"</system_reminder>"
),
)
return {"jump_to": "model", "messages": [reminder]}
@override
@hook_config(can_jump_to=["model"])
async def aafter_model(
self,
state: PlanningState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async version of after_model."""
return self.after_model(state, runtime)
@@ -94,9 +94,9 @@ def _build_runtime_middlewares(
middlewares.append(LLMErrorHandlingMiddleware()) middlewares.append(LLMErrorHandlingMiddleware())
# Guardrail middleware (if configured) # Guardrail middleware (if configured)
from deerflow.config.app_config import AppConfig from deerflow.config.guardrails_config import get_guardrails_config
guardrails_config = AppConfig.current().guardrails guardrails_config = get_guardrails_config()
if guardrails_config.enabled and guardrails_config.provider: if guardrails_config.enabled and guardrails_config.provider:
import inspect import inspect
@@ -9,7 +9,6 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
from deerflow.utils.file_conversion import extract_outline from deerflow.utils.file_conversion import extract_outline
@@ -185,7 +184,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
return files if files else None return files if files else None
@override @override
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None: def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None:
"""Inject uploaded files information before agent execution. """Inject uploaded files information before agent execution.
New files come from the current message's additional_kwargs.files. New files come from the current message's additional_kwargs.files.
@@ -214,7 +213,14 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
return None return None
# Resolve uploads directory for existence checks # Resolve uploads directory for existence checks
thread_id = runtime.context.thread_id thread_id = (runtime.context or {}).get("thread_id")
if thread_id is None:
try:
from langgraph.config import get_config
thread_id = get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
pass # get_config() raises outside a runnable context (e.g. unit tests)
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files # Get newly uploaded files from the current message's additional_kwargs.files
+26 -19
View File
@@ -36,9 +36,8 @@ from deerflow.agents.lead_agent.agent import _build_middlewares
from deerflow.agents.lead_agent.prompt import apply_prompt_template from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import AGENT_NAME_PATTERN from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.app_config import AppConfig from deerflow.config.app_config import get_app_config, reload_app_config
from deerflow.config.deer_flow_context import DeerFlowContext from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
from deerflow.skills.installer import install_skill_from_archive from deerflow.skills.installer import install_skill_from_archive
@@ -142,8 +141,8 @@ class DeerFlowClient:
middlewares: Optional list of custom middlewares to inject into the agent. middlewares: Optional list of custom middlewares to inject into the agent.
""" """
if config_path is not None: if config_path is not None:
AppConfig.init(AppConfig.from_file(config_path)) reload_app_config(config_path)
self._app_config = AppConfig.current() self._app_config = get_app_config()
if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name): if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}") raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
@@ -552,7 +551,9 @@ class DeerFlowClient:
self._ensure_agent(config) self._ensure_agent(config)
state: dict[str, Any] = {"messages": [HumanMessage(content=message)]} state: dict[str, Any] = {"messages": [HumanMessage(content=message)]}
context = DeerFlowContext(app_config=self._app_config, thread_id=thread_id, agent_name=self._agent_name) context = {"thread_id": thread_id}
if self._agent_name:
context["agent_name"] = self._agent_name
seen_ids: set[str] = set() seen_ids: set[str] = set()
# Cross-mode handoff: ids already streamed via LangGraph ``messages`` # Cross-mode handoff: ids already streamed via LangGraph ``messages``
@@ -721,6 +722,10 @@ class DeerFlowClient:
Dict with "models" key containing list of model info dicts, Dict with "models" key containing list of model info dicts,
matching the Gateway API ``ModelsListResponse`` schema. matching the Gateway API ``ModelsListResponse`` schema.
""" """
token_usage_enabled = getattr(getattr(self._app_config, "token_usage", None), "enabled", False)
if not isinstance(token_usage_enabled, bool):
token_usage_enabled = False
return { return {
"models": [ "models": [
{ {
@@ -732,7 +737,8 @@ class DeerFlowClient:
"supports_reasoning_effort": getattr(model, "supports_reasoning_effort", False), "supports_reasoning_effort": getattr(model, "supports_reasoning_effort", False),
} }
for model in self._app_config.models for model in self._app_config.models
] ],
"token_usage": {"enabled": token_usage_enabled},
} }
def list_skills(self, enabled_only: bool = False) -> dict: def list_skills(self, enabled_only: bool = False) -> dict:
@@ -815,8 +821,8 @@ class DeerFlowClient:
Dict with "mcp_servers" key mapping server name to config, Dict with "mcp_servers" key mapping server name to config,
matching the Gateway API ``McpConfigResponse`` schema. matching the Gateway API ``McpConfigResponse`` schema.
""" """
ext = AppConfig.current().extensions config = get_extensions_config()
return {"mcp_servers": {name: server.model_dump() for name, server in ext.mcp_servers.items()}} return {"mcp_servers": {name: server.model_dump() for name, server in config.mcp_servers.items()}}
def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict: def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict:
"""Update MCP server configurations. """Update MCP server configurations.
@@ -838,19 +844,18 @@ class DeerFlowClient:
if config_path is None: if config_path is None:
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.") raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
current_ext = AppConfig.current().extensions current_config = get_extensions_config()
config_data = { config_data = {
"mcpServers": mcp_servers, "mcpServers": mcp_servers,
"skills": {name: {"enabled": skill.enabled} for name, skill in current_ext.skills.items()}, "skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
} }
self._atomic_write_json(config_path, config_data) self._atomic_write_json(config_path, config_data)
self._agent = None self._agent = None
self._agent_config_key = None self._agent_config_key = None
AppConfig.init(AppConfig.from_file()) reloaded = reload_extensions_config()
reloaded = AppConfig.current().extensions
return {"mcp_servers": {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}} return {"mcp_servers": {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}}
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -904,19 +909,19 @@ class DeerFlowClient:
if config_path is None: if config_path is None:
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.") raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
ext = AppConfig.current().extensions extensions_config = get_extensions_config()
ext.skills[name] = SkillStateConfig(enabled=enabled) extensions_config.skills[name] = SkillStateConfig(enabled=enabled)
config_data = { config_data = {
"mcpServers": {n: s.model_dump() for n, s in ext.mcp_servers.items()}, "mcpServers": {n: s.model_dump() for n, s in extensions_config.mcp_servers.items()},
"skills": {n: {"enabled": sc.enabled} for n, sc in ext.skills.items()}, "skills": {n: {"enabled": sc.enabled} for n, sc in extensions_config.skills.items()},
} }
self._atomic_write_json(config_path, config_data) self._atomic_write_json(config_path, config_data)
self._agent = None self._agent = None
self._agent_config_key = None self._agent_config_key = None
AppConfig.init(AppConfig.from_file()) reload_extensions_config()
updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None) updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
if updated is None: if updated is None:
@@ -999,7 +1004,9 @@ class DeerFlowClient:
Returns: Returns:
Memory config dict. Memory config dict.
""" """
config = AppConfig.current().memory from deerflow.config.memory_config import get_memory_config
config = get_memory_config()
return { return {
"enabled": config.enabled, "enabled": config.enabled,
"storage_path": config.storage_path, "storage_path": config.storage_path,
@@ -25,7 +25,7 @@ except ImportError: # pragma: no cover - Windows fallback
fcntl = None # type: ignore[assignment] fcntl = None # type: ignore[assignment]
import msvcrt import msvcrt
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.sandbox.sandbox import Sandbox from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import SandboxProvider from deerflow.sandbox.sandbox_provider import SandboxProvider
@@ -119,6 +119,16 @@ class AioSandboxProvider(SandboxProvider):
if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0: if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0:
self._start_idle_checker() self._start_idle_checker()
@property
def uses_thread_data_mounts(self) -> bool:
"""Whether thread workspace/uploads/outputs are visible via mounts.
Local container backends bind-mount the thread data directories, so files
written by the gateway are already visible when the sandbox starts.
Remote backends may require explicit file sync.
"""
return isinstance(self._backend, LocalContainerBackend)
# ── Factory methods ────────────────────────────────────────────────── # ── Factory methods ──────────────────────────────────────────────────
def _create_backend(self) -> SandboxBackend: def _create_backend(self) -> SandboxBackend:
@@ -148,7 +158,7 @@ class AioSandboxProvider(SandboxProvider):
def _load_config(self) -> dict: def _load_config(self) -> dict:
"""Load sandbox configuration from app config.""" """Load sandbox configuration from app config."""
config = AppConfig.current() config = get_app_config()
sandbox_config = config.sandbox sandbox_config = config.sandbox
idle_timeout = getattr(sandbox_config, "idle_timeout", None) idle_timeout = getattr(sandbox_config, "idle_timeout", None)
@@ -279,7 +289,7 @@ class AioSandboxProvider(SandboxProvider):
so the host Docker daemon can resolve the path. so the host Docker daemon can resolve the path.
""" """
try: try:
config = AppConfig.current() config = get_app_config()
skills_path = config.skills.get_skills_path() skills_path = config.skills.get_skills_path()
container_path = config.skills.container_path container_path = config.skills.container_path
@@ -7,7 +7,7 @@ import logging
from langchain.tools import tool from langchain.tools import tool
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -63,7 +63,7 @@ def web_search_tool(
query: Search keywords describing what you want to find. Be specific for better results. query: Search keywords describing what you want to find. Be specific for better results.
max_results: Maximum number of results to return. Default is 5. max_results: Maximum number of results to return. Default is 5.
""" """
config = AppConfig.current().get_tool_config("web_search") config = get_app_config().get_tool_config("web_search")
# Override max_results from config if set # Override max_results from config if set
if config is not None and "max_results" in config.model_extra: if config is not None and "max_results" in config.model_extra:
@@ -3,11 +3,11 @@ import json
from exa_py import Exa from exa_py import Exa
from langchain.tools import tool from langchain.tools import tool
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
def _get_exa_client(tool_name: str = "web_search") -> Exa: def _get_exa_client(tool_name: str = "web_search") -> Exa:
config = AppConfig.current().get_tool_config(tool_name) config = get_app_config().get_tool_config(tool_name)
api_key = None api_key = None
if config is not None and "api_key" in config.model_extra: if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key") api_key = config.model_extra.get("api_key")
@@ -22,7 +22,7 @@ def web_search_tool(query: str) -> str:
query: The query to search for. query: The query to search for.
""" """
try: try:
config = AppConfig.current().get_tool_config("web_search") config = get_app_config().get_tool_config("web_search")
max_results = 5 max_results = 5
search_type = "auto" search_type = "auto"
contents_max_characters = 1000 contents_max_characters = 1000
@@ -3,11 +3,11 @@ import json
from firecrawl import FirecrawlApp from firecrawl import FirecrawlApp
from langchain.tools import tool from langchain.tools import tool
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
def _get_firecrawl_client(tool_name: str = "web_search") -> FirecrawlApp: def _get_firecrawl_client(tool_name: str = "web_search") -> FirecrawlApp:
config = AppConfig.current().get_tool_config(tool_name) config = get_app_config().get_tool_config(tool_name)
api_key = None api_key = None
if config is not None and "api_key" in config.model_extra: if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key") api_key = config.model_extra.get("api_key")
@@ -22,7 +22,7 @@ def web_search_tool(query: str) -> str:
query: The query to search for. query: The query to search for.
""" """
try: try:
config = AppConfig.current().get_tool_config("web_search") config = get_app_config().get_tool_config("web_search")
max_results = 5 max_results = 5
if config is not None: if config is not None:
max_results = config.model_extra.get("max_results", max_results) max_results = config.model_extra.get("max_results", max_results)
@@ -7,7 +7,7 @@ import logging
from langchain.tools import tool from langchain.tools import tool
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -99,7 +99,7 @@ def image_search_tool(
type_image: Image type filter. Options: "photo", "clipart", "gif", "transparent", "line". Use "photo" for realistic references. type_image: Image type filter. Options: "photo", "clipart", "gif", "transparent", "line". Use "photo" for realistic references.
layout: Layout filter. Options: "Square", "Tall", "Wide". Choose based on your generation needs. layout: Layout filter. Options: "Square", "Tall", "Wide". Choose based on your generation needs.
""" """
config = AppConfig.current().get_tool_config("image_search") config = get_app_config().get_tool_config("image_search")
# Override max_results from config if set # Override max_results from config if set
if config is not None and "max_results" in config.model_extra: if config is not None and "max_results" in config.model_extra:
@@ -1,6 +1,6 @@
from langchain.tools import tool from langchain.tools import tool
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
from deerflow.utils.readability import ReadabilityExtractor from deerflow.utils.readability import ReadabilityExtractor
from .infoquest_client import InfoQuestClient from .infoquest_client import InfoQuestClient
@@ -9,12 +9,12 @@ readability_extractor = ReadabilityExtractor()
def _get_infoquest_client() -> InfoQuestClient: def _get_infoquest_client() -> InfoQuestClient:
search_config = AppConfig.current().get_tool_config("web_search") search_config = get_app_config().get_tool_config("web_search")
search_time_range = -1 search_time_range = -1
if search_config is not None and "search_time_range" in search_config.model_extra: if search_config is not None and "search_time_range" in search_config.model_extra:
search_time_range = search_config.model_extra.get("search_time_range") search_time_range = search_config.model_extra.get("search_time_range")
fetch_config = AppConfig.current().get_tool_config("web_fetch") fetch_config = get_app_config().get_tool_config("web_fetch")
fetch_time = -1 fetch_time = -1
if fetch_config is not None and "fetch_time" in fetch_config.model_extra: if fetch_config is not None and "fetch_time" in fetch_config.model_extra:
fetch_time = fetch_config.model_extra.get("fetch_time") fetch_time = fetch_config.model_extra.get("fetch_time")
@@ -25,7 +25,7 @@ def _get_infoquest_client() -> InfoQuestClient:
if fetch_config is not None and "navigation_timeout" in fetch_config.model_extra: if fetch_config is not None and "navigation_timeout" in fetch_config.model_extra:
navigation_timeout = fetch_config.model_extra.get("navigation_timeout") navigation_timeout = fetch_config.model_extra.get("navigation_timeout")
image_search_config = AppConfig.current().get_tool_config("image_search") image_search_config = get_app_config().get_tool_config("image_search")
image_search_time_range = -1 image_search_time_range = -1
if image_search_config is not None and "image_search_time_range" in image_search_config.model_extra: if image_search_config is not None and "image_search_time_range" in image_search_config.model_extra:
image_search_time_range = image_search_config.model_extra.get("image_search_time_range") image_search_time_range = image_search_config.model_extra.get("image_search_time_range")
@@ -1,7 +1,9 @@
import asyncio
from langchain.tools import tool from langchain.tools import tool
from deerflow.community.jina_ai.jina_client import JinaClient from deerflow.community.jina_ai.jina_client import JinaClient
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
from deerflow.utils.readability import ReadabilityExtractor from deerflow.utils.readability import ReadabilityExtractor
readability_extractor = ReadabilityExtractor() readability_extractor = ReadabilityExtractor()
@@ -20,11 +22,11 @@ async def web_fetch_tool(url: str) -> str:
""" """
jina_client = JinaClient() jina_client = JinaClient()
timeout = 10 timeout = 10
config = AppConfig.current().get_tool_config("web_fetch") config = get_app_config().get_tool_config("web_fetch")
if config is not None and "timeout" in config.model_extra: if config is not None and "timeout" in config.model_extra:
timeout = config.model_extra.get("timeout") timeout = config.model_extra.get("timeout")
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout) html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
if isinstance(html_content, str) and html_content.startswith("Error:"): if isinstance(html_content, str) and html_content.startswith("Error:"):
return html_content return html_content
article = readability_extractor.extract_article(html_content) article = await asyncio.to_thread(readability_extractor.extract_article, html_content)
return article.to_markdown()[:4096] return article.to_markdown()[:4096]
@@ -3,11 +3,11 @@ import json
from langchain.tools import tool from langchain.tools import tool
from tavily import TavilyClient from tavily import TavilyClient
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
def _get_tavily_client() -> TavilyClient: def _get_tavily_client() -> TavilyClient:
config = AppConfig.current().get_tool_config("web_search") config = get_app_config().get_tool_config("web_search")
api_key = None api_key = None
if config is not None and "api_key" in config.model_extra: if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key") api_key = config.model_extra.get("api_key")
@@ -21,7 +21,7 @@ def web_search_tool(query: str) -> str:
Args: Args:
query: The query to search for. query: The query to search for.
""" """
config = AppConfig.current().get_tool_config("web_search") config = get_app_config().get_tool_config("web_search")
max_results = 5 max_results = 5
if config is not None and "max_results" in config.model_extra: if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results") max_results = config.model_extra.get("max_results")
@@ -1,6 +1,6 @@
from .app_config import AppConfig from .app_config import get_app_config
from .extensions_config import ExtensionsConfig from .extensions_config import ExtensionsConfig, get_extensions_config
from .memory_config import MemoryConfig from .memory_config import MemoryConfig, get_memory_config
from .paths import Paths, get_paths from .paths import Paths, get_paths
from .skill_evolution_config import SkillEvolutionConfig from .skill_evolution_config import SkillEvolutionConfig
from .skills_config import SkillsConfig from .skills_config import SkillsConfig
@@ -13,16 +13,18 @@ from .tracing_config import (
) )
__all__ = [ __all__ = [
"AppConfig", "get_app_config",
"ExtensionsConfig",
"MemoryConfig",
"Paths",
"SkillEvolutionConfig", "SkillEvolutionConfig",
"SkillsConfig", "Paths",
"get_enabled_tracing_providers",
"get_explicitly_enabled_tracing_providers",
"get_paths", "get_paths",
"SkillsConfig",
"ExtensionsConfig",
"get_extensions_config",
"MemoryConfig",
"get_memory_config",
"get_tracing_config", "get_tracing_config",
"get_explicitly_enabled_tracing_providers",
"get_enabled_tracing_providers",
"is_tracing_enabled", "is_tracing_enabled",
"validate_enabled_tracing_providers", "validate_enabled_tracing_providers",
] ]
@@ -1,13 +1,16 @@
"""ACP (Agent Client Protocol) agent configuration loaded from config.yaml.""" """ACP (Agent Client Protocol) agent configuration loaded from config.yaml."""
from pydantic import BaseModel, ConfigDict, Field import logging
from collections.abc import Mapping
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class ACPAgentConfig(BaseModel): class ACPAgentConfig(BaseModel):
"""Configuration for a single ACP-compatible agent.""" """Configuration for a single ACP-compatible agent."""
model_config = ConfigDict(frozen=True)
command: str = Field(description="Command to launch the ACP agent subprocess") command: str = Field(description="Command to launch the ACP agent subprocess")
args: list[str] = Field(default_factory=list, description="Additional command arguments") args: list[str] = Field(default_factory=list, description="Additional command arguments")
env: dict[str, str] = Field(default_factory=dict, description="Environment variables to inject into the agent subprocess. Values starting with $ are resolved from host environment variables.") env: dict[str, str] = Field(default_factory=dict, description="Environment variables to inject into the agent subprocess. Values starting with $ are resolved from host environment variables.")
@@ -21,3 +24,28 @@ class ACPAgentConfig(BaseModel):
"are denied — the agent must be configured to operate without requesting permissions." "are denied — the agent must be configured to operate without requesting permissions."
), ),
) )
_acp_agents: dict[str, ACPAgentConfig] = {}
def get_acp_agents() -> dict[str, ACPAgentConfig]:
"""Get the currently configured ACP agents.
Returns:
Mapping of agent name -> ACPAgentConfig. Empty dict if no ACP agents are configured.
"""
return _acp_agents
def load_acp_config_from_dict(config_dict: Mapping[str, Mapping[str, object]] | None) -> None:
"""Load ACP agent configuration from a dictionary (typically from config.yaml).
Args:
config_dict: Mapping of agent name -> config fields.
"""
global _acp_agents
if config_dict is None:
config_dict = {}
_acp_agents = {name: ACPAgentConfig(**cfg) for name, cfg in config_dict.items()}
logger.info("ACP config loaded: %d agent(s): %s", len(_acp_agents), list(_acp_agents.keys()))
@@ -0,0 +1,32 @@
"""Configuration for the custom agents management API."""
from pydantic import BaseModel, Field
class AgentsApiConfig(BaseModel):
"""Configuration for custom-agent and user-profile management routes."""
enabled: bool = Field(
default=False,
description=("Whether to expose the custom-agent management API over HTTP. When disabled, the gateway rejects read/write access to custom agent SOUL.md, config, and USER.md prompt-management routes."),
)
_agents_api_config: AgentsApiConfig = AgentsApiConfig()
def get_agents_api_config() -> AgentsApiConfig:
"""Get the current agents API configuration."""
return _agents_api_config
def set_agents_api_config(config: AgentsApiConfig) -> None:
"""Set the agents API configuration."""
global _agents_api_config
_agents_api_config = config
def load_agents_api_config_from_dict(config_dict: dict) -> None:
"""Load agents API configuration from a dictionary."""
global _agents_api_config
_agents_api_config = AgentsApiConfig(**config_dict)
@@ -5,7 +5,7 @@ import re
from typing import Any from typing import Any
import yaml import yaml
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
@@ -15,11 +15,20 @@ SOUL_FILENAME = "SOUL.md"
AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$") AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
def validate_agent_name(name: str | None) -> str | None:
"""Validate a custom agent name before using it in filesystem paths."""
if name is None:
return None
if not isinstance(name, str):
raise ValueError("Invalid agent name. Expected a string or None.")
if not AGENT_NAME_PATTERN.fullmatch(name):
raise ValueError(f"Invalid agent name '{name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
return name
class AgentConfig(BaseModel): class AgentConfig(BaseModel):
"""Configuration for a custom agent.""" """Configuration for a custom agent."""
model_config = ConfigDict(frozen=True)
name: str name: str
description: str = "" description: str = ""
model: str | None = None model: str | None = None
@@ -48,8 +57,7 @@ def load_agent_config(name: str | None) -> AgentConfig | None:
if name is None: if name is None:
return None return None
if not AGENT_NAME_PATTERN.match(name): name = validate_agent_name(name)
raise ValueError(f"Invalid agent name '{name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
agent_dir = get_paths().agent_dir(name) agent_dir = get_paths().agent_dir(name)
config_file = agent_dir / "config.yaml" config_file = agent_dir / "config.yaml"
@@ -1,37 +1,43 @@
from __future__ import annotations
import logging import logging
import os import os
from contextvars import ContextVar from contextvars import ContextVar
from pathlib import Path from pathlib import Path
from typing import Any, ClassVar, Self from typing import Any, Self
import yaml import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.acp_config import ACPAgentConfig from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.checkpointer_config import CheckpointerConfig from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.extensions_config import ExtensionsConfig from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
from deerflow.config.memory_config import MemoryConfig from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skill_evolution_config import SkillEvolutionConfig from deerflow.config.skill_evolution_config import SkillEvolutionConfig
from deerflow.config.skills_config import SkillsConfig from deerflow.config.skills_config import SkillsConfig
from deerflow.config.stream_bridge_config import StreamBridgeConfig from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict
from deerflow.config.subagents_config import SubagentsAppConfig from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict
from deerflow.config.summarization_config import SummarizationConfig from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict
from deerflow.config.title_config import TitleConfig from deerflow.config.title_config import TitleConfig, load_title_config_from_dict
from deerflow.config.token_usage_config import TokenUsageConfig from deerflow.config.token_usage_config import TokenUsageConfig
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
from deerflow.config.tool_search_config import ToolSearchConfig from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
load_dotenv() load_dotenv()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CircuitBreakerConfig(BaseModel):
"""Configuration for the LLM Circuit Breaker."""
failure_threshold: int = Field(default=5, description="Number of consecutive failures before tripping the circuit")
recovery_timeout_sec: int = Field(default=60, description="Time in seconds before attempting to recover the circuit")
def _default_config_candidates() -> tuple[Path, ...]: def _default_config_candidates() -> tuple[Path, ...]:
"""Return deterministic config.yaml locations without relying on cwd.""" """Return deterministic config.yaml locations without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4] backend_dir = Path(__file__).resolve().parents[4]
@@ -55,12 +61,13 @@ class AppConfig(BaseModel):
title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration") title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration")
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration") summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration") memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
agents_api: AgentsApiConfig = Field(default_factory=AgentsApiConfig, description="Custom-agent management API configuration")
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration") subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration") guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
model_config = ConfigDict(extra="allow", frozen=True) circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
model_config = ConfigDict(extra="allow", frozen=False)
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration") checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration") stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
acp_agents: dict[str, ACPAgentConfig] = Field(default_factory=dict, description="ACP agent configurations keyed by agent name")
@classmethod @classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path: def resolve_config_path(cls, config_path: str | None = None) -> Path:
@@ -108,6 +115,49 @@ class AppConfig(BaseModel):
config_data = cls.resolve_env_variables(config_data) config_data = cls.resolve_env_variables(config_data)
# Load title config if present
if "title" in config_data:
load_title_config_from_dict(config_data["title"])
# Load summarization config if present
if "summarization" in config_data:
load_summarization_config_from_dict(config_data["summarization"])
# Load memory config if present
if "memory" in config_data:
load_memory_config_from_dict(config_data["memory"])
# Always refresh agents API config so removed config sections reset
# singleton-backed state to its default/disabled values on reload.
load_agents_api_config_from_dict(config_data.get("agents_api") or {})
# Load subagents config if present
if "subagents" in config_data:
load_subagents_config_from_dict(config_data["subagents"])
# Load tool_search config if present
if "tool_search" in config_data:
load_tool_search_config_from_dict(config_data["tool_search"])
# Load guardrails config if present
if "guardrails" in config_data:
load_guardrails_config_from_dict(config_data["guardrails"])
# Load circuit_breaker config if present
if "circuit_breaker" in config_data:
config_data["circuit_breaker"] = config_data["circuit_breaker"]
# Load checkpointer config if present
if "checkpointer" in config_data:
load_checkpointer_config_from_dict(config_data["checkpointer"])
# Load stream bridge config if present
if "stream_bridge" in config_data:
load_stream_bridge_config_from_dict(config_data["stream_bridge"])
# Always refresh ACP agent config so removed entries do not linger across reloads.
load_acp_config_from_dict(config_data.get("acp_agents", {}))
# Load extensions config separately (it's in a different file) # Load extensions config separately (it's in a different file)
extensions_config = ExtensionsConfig.from_file() extensions_config = ExtensionsConfig.from_file()
config_data["extensions"] = extensions_config.model_dump() config_data["extensions"] = extensions_config.model_dump()
@@ -218,26 +268,130 @@ class AppConfig(BaseModel):
""" """
return next((group for group in self.tool_groups if group.name == name), None) return next((group for group in self.tool_groups if group.name == name), None)
# -- Lifecycle (class-level singleton via ContextVar) --
_current: ClassVar[ContextVar[AppConfig]] = ContextVar("deerflow_app_config") _app_config: AppConfig | None = None
_app_config_path: Path | None = None
_app_config_mtime: float | None = None
_app_config_is_custom = False
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
@classmethod
def init(cls, config: AppConfig) -> None:
"""Set the AppConfig for the current context. Call once at process startup."""
cls._current.set(config)
@classmethod def _get_config_mtime(config_path: Path) -> float | None:
def current(cls) -> AppConfig: """Get the modification time of a config file if it exists."""
"""Get the current AppConfig. try:
return config_path.stat().st_mtime
except OSError:
return None
Auto-initializes from config file on first access for backward compatibility.
Prefer calling AppConfig.init() explicitly at process startup. def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
""" """Load config from disk and refresh cache metadata."""
try: global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
return cls._current.get()
except LookupError: resolved_path = AppConfig.resolve_config_path(config_path)
logger.debug("AppConfig not initialized, auto-loading from file") _app_config = AppConfig.from_file(str(resolved_path))
config = cls.from_file() _app_config_path = resolved_path
cls._current.set(config) _app_config_mtime = _get_config_mtime(resolved_path)
return config _app_config_is_custom = False
return _app_config
def get_app_config() -> AppConfig:
"""Get the DeerFlow config instance.
Returns a cached singleton instance and automatically reloads it when the
underlying config file path or modification time changes. Use
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
the cache.
"""
global _app_config, _app_config_path, _app_config_mtime
runtime_override = _current_app_config.get()
if runtime_override is not None:
return runtime_override
if _app_config is not None and _app_config_is_custom:
return _app_config
resolved_path = AppConfig.resolve_config_path()
current_mtime = _get_config_mtime(resolved_path)
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
if should_reload:
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
logger.info(
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
_app_config_mtime,
current_mtime,
)
_load_and_cache_app_config(str(resolved_path))
return _app_config
def reload_app_config(config_path: str | None = None) -> AppConfig:
"""Reload the config from file and update the cached instance.
This is useful when the config file has been modified and you want
to pick up the changes without restarting the application.
Args:
config_path: Optional path to config file. If not provided,
uses the default resolution strategy.
Returns:
The newly loaded AppConfig instance.
"""
return _load_and_cache_app_config(config_path)
def reset_app_config() -> None:
"""Reset the cached config instance.
This clears the singleton cache, causing the next call to
`get_app_config()` to reload from file. Useful for testing
or when switching between different configurations.
"""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = None
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = False
def set_app_config(config: AppConfig) -> None:
"""Set a custom config instance.
This allows injecting a custom or mock config for testing purposes.
Args:
config: The AppConfig instance to use.
"""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = config
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = True
def peek_current_app_config() -> AppConfig | None:
"""Return the runtime-scoped AppConfig override, if one is active."""
return _current_app_config.get()
def push_current_app_config(config: AppConfig) -> None:
"""Push a runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
_current_app_config_stack.set(stack + (_current_app_config.get(),))
_current_app_config.set(config)
def pop_current_app_config() -> None:
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
if not stack:
_current_app_config.set(None)
return
previous = stack[-1]
_current_app_config_stack.set(stack[:-1])
_current_app_config.set(previous)
@@ -2,7 +2,7 @@
from typing import Literal from typing import Literal
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
CheckpointerType = Literal["memory", "sqlite", "postgres"] CheckpointerType = Literal["memory", "sqlite", "postgres"]
@@ -10,8 +10,6 @@ CheckpointerType = Literal["memory", "sqlite", "postgres"]
class CheckpointerConfig(BaseModel): class CheckpointerConfig(BaseModel):
"""Configuration for LangGraph state persistence checkpointer.""" """Configuration for LangGraph state persistence checkpointer."""
model_config = ConfigDict(frozen=True)
type: CheckpointerType = Field( type: CheckpointerType = Field(
description="Checkpointer backend type. " description="Checkpointer backend type. "
"'memory' is in-process only (lost on restart). " "'memory' is in-process only (lost on restart). "
@@ -25,3 +23,24 @@ class CheckpointerConfig(BaseModel):
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. " "For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.", "For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
) )
# Global configuration instance — None means no checkpointer is configured.
_checkpointer_config: CheckpointerConfig | None = None
def get_checkpointer_config() -> CheckpointerConfig | None:
"""Get the current checkpointer configuration, or None if not configured."""
return _checkpointer_config
def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
"""Set the checkpointer configuration."""
global _checkpointer_config
_checkpointer_config = config
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
"""Load checkpointer configuration from a dictionary."""
global _checkpointer_config
_checkpointer_config = CheckpointerConfig(**config_dict)
@@ -1,59 +0,0 @@
"""Per-invocation context for DeerFlow agent execution.
Injected via LangGraph Runtime. Middleware and tools access this
via Runtime[DeerFlowContext] parameters, through resolve_context().
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True)
class DeerFlowContext:
"""Typed, immutable, per-invocation context injected via LangGraph Runtime.
Fields are all known at run start and never change during execution.
Mutable runtime state (e.g. sandbox_id) flows through ThreadState, not here.
"""
app_config: Any # AppConfig — typed as Any to avoid circular import at module level
thread_id: str
agent_name: str | None = None
def resolve_context(runtime: Any) -> DeerFlowContext:
"""Extract or construct DeerFlowContext from runtime.
Gateway/Client paths: runtime.context is already DeerFlowContext → return directly.
LangGraph Server / legacy dict path: construct from dict context or configurable fallback.
"""
ctx = getattr(runtime, "context", None)
if isinstance(ctx, DeerFlowContext):
return ctx
from deerflow.config.app_config import AppConfig
# Try dict context first (legacy path, tests), then configurable
if isinstance(ctx, dict):
return DeerFlowContext(
app_config=AppConfig.current(),
thread_id=ctx.get("thread_id", ""),
agent_name=ctx.get("agent_name"),
)
# No context at all — fall back to LangGraph configurable
try:
from langgraph.config import get_config
cfg = get_config().get("configurable", {})
except RuntimeError:
# Outside runnable context (e.g. unit tests)
cfg = {}
return DeerFlowContext(
app_config=AppConfig.current(),
thread_id=cfg.get("thread_id", ""),
agent_name=cfg.get("agent_name"),
)
@@ -11,8 +11,6 @@ from pydantic import BaseModel, ConfigDict, Field
class McpOAuthConfig(BaseModel): class McpOAuthConfig(BaseModel):
"""OAuth configuration for an MCP server (HTTP/SSE transports).""" """OAuth configuration for an MCP server (HTTP/SSE transports)."""
model_config = ConfigDict(extra="allow", frozen=True)
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled") enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
token_url: str = Field(description="OAuth token endpoint URL") token_url: str = Field(description="OAuth token endpoint URL")
grant_type: Literal["client_credentials", "refresh_token"] = Field( grant_type: Literal["client_credentials", "refresh_token"] = Field(
@@ -30,13 +28,12 @@ class McpOAuthConfig(BaseModel):
default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response") default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response")
refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry") refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry")
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint") extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
model_config = ConfigDict(extra="allow")
class McpServerConfig(BaseModel): class McpServerConfig(BaseModel):
"""Configuration for a single MCP server.""" """Configuration for a single MCP server."""
model_config = ConfigDict(extra="allow", frozen=True)
enabled: bool = Field(default=True, description="Whether this MCP server is enabled") enabled: bool = Field(default=True, description="Whether this MCP server is enabled")
type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'") type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'")
command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)") command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)")
@@ -46,13 +43,12 @@ class McpServerConfig(BaseModel):
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)") headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)") oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)")
description: str = Field(default="", description="Human-readable description of what this MCP server provides") description: str = Field(default="", description="Human-readable description of what this MCP server provides")
model_config = ConfigDict(extra="allow")
class SkillStateConfig(BaseModel): class SkillStateConfig(BaseModel):
"""Configuration for a single skill's state.""" """Configuration for a single skill's state."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=True, description="Whether this skill is enabled") enabled: bool = Field(default=True, description="Whether this skill is enabled")
@@ -68,7 +64,7 @@ class ExtensionsConfig(BaseModel):
default_factory=dict, default_factory=dict,
description="Map of skill name to state configuration", description="Map of skill name to state configuration",
) )
model_config = ConfigDict(extra="allow", frozen=True, populate_by_name=True) model_config = ConfigDict(extra="allow", populate_by_name=True)
@classmethod @classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path | None: def resolve_config_path(cls, config_path: str | None = None) -> Path | None:
@@ -199,3 +195,62 @@ class ExtensionsConfig(BaseModel):
# Default to enable for public & custom skill # Default to enable for public & custom skill
return skill_category in ("public", "custom") return skill_category in ("public", "custom")
return skill_config.enabled return skill_config.enabled
_extensions_config: ExtensionsConfig | None = None
def get_extensions_config() -> ExtensionsConfig:
"""Get the extensions config instance.
Returns a cached singleton instance. Use `reload_extensions_config()` to reload
from file, or `reset_extensions_config()` to clear the cache.
Returns:
The cached ExtensionsConfig instance.
"""
global _extensions_config
if _extensions_config is None:
_extensions_config = ExtensionsConfig.from_file()
return _extensions_config
def reload_extensions_config(config_path: str | None = None) -> ExtensionsConfig:
"""Reload the extensions config from file and update the cached instance.
This is useful when the config file has been modified and you want
to pick up the changes without restarting the application.
Args:
config_path: Optional path to extensions config file. If not provided,
uses the default resolution strategy.
Returns:
The newly loaded ExtensionsConfig instance.
"""
global _extensions_config
_extensions_config = ExtensionsConfig.from_file(config_path)
return _extensions_config
def reset_extensions_config() -> None:
"""Reset the cached extensions config instance.
This clears the singleton cache, causing the next call to
`get_extensions_config()` to reload from file. Useful for testing
or when switching between different configurations.
"""
global _extensions_config
_extensions_config = None
def set_extensions_config(config: ExtensionsConfig) -> None:
"""Set a custom extensions config instance.
This allows injecting a custom or mock config for testing purposes.
Args:
config: The ExtensionsConfig instance to use.
"""
global _extensions_config
_extensions_config = config
@@ -1,13 +1,11 @@
"""Configuration for pre-tool-call authorization.""" """Configuration for pre-tool-call authorization."""
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
class GuardrailProviderConfig(BaseModel): class GuardrailProviderConfig(BaseModel):
"""Configuration for a guardrail provider.""" """Configuration for a guardrail provider."""
model_config = ConfigDict(frozen=True)
use: str = Field(description="Class path (e.g. 'deerflow.guardrails.builtin:AllowlistProvider')") use: str = Field(description="Class path (e.g. 'deerflow.guardrails.builtin:AllowlistProvider')")
config: dict = Field(default_factory=dict, description="Provider-specific settings passed as kwargs") config: dict = Field(default_factory=dict, description="Provider-specific settings passed as kwargs")
@@ -20,9 +18,31 @@ class GuardrailsConfig(BaseModel):
agent's passport reference, and returns an allow/deny decision. agent's passport reference, and returns an allow/deny decision.
""" """
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=False, description="Enable guardrail middleware") enabled: bool = Field(default=False, description="Enable guardrail middleware")
fail_closed: bool = Field(default=True, description="Block tool calls if provider errors") fail_closed: bool = Field(default=True, description="Block tool calls if provider errors")
passport: str | None = Field(default=None, description="OAP passport path or hosted agent ID") passport: str | None = Field(default=None, description="OAP passport path or hosted agent ID")
provider: GuardrailProviderConfig | None = Field(default=None, description="Guardrail provider configuration") provider: GuardrailProviderConfig | None = Field(default=None, description="Guardrail provider configuration")
_guardrails_config: GuardrailsConfig | None = None
def get_guardrails_config() -> GuardrailsConfig:
"""Get the guardrails config, returning defaults if not loaded."""
global _guardrails_config
if _guardrails_config is None:
_guardrails_config = GuardrailsConfig()
return _guardrails_config
def load_guardrails_config_from_dict(data: dict) -> GuardrailsConfig:
"""Load guardrails config from a dict (called during AppConfig loading)."""
global _guardrails_config
_guardrails_config = GuardrailsConfig.model_validate(data)
return _guardrails_config
def reset_guardrails_config() -> None:
"""Reset the cached config instance. Used in tests to prevent singleton leaks."""
global _guardrails_config
_guardrails_config = None
@@ -1,13 +1,11 @@
"""Configuration for memory mechanism.""" """Configuration for memory mechanism."""
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
class MemoryConfig(BaseModel): class MemoryConfig(BaseModel):
"""Configuration for global memory mechanism.""" """Configuration for global memory mechanism."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=True, default=True,
description="Whether to enable memory mechanism", description="Whether to enable memory mechanism",
@@ -61,3 +59,24 @@ class MemoryConfig(BaseModel):
le=8000, le=8000,
description="Maximum tokens to use for memory injection", description="Maximum tokens to use for memory injection",
) )
# Global configuration instance
_memory_config: MemoryConfig = MemoryConfig()
def get_memory_config() -> MemoryConfig:
"""Get the current memory configuration."""
return _memory_config
def set_memory_config(config: MemoryConfig) -> None:
"""Set the memory configuration."""
global _memory_config
_memory_config = config
def load_memory_config_from_dict(config_dict: dict) -> None:
"""Load memory configuration from a dictionary."""
global _memory_config
_memory_config = MemoryConfig(**config_dict)
@@ -12,7 +12,7 @@ class ModelConfig(BaseModel):
description="Class path of the model provider(e.g. langchain_openai.ChatOpenAI)", description="Class path of the model provider(e.g. langchain_openai.ChatOpenAI)",
) )
model: str = Field(..., description="Model name") model: str = Field(..., description="Model name")
model_config = ConfigDict(extra="allow", frozen=True) model_config = ConfigDict(extra="allow")
use_responses_api: bool | None = Field( use_responses_api: bool | None = Field(
default=None, default=None,
description="Whether to route OpenAI ChatOpenAI calls through the /v1/responses API", description="Whether to route OpenAI ChatOpenAI calls through the /v1/responses API",
@@ -4,8 +4,6 @@ from pydantic import BaseModel, ConfigDict, Field
class VolumeMountConfig(BaseModel): class VolumeMountConfig(BaseModel):
"""Configuration for a volume mount.""" """Configuration for a volume mount."""
model_config = ConfigDict(frozen=True)
host_path: str = Field(..., description="Path on the host machine") host_path: str = Field(..., description="Path on the host machine")
container_path: str = Field(..., description="Path inside the container") container_path: str = Field(..., description="Path inside the container")
read_only: bool = Field(default=False, description="Whether the mount is read-only") read_only: bool = Field(default=False, description="Whether the mount is read-only")
@@ -82,4 +80,4 @@ class SandboxConfig(BaseModel):
description="Maximum characters to keep from ls tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.", description="Maximum characters to keep from ls tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
) )
model_config = ConfigDict(extra="allow", frozen=True) model_config = ConfigDict(extra="allow")
@@ -1,11 +1,9 @@
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
class SkillEvolutionConfig(BaseModel): class SkillEvolutionConfig(BaseModel):
"""Configuration for agent-managed skill evolution.""" """Configuration for agent-managed skill evolution."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=False, default=False,
description="Whether the agent can create and modify skills under skills/custom.", description="Whether the agent can create and modify skills under skills/custom.",
@@ -1,6 +1,6 @@
from pathlib import Path from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
def _default_repo_root() -> Path: def _default_repo_root() -> Path:
@@ -11,8 +11,6 @@ def _default_repo_root() -> Path:
class SkillsConfig(BaseModel): class SkillsConfig(BaseModel):
"""Configuration for skills system""" """Configuration for skills system"""
model_config = ConfigDict(frozen=True)
path: str | None = Field( path: str | None = Field(
default=None, default=None,
description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory", description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory",
@@ -2,7 +2,7 @@
from typing import Literal from typing import Literal
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
StreamBridgeType = Literal["memory", "redis"] StreamBridgeType = Literal["memory", "redis"]
@@ -10,8 +10,6 @@ StreamBridgeType = Literal["memory", "redis"]
class StreamBridgeConfig(BaseModel): class StreamBridgeConfig(BaseModel):
"""Configuration for the stream bridge that connects agent workers to SSE endpoints.""" """Configuration for the stream bridge that connects agent workers to SSE endpoints."""
model_config = ConfigDict(frozen=True)
type: StreamBridgeType = Field( type: StreamBridgeType = Field(
default="memory", default="memory",
description="Stream bridge backend type. 'memory' uses in-process asyncio.Queue (single-process only). 'redis' uses Redis Streams (planned for Phase 2, not yet implemented).", description="Stream bridge backend type. 'memory' uses in-process asyncio.Queue (single-process only). 'redis' uses Redis Streams (planned for Phase 2, not yet implemented).",
@@ -24,3 +22,25 @@ class StreamBridgeConfig(BaseModel):
default=256, default=256,
description="Maximum number of events buffered per run in the memory bridge.", description="Maximum number of events buffered per run in the memory bridge.",
) )
# Global configuration instance — None means no stream bridge is configured
# (falls back to memory with defaults).
_stream_bridge_config: StreamBridgeConfig | None = None
def get_stream_bridge_config() -> StreamBridgeConfig | None:
"""Get the current stream bridge configuration, or None if not configured."""
return _stream_bridge_config
def set_stream_bridge_config(config: StreamBridgeConfig | None) -> None:
"""Set the stream bridge configuration."""
global _stream_bridge_config
_stream_bridge_config = config
def load_stream_bridge_config_from_dict(config_dict: dict) -> None:
"""Load stream bridge configuration from a dictionary."""
global _stream_bridge_config
_stream_bridge_config = StreamBridgeConfig(**config_dict)
@@ -1,13 +1,15 @@
"""Configuration for the subagent system loaded from config.yaml.""" """Configuration for the subagent system loaded from config.yaml."""
from pydantic import BaseModel, ConfigDict, Field import logging
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class SubagentOverrideConfig(BaseModel): class SubagentOverrideConfig(BaseModel):
"""Per-agent configuration overrides.""" """Per-agent configuration overrides."""
model_config = ConfigDict(frozen=True)
timeout_seconds: int | None = Field( timeout_seconds: int | None = Field(
default=None, default=None,
ge=1, ge=1,
@@ -18,13 +20,16 @@ class SubagentOverrideConfig(BaseModel):
ge=1, ge=1,
description="Maximum turns for this subagent (None = use global or builtin default)", description="Maximum turns for this subagent (None = use global or builtin default)",
) )
model: str | None = Field(
default=None,
min_length=1,
description="Model name for this subagent (None = inherit from parent agent)",
)
class SubagentsAppConfig(BaseModel): class SubagentsAppConfig(BaseModel):
"""Configuration for the subagent system.""" """Configuration for the subagent system."""
model_config = ConfigDict(frozen=True)
timeout_seconds: int = Field( timeout_seconds: int = Field(
default=900, default=900,
ge=1, ge=1,
@@ -54,6 +59,20 @@ class SubagentsAppConfig(BaseModel):
return override.timeout_seconds return override.timeout_seconds
return self.timeout_seconds return self.timeout_seconds
def get_model_for(self, agent_name: str) -> str | None:
"""Get the model override for a specific agent.
Args:
agent_name: The name of the subagent.
Returns:
Model name if overridden, None otherwise (subagent will inherit parent model).
"""
override = self.agents.get(agent_name)
if override is not None and override.model is not None:
return override.model
return None
def get_max_turns_for(self, agent_name: str, builtin_default: int) -> int: def get_max_turns_for(self, agent_name: str, builtin_default: int) -> int:
"""Get the effective max_turns for a specific agent.""" """Get the effective max_turns for a specific agent."""
override = self.agents.get(agent_name) override = self.agents.get(agent_name)
@@ -62,3 +81,43 @@ class SubagentsAppConfig(BaseModel):
if self.max_turns is not None: if self.max_turns is not None:
return self.max_turns return self.max_turns
return builtin_default return builtin_default
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
def get_subagents_app_config() -> SubagentsAppConfig:
"""Get the current subagents configuration."""
return _subagents_config
def load_subagents_config_from_dict(config_dict: dict) -> None:
"""Load subagents configuration from a dictionary."""
global _subagents_config
_subagents_config = SubagentsAppConfig(**config_dict)
overrides_summary = {}
for name, override in _subagents_config.agents.items():
parts = []
if override.timeout_seconds is not None:
parts.append(f"timeout={override.timeout_seconds}s")
if override.max_turns is not None:
parts.append(f"max_turns={override.max_turns}")
if override.model is not None:
parts.append(f"model={override.model}")
if parts:
overrides_summary[name] = ", ".join(parts)
if overrides_summary:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
overrides_summary,
)
else:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
)
@@ -2,7 +2,7 @@
from typing import Literal from typing import Literal
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
ContextSizeType = Literal["fraction", "tokens", "messages"] ContextSizeType = Literal["fraction", "tokens", "messages"]
@@ -10,8 +10,6 @@ ContextSizeType = Literal["fraction", "tokens", "messages"]
class ContextSize(BaseModel): class ContextSize(BaseModel):
"""Context size specification for trigger or keep parameters.""" """Context size specification for trigger or keep parameters."""
model_config = ConfigDict(frozen=True)
type: ContextSizeType = Field(description="Type of context size specification") type: ContextSizeType = Field(description="Type of context size specification")
value: int | float = Field(description="Value for the context size specification") value: int | float = Field(description="Value for the context size specification")
@@ -23,8 +21,6 @@ class ContextSize(BaseModel):
class SummarizationConfig(BaseModel): class SummarizationConfig(BaseModel):
"""Configuration for automatic conversation summarization.""" """Configuration for automatic conversation summarization."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=False, default=False,
description="Whether to enable automatic conversation summarization", description="Whether to enable automatic conversation summarization",
@@ -55,3 +51,24 @@ class SummarizationConfig(BaseModel):
default=None, default=None,
description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.", description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.",
) )
# Global configuration instance
_summarization_config: SummarizationConfig = SummarizationConfig()
def get_summarization_config() -> SummarizationConfig:
"""Get the current summarization configuration."""
return _summarization_config
def set_summarization_config(config: SummarizationConfig) -> None:
"""Set the summarization configuration."""
global _summarization_config
_summarization_config = config
def load_summarization_config_from_dict(config_dict: dict) -> None:
"""Load summarization configuration from a dictionary."""
global _summarization_config
_summarization_config = SummarizationConfig(**config_dict)
@@ -1,13 +1,11 @@
"""Configuration for automatic thread title generation.""" """Configuration for automatic thread title generation."""
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
class TitleConfig(BaseModel): class TitleConfig(BaseModel):
"""Configuration for automatic thread title generation.""" """Configuration for automatic thread title generation."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=True, default=True,
description="Whether to enable automatic title generation", description="Whether to enable automatic title generation",
@@ -32,3 +30,24 @@ class TitleConfig(BaseModel):
default=("Generate a concise title (max {max_words} words) for this conversation.\nUser: {user_msg}\nAssistant: {assistant_msg}\n\nReturn ONLY the title, no quotes, no explanation."), default=("Generate a concise title (max {max_words} words) for this conversation.\nUser: {user_msg}\nAssistant: {assistant_msg}\n\nReturn ONLY the title, no quotes, no explanation."),
description="Prompt template for title generation", description="Prompt template for title generation",
) )
# Global configuration instance
_title_config: TitleConfig = TitleConfig()
def get_title_config() -> TitleConfig:
"""Get the current title configuration."""
return _title_config
def set_title_config(config: TitleConfig) -> None:
"""Set the title configuration."""
global _title_config
_title_config = config
def load_title_config_from_dict(config_dict: dict) -> None:
"""Load title configuration from a dictionary."""
global _title_config
_title_config = TitleConfig(**config_dict)
@@ -1,9 +1,7 @@
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
class TokenUsageConfig(BaseModel): class TokenUsageConfig(BaseModel):
"""Configuration for token usage tracking.""" """Configuration for token usage tracking."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=False, description="Enable token usage tracking middleware") enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
@@ -5,7 +5,7 @@ class ToolGroupConfig(BaseModel):
"""Config section for a tool group""" """Config section for a tool group"""
name: str = Field(..., description="Unique name for the tool group") name: str = Field(..., description="Unique name for the tool group")
model_config = ConfigDict(extra="allow", frozen=True) model_config = ConfigDict(extra="allow")
class ToolConfig(BaseModel): class ToolConfig(BaseModel):
@@ -17,4 +17,4 @@ class ToolConfig(BaseModel):
..., ...,
description="Variable name of the tool provider(e.g. deerflow.sandbox.tools:bash_tool)", description="Variable name of the tool provider(e.g. deerflow.sandbox.tools:bash_tool)",
) )
model_config = ConfigDict(extra="allow", frozen=True) model_config = ConfigDict(extra="allow")
@@ -1,6 +1,6 @@
"""Configuration for deferred tool loading via tool_search.""" """Configuration for deferred tool loading via tool_search."""
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
class ToolSearchConfig(BaseModel): class ToolSearchConfig(BaseModel):
@@ -11,9 +11,25 @@ class ToolSearchConfig(BaseModel):
via the tool_search tool at runtime. via the tool_search tool at runtime.
""" """
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=False, default=False,
description="Defer tools and enable tool_search", description="Defer tools and enable tool_search",
) )
_tool_search_config: ToolSearchConfig | None = None
def get_tool_search_config() -> ToolSearchConfig:
"""Get the tool search config, loading from AppConfig if needed."""
global _tool_search_config
if _tool_search_config is None:
_tool_search_config = ToolSearchConfig()
return _tool_search_config
def load_tool_search_config_from_dict(data: dict) -> ToolSearchConfig:
"""Load tool search config from a dict (called during AppConfig loading)."""
global _tool_search_config
_tool_search_config = ToolSearchConfig.model_validate(data)
return _tool_search_config
@@ -1,7 +1,7 @@
import os import os
import threading import threading
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
_config_lock = threading.Lock() _config_lock = threading.Lock()
@@ -9,8 +9,6 @@ _config_lock = threading.Lock()
class LangSmithTracingConfig(BaseModel): class LangSmithTracingConfig(BaseModel):
"""Configuration for LangSmith tracing.""" """Configuration for LangSmith tracing."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(...) enabled: bool = Field(...)
api_key: str | None = Field(...) api_key: str | None = Field(...)
project: str = Field(...) project: str = Field(...)
@@ -28,8 +26,6 @@ class LangSmithTracingConfig(BaseModel):
class LangfuseTracingConfig(BaseModel): class LangfuseTracingConfig(BaseModel):
"""Configuration for Langfuse tracing.""" """Configuration for Langfuse tracing."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(...) enabled: bool = Field(...)
public_key: str | None = Field(...) public_key: str | None = Field(...)
secret_key: str | None = Field(...) secret_key: str | None = Field(...)
@@ -54,8 +50,6 @@ class LangfuseTracingConfig(BaseModel):
class TracingConfig(BaseModel): class TracingConfig(BaseModel):
"""Tracing configuration for supported providers.""" """Tracing configuration for supported providers."""
model_config = ConfigDict(frozen=True)
langsmith: LangSmithTracingConfig = Field(...) langsmith: LangSmithTracingConfig = Field(...)
langfuse: LangfuseTracingConfig = Field(...) langfuse: LangfuseTracingConfig = Field(...)
@@ -118,9 +118,13 @@ def get_cached_mcp_tools() -> list[BaseTool]:
loop.run_until_complete(initialize_mcp_tools()) loop.run_until_complete(initialize_mcp_tools())
except RuntimeError: except RuntimeError:
# No event loop exists, create one # No event loop exists, create one
asyncio.run(initialize_mcp_tools()) try:
except Exception as e: asyncio.run(initialize_mcp_tools())
logger.error(f"Failed to lazy-initialize MCP tools: {e}") except Exception:
logger.exception("Failed to lazy-initialize MCP tools")
return []
except Exception:
logger.exception("Failed to lazy-initialize MCP tools")
return [] return []
return _mcp_tools_cache or [] return _mcp_tools_cache or []
@@ -2,7 +2,7 @@ import logging
from langchain.chat_models import BaseChatModel from langchain.chat_models import BaseChatModel
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
from deerflow.reflection import resolve_class from deerflow.reflection import resolve_class
from deerflow.tracing import build_tracing_callbacks from deerflow.tracing import build_tracing_callbacks
@@ -39,7 +39,7 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
Returns: Returns:
A chat model instance. A chat model instance.
""" """
config = AppConfig.current() config = get_app_config()
if name is None: if name is None:
name = config.models[0].name name = config.models[0].name
model_config = config.get_model_config(name) model_config = config.get_model_config(name)
@@ -21,8 +21,6 @@ import inspect
import logging import logging
from typing import Any, Literal from typing import Any, Literal
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.runtime.serialization import serialize from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge from deerflow.runtime.stream_bridge import StreamBridge
@@ -100,14 +98,17 @@ async def run_agent(
# 3. Build the agent # 3. Build the agent
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
# Construct typed context for the agent run. # Inject runtime context so middlewares can access thread_id
# LangGraph's astream(context=...) injects this into Runtime.context # (langgraph-cli does this automatically; we must do it manually)
# so middleware/tools can access it via resolve_context(). runtime = Runtime(context={"thread_id": thread_id}, store=store)
deer_flow_context = DeerFlowContext( # If the caller already set a ``context`` key (LangGraph >= 0.6.0
app_config=AppConfig.current(), # prefers it over ``configurable`` for thread-level data), make
thread_id=thread_id, # sure ``thread_id`` is available there too.
) if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("thread_id", thread_id)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
runnable_config = RunnableConfig(**config) runnable_config = RunnableConfig(**config)
agent = agent_factory(config=runnable_config) agent = agent_factory(config=runnable_config)
@@ -154,7 +155,7 @@ async def run_agent(
if len(lg_modes) == 1 and not stream_subgraphs: if len(lg_modes) == 1 and not stream_subgraphs:
# Single mode, no subgraphs: astream yields raw chunks # Single mode, no subgraphs: astream yields raw chunks
single_mode = lg_modes[0] single_mode = lg_modes[0]
async for chunk in agent.astream(graph_input, config=runnable_config, context=deer_flow_context, stream_mode=single_mode): async for chunk in agent.astream(graph_input, config=runnable_config, stream_mode=single_mode):
if record.abort_event.is_set(): if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id) logger.info("Run %s abort requested — stopping", run_id)
break break
@@ -165,7 +166,6 @@ async def run_agent(
async for item in agent.astream( async for item in agent.astream(
graph_input, graph_input,
config=runnable_config, config=runnable_config,
context=deer_flow_context,
stream_mode=lg_modes, stream_mode=lg_modes,
subgraphs=stream_subgraphs, subgraphs=stream_subgraphs,
): ):
@@ -23,7 +23,7 @@ from collections.abc import AsyncIterator
from langgraph.store.base import BaseStore from langgraph.store.base import BaseStore
from deerflow.config.app_config import AppConfig from deerflow.config.app_config import get_app_config
from deerflow.runtime.store.provider import POSTGRES_CONN_REQUIRED, POSTGRES_STORE_INSTALL, SQLITE_STORE_INSTALL, ensure_sqlite_parent_dir, resolve_sqlite_conn_str from deerflow.runtime.store.provider import POSTGRES_CONN_REQUIRED, POSTGRES_STORE_INSTALL, SQLITE_STORE_INSTALL, ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -100,7 +100,7 @@ async def make_store() -> AsyncIterator[BaseStore]:
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
``checkpointer`` section is configured (emits a WARNING in that case). ``checkpointer`` section is configured (emits a WARNING in that case).
""" """
config = AppConfig.current() config = get_app_config()
if config.checkpointer is None: if config.checkpointer is None:
from langgraph.store.memory import InMemoryStore from langgraph.store.memory import InMemoryStore
@@ -26,7 +26,7 @@ from collections.abc import Iterator
from langgraph.store.base import BaseStore from langgraph.store.base import BaseStore
from deerflow.config.app_config import AppConfig from deerflow.config.app_config import get_app_config
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -115,10 +115,19 @@ def get_store() -> BaseStore:
if _store is not None: if _store is not None:
return _store return _store
try: # Lazily load app config, mirroring the checkpointer singleton pattern so
config = AppConfig.current().checkpointer # that tests that set the global checkpointer config explicitly remain isolated.
except (LookupError, FileNotFoundError): from deerflow.config.app_config import _app_config
config = None from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
try:
get_app_config()
except FileNotFoundError:
pass
config = get_checkpointer_config()
if config is None: if config is None:
from langgraph.store.memory import InMemoryStore from langgraph.store.memory import InMemoryStore
@@ -167,7 +176,7 @@ def store_context() -> Iterator[BaseStore]:
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
checkpointer is configured in *config.yaml*. checkpointer is configured in *config.yaml*.
""" """
config = AppConfig.current() config = get_app_config()
if config.checkpointer is None: if config.checkpointer is None:
from langgraph.store.memory import InMemoryStore from langgraph.store.memory import InMemoryStore
@@ -17,7 +17,7 @@ import contextlib
import logging import logging
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from deerflow.config.app_config import AppConfig from deerflow.config.stream_bridge_config import get_stream_bridge_config
from .base import StreamBridge from .base import StreamBridge
@@ -32,7 +32,7 @@ async def make_stream_bridge(config=None) -> AsyncIterator[StreamBridge]:
provided and nothing is set globally. provided and nothing is set globally.
""" """
if config is None: if config is None:
config = AppConfig.current().stream_bridge config = get_stream_bridge_config()
if config is None or config.type == "memory": if config is None or config.type == "memory":
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge
@@ -11,6 +11,8 @@ _singleton: LocalSandbox | None = None
class LocalSandboxProvider(SandboxProvider): class LocalSandboxProvider(SandboxProvider):
uses_thread_data_mounts = True
def __init__(self): def __init__(self):
"""Initialize the local sandbox provider with path mappings.""" """Initialize the local sandbox provider with path mappings."""
self._path_mappings = self._setup_path_mappings() self._path_mappings = self._setup_path_mappings()
@@ -29,9 +31,9 @@ class LocalSandboxProvider(SandboxProvider):
# Map skills container path to local skills directory # Map skills container path to local skills directory
try: try:
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
config = AppConfig.current() config = get_app_config()
skills_path = config.skills.get_skills_path() skills_path = config.skills.get_skills_path()
container_path = config.skills.container_path container_path = config.skills.container_path
@@ -6,7 +6,6 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.agents.thread_state import SandboxState, ThreadDataState from deerflow.agents.thread_state import SandboxState, ThreadDataState
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.sandbox import get_sandbox_provider from deerflow.sandbox import get_sandbox_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -50,15 +49,15 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
return sandbox_id return sandbox_id
@override @override
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None: def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
# Skip acquisition if lazy_init is enabled # Skip acquisition if lazy_init is enabled
if self._lazy_init: if self._lazy_init:
return super().before_agent(state, runtime) return super().before_agent(state, runtime)
# Eager initialization (original behavior) # Eager initialization (original behavior)
if "sandbox" not in state or state["sandbox"] is None: if "sandbox" not in state or state["sandbox"] is None:
thread_id = runtime.context.thread_id thread_id = (runtime.context or {}).get("thread_id")
if not thread_id: if thread_id is None:
return super().before_agent(state, runtime) return super().before_agent(state, runtime)
sandbox_id = self._acquire_sandbox(thread_id) sandbox_id = self._acquire_sandbox(thread_id)
logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}") logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}")
@@ -66,7 +65,7 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
return super().before_agent(state, runtime) return super().before_agent(state, runtime)
@override @override
def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None: def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
sandbox = state.get("sandbox") sandbox = state.get("sandbox")
if sandbox is not None: if sandbox is not None:
sandbox_id = sandbox["sandbox_id"] sandbox_id = sandbox["sandbox_id"]
@@ -74,5 +73,11 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
get_sandbox_provider().release(sandbox_id) get_sandbox_provider().release(sandbox_id)
return None return None
if (runtime.context or {}).get("sandbox_id") is not None:
sandbox_id = runtime.context.get("sandbox_id")
logger.info(f"Releasing sandbox {sandbox_id} from context")
get_sandbox_provider().release(sandbox_id)
return None
# No sandbox to release # No sandbox to release
return super().after_agent(state, runtime) return super().after_agent(state, runtime)
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
from deerflow.reflection import resolve_class from deerflow.reflection import resolve_class
from deerflow.sandbox.sandbox import Sandbox from deerflow.sandbox.sandbox import Sandbox
@@ -8,6 +8,8 @@ from deerflow.sandbox.sandbox import Sandbox
class SandboxProvider(ABC): class SandboxProvider(ABC):
"""Abstract base class for sandbox providers""" """Abstract base class for sandbox providers"""
uses_thread_data_mounts: bool = False
@abstractmethod @abstractmethod
def acquire(self, thread_id: str | None = None) -> str: def acquire(self, thread_id: str | None = None) -> str:
"""Acquire a sandbox environment and return its ID. """Acquire a sandbox environment and return its ID.
@@ -50,7 +52,7 @@ def get_sandbox_provider(**kwargs) -> SandboxProvider:
""" """
global _default_sandbox_provider global _default_sandbox_provider
if _default_sandbox_provider is None: if _default_sandbox_provider is None:
config = AppConfig.current() config = get_app_config()
cls = resolve_class(config.sandbox.use, SandboxProvider) cls = resolve_class(config.sandbox.use, SandboxProvider)
_default_sandbox_provider = cls(**kwargs) _default_sandbox_provider = cls(**kwargs)
return _default_sandbox_provider return _default_sandbox_provider
@@ -1,6 +1,6 @@
"""Security helpers for sandbox capability gating.""" """Security helpers for sandbox capability gating."""
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
_LOCAL_SANDBOX_PROVIDER_MARKERS = ( _LOCAL_SANDBOX_PROVIDER_MARKERS = (
"deerflow.sandbox.local:LocalSandboxProvider", "deerflow.sandbox.local:LocalSandboxProvider",
@@ -23,7 +23,7 @@ LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE = (
def uses_local_sandbox_provider(config=None) -> bool: def uses_local_sandbox_provider(config=None) -> bool:
"""Return True when the active sandbox provider is the host-local provider.""" """Return True when the active sandbox provider is the host-local provider."""
if config is None: if config is None:
config = AppConfig.current() config = get_app_config()
sandbox_cfg = getattr(config, "sandbox", None) sandbox_cfg = getattr(config, "sandbox", None)
sandbox_use = getattr(sandbox_cfg, "use", "") sandbox_use = getattr(sandbox_cfg, "use", "")
@@ -35,7 +35,7 @@ def uses_local_sandbox_provider(config=None) -> bool:
def is_host_bash_allowed(config=None) -> bool: def is_host_bash_allowed(config=None) -> bool:
"""Return whether host bash execution is explicitly allowed.""" """Return whether host bash execution is explicitly allowed."""
if config is None: if config is None:
config = AppConfig.current() config = get_app_config()
sandbox_cfg = getattr(config, "sandbox", None) sandbox_cfg = getattr(config, "sandbox", None)
if sandbox_cfg is None: if sandbox_cfg is None:
@@ -7,7 +7,7 @@ from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadDataState, ThreadState from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
from deerflow.config.paths import VIRTUAL_PATH_PREFIX from deerflow.config.paths import VIRTUAL_PATH_PREFIX
from deerflow.sandbox.exceptions import ( from deerflow.sandbox.exceptions import (
SandboxError, SandboxError,
@@ -50,7 +50,9 @@ def _get_skills_container_path() -> str:
if cached is not None: if cached is not None:
return cached return cached
try: try:
value = AppConfig.current().skills.container_path from deerflow.config import get_app_config
value = get_app_config().skills.container_path
_get_skills_container_path._cached = value # type: ignore[attr-defined] _get_skills_container_path._cached = value # type: ignore[attr-defined]
return value return value
except Exception: except Exception:
@@ -69,7 +71,9 @@ def _get_skills_host_path() -> str | None:
if cached is not None: if cached is not None:
return cached return cached
try: try:
config = AppConfig.current() from deerflow.config import get_app_config
config = get_app_config()
skills_path = config.skills.get_skills_path() skills_path = config.skills.get_skills_path()
if skills_path.exists(): if skills_path.exists():
value = str(skills_path) value = str(skills_path)
@@ -128,7 +132,9 @@ def _get_custom_mounts():
try: try:
from pathlib import Path from pathlib import Path
config = AppConfig.current() from deerflow.config import get_app_config
config = get_app_config()
mounts = [] mounts = []
if config.sandbox and config.sandbox.mounts: if config.sandbox and config.sandbox.mounts:
# Only include mounts whose host_path exists, consistent with # Only include mounts whose host_path exists, consistent with
@@ -268,7 +274,9 @@ def _get_mcp_allowed_paths() -> list[str]:
"""Get the list of allowed paths from MCP config for file system server.""" """Get the list of allowed paths from MCP config for file system server."""
allowed_paths = [] allowed_paths = []
try: try:
extensions_config = AppConfig.current().extensions from deerflow.config.extensions_config import get_extensions_config
extensions_config = get_extensions_config()
for _, server in extensions_config.mcp_servers.items(): for _, server in extensions_config.mcp_servers.items():
if not server.enabled: if not server.enabled:
@@ -293,7 +301,7 @@ def _get_mcp_allowed_paths() -> list[str]:
def _get_tool_config_int(name: str, key: str, default: int) -> int: def _get_tool_config_int(name: str, key: str, default: int) -> int:
try: try:
tool_config = AppConfig.current().get_tool_config(name) tool_config = get_app_config().get_tool_config(name)
if tool_config is not None and key in tool_config.model_extra: if tool_config is not None and key in tool_config.model_extra:
value = tool_config.model_extra.get(key) value = tool_config.model_extra.get(key)
if isinstance(value, int): if isinstance(value, int):
@@ -801,6 +809,8 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
if sandbox is None: if sandbox is None:
raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id) raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id)
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
return sandbox return sandbox
@@ -835,12 +845,16 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if sandbox_id is not None: if sandbox_id is not None:
sandbox = get_sandbox_provider().get(sandbox_id) sandbox = get_sandbox_provider().get(sandbox_id)
if sandbox is not None: if sandbox is not None:
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
return sandbox return sandbox
# Sandbox was released, fall through to acquire new one # Sandbox was released, fall through to acquire new one
# Lazy acquisition: get thread_id and acquire sandbox # Lazy acquisition: get thread_id and acquire sandbox
thread_id = runtime.context.thread_id thread_id = runtime.context.get("thread_id") if runtime.context else None
if not thread_id: if thread_id is None:
thread_id = runtime.config.get("configurable", {}).get("thread_id") if runtime.config else None
if thread_id is None:
raise SandboxRuntimeError("Thread ID not available in runtime context") raise SandboxRuntimeError("Thread ID not available in runtime context")
provider = get_sandbox_provider() provider = get_sandbox_provider()
@@ -854,6 +868,8 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if sandbox is None: if sandbox is None:
raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id) raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id)
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
return sandbox return sandbox
@@ -995,14 +1011,18 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
command = _apply_cwd_prefix(command, thread_data) command = _apply_cwd_prefix(command, thread_data)
output = sandbox.execute_command(command) output = sandbox.execute_command(command)
try: try:
sandbox_cfg = AppConfig.current().sandbox from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000 max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception: except Exception:
max_chars = 20000 max_chars = 20000
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars) return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
ensure_thread_directories_exist(runtime) ensure_thread_directories_exist(runtime)
try: try:
sandbox_cfg = AppConfig.current().sandbox from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000 max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception: except Exception:
max_chars = 20000 max_chars = 20000
@@ -1027,6 +1047,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
sandbox = ensure_sandbox_initialized(runtime) sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime) ensure_thread_directories_exist(runtime)
requested_path = path requested_path = path
thread_data = None
if is_local_sandbox(runtime): if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime) thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data, read_only=True) validate_local_tool_path(path, thread_data, read_only=True)
@@ -1041,8 +1062,12 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
if not children: if not children:
return "(empty)" return "(empty)"
output = "\n".join(children) output = "\n".join(children)
if thread_data is not None:
output = mask_local_paths_in_output(output, thread_data)
try: try:
sandbox_cfg = AppConfig.current().sandbox from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000 max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
except Exception: except Exception:
max_chars = 20000 max_chars = 20000
@@ -1213,7 +1238,9 @@ def read_file_tool(
if start_line is not None and end_line is not None: if start_line is not None and end_line is not None:
content = "\n".join(content.splitlines()[start_line - 1 : end_line]) content = "\n".join(content.splitlines()[start_line - 1 : end_line])
try: try:
sandbox_cfg = AppConfig.current().sandbox from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000 max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
except Exception: except Exception:
max_chars = 50000 max_chars = 50000
@@ -42,9 +42,9 @@ def load_skills(skills_path: Path | None = None, use_config: bool = True, enable
if skills_path is None: if skills_path is None:
if use_config: if use_config:
try: try:
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
config = AppConfig.current() config = get_app_config()
skills_path = config.skills.get_skills_path() skills_path = config.skills.get_skills_path()
except Exception: except Exception:
# Fallback to default if config fails # Fallback to default if config fails
@@ -9,7 +9,7 @@ from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
from deerflow.skills.loader import load_skills from deerflow.skills.loader import load_skills
from deerflow.skills.validation import _validate_skill_frontmatter from deerflow.skills.validation import _validate_skill_frontmatter
@@ -21,7 +21,7 @@ _SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
def get_skills_root_dir() -> Path: def get_skills_root_dir() -> Path:
return AppConfig.current().skills.get_skills_path() return get_app_config().skills.get_skills_path()
def get_public_skills_dir() -> Path: def get_public_skills_dir() -> Path:
@@ -7,7 +7,7 @@ import logging
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -47,7 +47,7 @@ async def scan_skill_content(content: str, *, executable: bool = False, location
prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----" prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----"
try: try:
config = AppConfig.current() config = get_app_config()
model_name = config.skill_evolution.moderation_model_name model_name = config.skill_evolution.moderation_model_name
model = create_chat_model(name=model_name, thinking_enabled=False) if model_name else create_chat_model(thinking_enabled=False) model = create_chat_model(name=model_name, thinking_enabled=False) if model_name else create_chat_model(thinking_enabled=False)
response = await model.ainvoke( response = await model.ainvoke(
@@ -23,10 +23,11 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
if config is None: if config is None:
return None return None
# Apply timeout override from config.yaml (lazy import to avoid circular deps) # Apply runtime overrides (timeout, max_turns, model) from config.yaml
from deerflow.config.app_config import AppConfig # Lazy import to avoid circular deps.
from deerflow.config.subagents_config import get_subagents_app_config
app_config = AppConfig.current().subagents app_config = get_subagents_app_config()
effective_timeout = app_config.get_timeout_for(name) effective_timeout = app_config.get_timeout_for(name)
effective_max_turns = app_config.get_max_turns_for(name, config.max_turns) effective_max_turns = app_config.get_max_turns_for(name, config.max_turns)
@@ -47,6 +48,15 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
effective_max_turns, effective_max_turns,
) )
overrides["max_turns"] = effective_max_turns overrides["max_turns"] = effective_max_turns
effective_model = app_config.get_model_for(name)
if effective_model is not None and effective_model != config.model:
logger.debug(
"Subagent '%s': model overridden by config.yaml (%s -> %s)",
name,
config.model,
effective_model,
)
overrides["model"] = effective_model
if overrides: if overrides:
config = replace(config, **overrides) config = replace(config, **overrides)
@@ -3,6 +3,7 @@ from typing import Annotated
from langchain.tools import InjectedToolCallId, ToolRuntime, tool from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langchain_core.messages import ToolMessage from langchain_core.messages import ToolMessage
from langgraph.config import get_config
from langgraph.types import Command from langgraph.types import Command
from langgraph.typing import ContextT from langgraph.typing import ContextT
@@ -12,6 +13,23 @@ from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs" OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None:
"""Resolve the current thread id from runtime context or RunnableConfig."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id:
return thread_id
runtime_config = getattr(runtime, "config", None) or {}
thread_id = runtime_config.get("configurable", {}).get("thread_id")
if thread_id:
return thread_id
try:
return get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
return None
def _normalize_presented_filepath( def _normalize_presented_filepath(
runtime: ToolRuntime[ContextT, ThreadState], runtime: ToolRuntime[ContextT, ThreadState],
filepath: str, filepath: str,
@@ -33,9 +51,9 @@ def _normalize_presented_filepath(
if runtime.state is None: if runtime.state is None:
raise ValueError("Thread runtime state is not available") raise ValueError("Thread runtime state is not available")
thread_id = runtime.context.thread_id thread_id = _get_thread_id(runtime)
if not thread_id: if not thread_id:
raise ValueError("Thread ID is not available in runtime context") raise ValueError("Thread ID is not available in runtime context or runtime config")
thread_data = runtime.state.get("thread_data") or {} thread_data = runtime.state.get("thread_data") or {}
outputs_path = thread_data.get("outputs_path") outputs_path = thread_data.get("outputs_path")
@@ -6,6 +6,7 @@ from langchain_core.tools import tool
from langgraph.prebuilt import ToolRuntime from langgraph.prebuilt import ToolRuntime
from langgraph.types import Command from langgraph.types import Command
from deerflow.config.agents_config import validate_agent_name
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -24,9 +25,11 @@ def setup_agent(
description: One-line description of what the agent does. description: One-line description of what the agent does.
""" """
agent_name: str | None = runtime.context.agent_name agent_name: str | None = runtime.context.get("agent_name") if runtime.context else None
agent_dir = None
try: try:
agent_name = validate_agent_name(agent_name)
paths = get_paths() paths = get_paths()
agent_dir = paths.agent_dir(agent_name) if agent_name else paths.base_dir agent_dir = paths.agent_dir(agent_name) if agent_name else paths.base_dir
agent_dir.mkdir(parents=True, exist_ok=True) agent_dir.mkdir(parents=True, exist_ok=True)
@@ -55,7 +58,7 @@ def setup_agent(
except Exception as e: except Exception as e:
import shutil import shutil
if agent_name and agent_dir.exists(): if agent_name and agent_dir is not None and agent_dir.exists():
# Cleanup the custom agent directory only if it was created but an error occurred during setup # Cleanup the custom agent directory only if it was created but an error occurred during setup
shutil.rmtree(agent_dir) shutil.rmtree(agent_dir)
logger.error(f"[agent_creator] Failed to create agent '{agent_name}': {e}", exc_info=True) logger.error(f"[agent_creator] Failed to create agent '{agent_name}': {e}", exc_info=True)
@@ -88,11 +88,14 @@ async def task_tool(
thread_id = None thread_id = None
parent_model = None parent_model = None
trace_id = None trace_id = None
metadata: dict = {}
if runtime is not None: if runtime is not None:
sandbox_state = runtime.state.get("sandbox") sandbox_state = runtime.state.get("sandbox")
thread_data = runtime.state.get("thread_data") thread_data = runtime.state.get("thread_data")
thread_id = runtime.context.thread_id thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
thread_id = runtime.config.get("configurable", {}).get("thread_id")
# Try to get parent model from configurable # Try to get parent model from configurable
metadata = runtime.config.get("metadata", {}) metadata = runtime.config.get("metadata", {})
@@ -105,8 +108,11 @@ async def task_tool(
# Lazy import to avoid circular dependency # Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools from deerflow.tools import get_available_tools
# Inherit parent agent's tool_groups so subagents respect the same restrictions
parent_tool_groups = metadata.get("tool_groups")
# Subagents should not have subagent tools enabled (prevent recursive nesting) # Subagents should not have subagent tools enabled (prevent recursive nesting)
tools = get_available_tools(model_name=parent_model, subagent_enabled=False) tools = get_available_tools(model_name=parent_model, groups=parent_tool_groups, subagent_enabled=False)
# Create executor # Create executor
executor = SubagentExecutor( executor = SubagentExecutor(
@@ -45,7 +45,9 @@ def _get_lock(name: str) -> asyncio.Lock:
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None: def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None:
if runtime is None: if runtime is None:
return None return None
return runtime.context.thread_id or None if runtime.context and runtime.context.get("thread_id"):
return runtime.context.get("thread_id")
return runtime.config.get("configurable", {}).get("thread_id")
def _history_record(*, action: str, file_path: str, prev_content: str | None, new_content: str | None, thread_id: str | None, scanner: dict[str, Any]) -> dict[str, Any]: def _history_record(*, action: str, file_path: str, prev_content: str | None, new_content: str | None, thread_id: str | None, scanner: dict[str, Any]) -> dict[str, Any]:
@@ -2,7 +2,7 @@ import logging
from langchain.tools import BaseTool from langchain.tools import BaseTool
from deerflow.config.app_config import AppConfig from deerflow.config import get_app_config
from deerflow.reflection import resolve_variable from deerflow.reflection import resolve_variable
from deerflow.sandbox.security import is_host_bash_allowed from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
@@ -52,7 +52,7 @@ def get_available_tools(
Returns: Returns:
List of available tools. List of available tools.
""" """
config = AppConfig.current() config = get_app_config()
tool_configs = [tool for tool in config.tools if groups is None or tool.group in groups] tool_configs = [tool for tool in config.tools if groups is None or tool.group in groups]
# Do not expose host bash by default when LocalSandboxProvider is active. # Do not expose host bash by default when LocalSandboxProvider is active.
@@ -123,9 +123,10 @@ def get_available_tools(
# Add invoke_acp_agent tool if any ACP agents are configured # Add invoke_acp_agent tool if any ACP agents are configured
acp_tools: list[BaseTool] = [] acp_tools: list[BaseTool] = []
try: try:
from deerflow.config.acp_config import get_acp_agents
from deerflow.tools.builtins.invoke_acp_agent_tool import build_invoke_acp_agent_tool from deerflow.tools.builtins.invoke_acp_agent_tool import build_invoke_acp_agent_tool
acp_agents = AppConfig.current().acp_agents acp_agents = get_acp_agents()
if acp_agents: if acp_agents:
acp_tools.append(build_invoke_acp_agent_tool(acp_agents)) acp_tools.append(build_invoke_acp_agent_tool(acp_agents))
logger.info(f"Including invoke_acp_agent tool ({len(acp_agents)} agent(s): {list(acp_agents.keys())})") logger.info(f"Including invoke_acp_agent tool ({len(acp_agents)} agent(s): {list(acp_agents.keys())})")
@@ -19,6 +19,8 @@ import logging
import re import re
from pathlib import Path from pathlib import Path
from deerflow.config.app_config import get_app_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# File extensions that should be converted to markdown # File extensions that should be converted to markdown
@@ -286,6 +288,15 @@ def extract_outline(md_path: Path) -> list[dict]:
return outline return outline
def _get_uploads_config_value(key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access."""
cfg = get_app_config()
uploads_cfg = getattr(cfg, "uploads", None)
if isinstance(uploads_cfg, dict):
return uploads_cfg.get(key, default)
return getattr(uploads_cfg, key, default)
def _get_pdf_converter() -> str: def _get_pdf_converter() -> str:
"""Read pdf_converter setting from app config, defaulting to 'auto'. """Read pdf_converter setting from app config, defaulting to 'auto'.
@@ -294,16 +305,11 @@ def _get_pdf_converter() -> str:
fall through to unexpected behaviour. fall through to unexpected behaviour.
""" """
try: try:
from deerflow.config.app_config import AppConfig raw = str(_get_uploads_config_value("pdf_converter", "auto")).strip().lower()
if raw not in _ALLOWED_PDF_CONVERTERS:
cfg = AppConfig.current() logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw)
uploads_cfg = getattr(cfg, "uploads", None) return "auto"
if uploads_cfg is not None: return raw
raw = str(getattr(uploads_cfg, "pdf_converter", "auto")).strip().lower()
if raw not in _ALLOWED_PDF_CONVERTERS:
logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw)
return "auto"
return raw
except Exception: except Exception:
pass pass
return "auto" return "auto"
+2 -2
View File
@@ -8,7 +8,7 @@ dependencies = [
"deerflow-harness", "deerflow-harness",
"fastapi>=0.115.0", "fastapi>=0.115.0",
"httpx>=0.28.0", "httpx>=0.28.0",
"python-multipart>=0.0.20", "python-multipart>=0.0.26",
"sse-starlette>=2.1.0", "sse-starlette>=2.1.0",
"uvicorn[standard]>=0.34.0", "uvicorn[standard]>=0.34.0",
"lark-oapi>=1.4.0", "lark-oapi>=1.4.0",
@@ -20,7 +20,7 @@ dependencies = [
] ]
[dependency-groups] [dependency-groups]
dev = ["pytest>=8.0.0", "ruff>=0.14.11"] dev = ["pytest>=9.0.3", "ruff>=0.14.11"]
[tool.uv.workspace] [tool.uv.workspace]
members = ["packages/harness"] members = ["packages/harness"]
+37 -27
View File
@@ -6,20 +6,17 @@ import pytest
import yaml import yaml
from pydantic import ValidationError from pydantic import ValidationError
from deerflow.config.acp_config import ACPAgentConfig from deerflow.config.acp_config import ACPAgentConfig, get_acp_agents, load_acp_config_from_dict
from deerflow.config.app_config import AppConfig from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
def _make_config(acp_agents: dict | None = None) -> AppConfig: def setup_function():
return AppConfig( """Reset ACP config before each test."""
sandbox=SandboxConfig(use="test"), load_acp_config_from_dict({})
acp_agents={name: ACPAgentConfig(**cfg) for name, cfg in (acp_agents or {}).items()},
)
def test_acp_agents_via_app_config(): def test_load_acp_config_sets_agents():
cfg = _make_config( load_acp_config_from_dict(
{ {
"claude_code": { "claude_code": {
"command": "claude-code-acp", "command": "claude-code-acp",
@@ -29,33 +26,39 @@ def test_acp_agents_via_app_config():
} }
} }
) )
agents = cfg.acp_agents agents = get_acp_agents()
assert "claude_code" in agents assert "claude_code" in agents
assert agents["claude_code"].command == "claude-code-acp" assert agents["claude_code"].command == "claude-code-acp"
assert agents["claude_code"].description == "Claude Code for coding tasks" assert agents["claude_code"].description == "Claude Code for coding tasks"
assert agents["claude_code"].model is None assert agents["claude_code"].model is None
def test_multiple_agents(): def test_load_acp_config_multiple_agents():
cfg = _make_config( load_acp_config_from_dict(
{ {
"claude_code": {"command": "claude-code-acp", "args": [], "description": "Claude Code"}, "claude_code": {"command": "claude-code-acp", "args": [], "description": "Claude Code"},
"codex": {"command": "codex-acp", "args": ["--flag"], "description": "Codex CLI"}, "codex": {"command": "codex-acp", "args": ["--flag"], "description": "Codex CLI"},
} }
) )
agents = cfg.acp_agents agents = get_acp_agents()
assert len(agents) == 2 assert len(agents) == 2
assert agents["codex"].args == ["--flag"] assert agents["codex"].args == ["--flag"]
def test_empty_acp_agents(): def test_load_acp_config_empty_clears_agents():
cfg = _make_config({}) load_acp_config_from_dict({"agent": {"command": "cmd", "args": [], "description": "desc"}})
assert cfg.acp_agents == {} assert len(get_acp_agents()) == 1
load_acp_config_from_dict({})
assert len(get_acp_agents()) == 0
def test_default_acp_agents_empty(): def test_load_acp_config_none_clears_agents():
cfg = AppConfig(sandbox=SandboxConfig(use="test")) load_acp_config_from_dict({"agent": {"command": "cmd", "args": [], "description": "desc"}})
assert cfg.acp_agents == {} assert len(get_acp_agents()) == 1
load_acp_config_from_dict(None)
assert get_acp_agents() == {}
def test_acp_agent_config_defaults(): def test_acp_agent_config_defaults():
@@ -76,8 +79,8 @@ def test_acp_agent_config_env_default_is_empty():
assert cfg.env == {} assert cfg.env == {}
def test_acp_agent_preserves_env(): def test_load_acp_config_preserves_env():
cfg = _make_config( load_acp_config_from_dict(
{ {
"codex": { "codex": {
"command": "codex-acp", "command": "codex-acp",
@@ -87,7 +90,8 @@ def test_acp_agent_preserves_env():
} }
} }
) )
assert cfg.acp_agents["codex"].env == {"OPENAI_API_KEY": "$OPENAI_API_KEY", "FOO": "bar"} cfg = get_acp_agents()["codex"]
assert cfg.env == {"OPENAI_API_KEY": "$OPENAI_API_KEY", "FOO": "bar"}
def test_acp_agent_config_with_model(): def test_acp_agent_config_with_model():
@@ -111,7 +115,13 @@ def test_acp_agent_config_missing_description_raises():
ACPAgentConfig(command="my-agent") ACPAgentConfig(command="my-agent")
def test_app_config_from_file_with_acp_agents(tmp_path, monkeypatch): def test_get_acp_agents_returns_empty_by_default():
"""After clearing, should return empty dict."""
load_acp_config_from_dict({})
assert get_acp_agents() == {}
def test_app_config_reload_without_acp_agents_clears_previous_state(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml" config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json" extensions_path = tmp_path / "extensions_config.json"
extensions_path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8") extensions_path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
@@ -147,9 +157,9 @@ def test_app_config_from_file_with_acp_agents(tmp_path, monkeypatch):
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path)) monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
config_path.write_text(yaml.safe_dump(config_with_acp), encoding="utf-8") config_path.write_text(yaml.safe_dump(config_with_acp), encoding="utf-8")
app = AppConfig.from_file(str(config_path)) AppConfig.from_file(str(config_path))
assert set(app.acp_agents) == {"codex"} assert set(get_acp_agents()) == {"codex"}
config_path.write_text(yaml.safe_dump(config_without_acp), encoding="utf-8") config_path.write_text(yaml.safe_dump(config_without_acp), encoding="utf-8")
app = AppConfig.from_file(str(config_path)) AppConfig.from_file(str(config_path))
assert app.acp_agents == {} assert get_acp_agents() == {}
+83 -33
View File
@@ -1,11 +1,13 @@
from __future__ import annotations from __future__ import annotations
import json import json
import os
from pathlib import Path from pathlib import Path
import yaml import yaml
from deerflow.config.app_config import AppConfig from deerflow.config.agents_api_config import get_agents_api_config
from deerflow.config.app_config import get_app_config, reset_app_config
def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> None: def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> None:
@@ -27,65 +29,113 @@ def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> No
) )
def _write_config_with_agents_api(
path: Path,
*,
model_name: str,
supports_thinking: bool,
agents_api: dict | None = None,
) -> None:
config = {
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
"models": [
{
"name": model_name,
"use": "langchain_openai:ChatOpenAI",
"model": "gpt-test",
"supports_thinking": supports_thinking,
}
],
}
if agents_api is not None:
config["agents_api"] = agents_api
path.write_text(yaml.safe_dump(config), encoding="utf-8")
def _write_extensions_config(path: Path) -> None: def _write_extensions_config(path: Path) -> None:
path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8") path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
def test_init_then_get(tmp_path, monkeypatch): def test_get_app_config_reloads_when_file_changes(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml" config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json" extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path) _write_extensions_config(extensions_path)
_write_config(config_path, model_name="test-model", supports_thinking=False) _write_config(config_path, model_name="first-model", supports_thinking=False)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path)) monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path)) monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
reset_app_config()
config = AppConfig.from_file(str(config_path)) try:
AppConfig.init(config) initial = get_app_config()
assert initial.models[0].supports_thinking is False
result = AppConfig.current() _write_config(config_path, model_name="first-model", supports_thinking=True)
assert result is config next_mtime = config_path.stat().st_mtime + 5
assert result.models[0].name == "test-model" os.utime(config_path, (next_mtime, next_mtime))
reloaded = get_app_config()
assert reloaded.models[0].supports_thinking is True
assert reloaded is not initial
finally:
reset_app_config()
def test_init_replaces_previous(tmp_path, monkeypatch): def test_get_app_config_reloads_when_config_path_changes(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml" config_a = tmp_path / "config-a.yaml"
config_b = tmp_path / "config-b.yaml"
extensions_path = tmp_path / "extensions_config.json" extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path) _write_extensions_config(extensions_path)
_write_config(config_path, model_name="model-a", supports_thinking=False) _write_config(config_a, model_name="model-a", supports_thinking=False)
_write_config(config_b, model_name="model-b", supports_thinking=True)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path)) monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_a))
reset_app_config()
config_a = AppConfig.from_file(str(config_path)) try:
AppConfig.init(config_a) first = get_app_config()
assert AppConfig.current().models[0].name == "model-a" assert first.models[0].name == "model-a"
_write_config(config_path, model_name="model-b", supports_thinking=True) monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_b))
config_b = AppConfig.from_file(str(config_path)) second = get_app_config()
AppConfig.init(config_b) assert second.models[0].name == "model-b"
assert AppConfig.current().models[0].name == "model-b" assert second is not first
assert AppConfig.current() is config_b finally:
reset_app_config()
def test_config_version_check(tmp_path, monkeypatch): def test_get_app_config_resets_agents_api_config_when_section_removed(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml" config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json" extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path) _write_extensions_config(extensions_path)
_write_config_with_agents_api(
config_path.write_text( config_path,
yaml.safe_dump( model_name="first-model",
{ supports_thinking=False,
"config_version": 1, agents_api={"enabled": True},
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
"models": [],
}
),
encoding="utf-8",
) )
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path)) monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path)) monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
reset_app_config()
config = AppConfig.from_file(str(config_path)) try:
assert config is not None initial = get_app_config()
assert initial.models[0].name == "first-model"
assert get_agents_api_config().enabled is True
_write_config_with_agents_api(
config_path,
model_name="first-model",
supports_thinking=False,
)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
reloaded = get_app_config()
assert reloaded is not initial
assert get_agents_api_config().enabled is False
finally:
reset_app_config()
+57
View File
@@ -0,0 +1,57 @@
from __future__ import annotations
import importlib.util
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[2]
CHECK_SCRIPT_PATH = REPO_ROOT / "scripts" / "check.py"
spec = importlib.util.spec_from_file_location("deerflow_check_script", CHECK_SCRIPT_PATH)
assert spec is not None
assert spec.loader is not None
check_script = importlib.util.module_from_spec(spec)
spec.loader.exec_module(check_script)
def test_find_pnpm_command_prefers_resolved_executable(monkeypatch):
def fake_which(name: str) -> str | None:
if name == "pnpm":
return r"C:\Users\tester\AppData\Roaming\npm\pnpm.CMD"
if name == "pnpm.cmd":
return r"C:\Users\tester\AppData\Roaming\npm\pnpm.cmd"
return None
monkeypatch.setattr(check_script.shutil, "which", fake_which)
assert check_script.find_pnpm_command() == [r"C:\Users\tester\AppData\Roaming\npm\pnpm.CMD"]
def test_find_pnpm_command_falls_back_to_corepack(monkeypatch):
def fake_which(name: str) -> str | None:
if name == "corepack":
return r"C:\Program Files\nodejs\corepack.exe"
return None
monkeypatch.setattr(check_script.shutil, "which", fake_which)
assert check_script.find_pnpm_command() == [
r"C:\Program Files\nodejs\corepack.exe",
"pnpm",
]
def test_find_pnpm_command_falls_back_to_corepack_cmd(monkeypatch):
def fake_which(name: str) -> str | None:
if name == "corepack":
return None
if name == "corepack.cmd":
return r"C:\Program Files\nodejs\corepack.cmd"
return None
monkeypatch.setattr(check_script.shutil, "which", fake_which)
assert check_script.find_pnpm_command() == [
r"C:\Program Files\nodejs\corepack.cmd",
"pnpm",
]
+193 -67
View File
@@ -5,21 +5,25 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import deerflow.config.app_config as app_config_module
from deerflow.agents.checkpointer import get_checkpointer, reset_checkpointer from deerflow.agents.checkpointer import get_checkpointer, reset_checkpointer
from deerflow.config.app_config import AppConfig from deerflow.config.checkpointer_config import (
from deerflow.config.checkpointer_config import CheckpointerConfig CheckpointerConfig,
from deerflow.config.sandbox_config import SandboxConfig get_checkpointer_config,
load_checkpointer_config_from_dict,
set_checkpointer_config,
def _make_config(checkpointer: CheckpointerConfig | None = None) -> AppConfig: )
return AppConfig(sandbox=SandboxConfig(use="test"), checkpointer=checkpointer)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_state(): def reset_state():
"""Reset singleton state before each test.""" """Reset singleton state before each test."""
app_config_module._app_config = None
set_checkpointer_config(None)
reset_checkpointer() reset_checkpointer()
yield yield
app_config_module._app_config = None
set_checkpointer_config(None)
reset_checkpointer() reset_checkpointer()
@@ -29,18 +33,24 @@ def reset_state():
class TestCheckpointerConfig: class TestCheckpointerConfig:
def test_memory_config(self): def test_load_memory_config(self):
config = CheckpointerConfig(type="memory") load_checkpointer_config_from_dict({"type": "memory"})
config = get_checkpointer_config()
assert config is not None
assert config.type == "memory" assert config.type == "memory"
assert config.connection_string is None assert config.connection_string is None
def test_sqlite_config(self): def test_load_sqlite_config(self):
config = CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db") load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
config = get_checkpointer_config()
assert config is not None
assert config.type == "sqlite" assert config.type == "sqlite"
assert config.connection_string == "/tmp/test.db" assert config.connection_string == "/tmp/test.db"
def test_postgres_config(self): def test_load_postgres_config(self):
config = CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db") load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
config = get_checkpointer_config()
assert config is not None
assert config.type == "postgres" assert config.type == "postgres"
assert config.connection_string == "postgresql://localhost/db" assert config.connection_string == "postgresql://localhost/db"
@@ -48,9 +58,14 @@ class TestCheckpointerConfig:
config = CheckpointerConfig(type="memory") config = CheckpointerConfig(type="memory")
assert config.connection_string is None assert config.connection_string is None
def test_set_config_to_none(self):
load_checkpointer_config_from_dict({"type": "memory"})
set_checkpointer_config(None)
assert get_checkpointer_config() is None
def test_invalid_type_raises(self): def test_invalid_type_raises(self):
with pytest.raises(Exception): with pytest.raises(Exception):
CheckpointerConfig(type="unknown") load_checkpointer_config_from_dict({"type": "unknown"})
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -63,78 +78,88 @@ class TestGetCheckpointer:
"""get_checkpointer should return InMemorySaver when not configured.""" """get_checkpointer should return InMemorySaver when not configured."""
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
with patch.object(AppConfig, "current", return_value=_make_config()): with patch("deerflow.agents.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
cp = get_checkpointer()
assert cp is not None
assert isinstance(cp, InMemorySaver)
def test_returns_in_memory_saver_when_config_not_found(self):
from langgraph.checkpoint.memory import InMemorySaver
with patch.object(AppConfig, "current", side_effect=FileNotFoundError):
cp = get_checkpointer() cp = get_checkpointer()
assert cp is not None assert cp is not None
assert isinstance(cp, InMemorySaver) assert isinstance(cp, InMemorySaver)
def test_memory_returns_in_memory_saver(self): def test_memory_returns_in_memory_saver(self):
load_checkpointer_config_from_dict({"type": "memory"})
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
cfg = _make_config(CheckpointerConfig(type="memory")) cp = get_checkpointer()
with patch.object(AppConfig, "current", return_value=cfg):
cp = get_checkpointer()
assert isinstance(cp, InMemorySaver) assert isinstance(cp, InMemorySaver)
def test_memory_singleton(self): def test_memory_singleton(self):
cfg = _make_config(CheckpointerConfig(type="memory")) load_checkpointer_config_from_dict({"type": "memory"})
with patch.object(AppConfig, "current", return_value=cfg): cp1 = get_checkpointer()
cp1 = get_checkpointer() cp2 = get_checkpointer()
cp2 = get_checkpointer()
assert cp1 is cp2 assert cp1 is cp2
def test_reset_clears_singleton(self): def test_reset_clears_singleton(self):
cfg = _make_config(CheckpointerConfig(type="memory")) load_checkpointer_config_from_dict({"type": "memory"})
with patch.object(AppConfig, "current", return_value=cfg): cp1 = get_checkpointer()
cp1 = get_checkpointer() reset_checkpointer()
reset_checkpointer() cp2 = get_checkpointer()
cp2 = get_checkpointer()
assert cp1 is not cp2 assert cp1 is not cp2
def test_sqlite_raises_when_package_missing(self): def test_sqlite_raises_when_package_missing(self):
cfg = _make_config(CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db")) load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
with ( with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}):
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": None}),
):
reset_checkpointer() reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"): with pytest.raises(ImportError, match="langgraph-checkpoint-sqlite"):
get_checkpointer() get_checkpointer()
def test_postgres_raises_when_package_missing(self): def test_postgres_raises_when_package_missing(self):
cfg = _make_config(CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db")) load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
with ( with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}):
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.postgres": None}),
):
reset_checkpointer() reset_checkpointer()
with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"): with pytest.raises(ImportError, match="langgraph-checkpoint-postgres"):
get_checkpointer() get_checkpointer()
def test_postgres_raises_when_connection_string_missing(self): def test_postgres_raises_when_connection_string_missing(self):
cfg = _make_config(CheckpointerConfig(type="postgres")) load_checkpointer_config_from_dict({"type": "postgres"})
mock_saver = MagicMock() mock_saver = MagicMock()
mock_module = MagicMock() mock_module = MagicMock()
mock_module.PostgresSaver = mock_saver mock_module.PostgresSaver = mock_saver
with ( with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}):
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_module}),
):
reset_checkpointer() reset_checkpointer()
with pytest.raises(ValueError, match="connection_string is required"): with pytest.raises(ValueError, match="connection_string is required"):
get_checkpointer() get_checkpointer()
def test_sqlite_creates_saver(self): def test_sqlite_creates_saver(self):
"""SQLite checkpointer is created when package is available.""" """SQLite checkpointer is created when package is available."""
cfg = _make_config(CheckpointerConfig(type="sqlite", connection_string="/tmp/test.db")) load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "/tmp/test.db"})
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
with patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}):
reset_checkpointer()
cp = get_checkpointer()
assert cp is mock_saver_instance
mock_saver_cls.from_conn_string.assert_called_once()
mock_saver_instance.setup.assert_called_once()
def test_sqlite_creates_parent_dir(self):
"""Sync SQLite checkpointer should call ensure_sqlite_parent_dir before connecting.
This mirrors the async checkpointer's behaviour and prevents
'sqlite3.OperationalError: unable to open database file' when the
parent directory for the database file does not yet exist (e.g. when
using the harness package from an external virtualenv where the
.deer-flow directory has not been created).
"""
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "relative/test.db"})
mock_saver_instance = MagicMock() mock_saver_instance = MagicMock()
mock_cm = MagicMock() mock_cm = MagicMock()
@@ -148,19 +173,59 @@ class TestGetCheckpointer:
mock_module.SqliteSaver = mock_saver_cls mock_module.SqliteSaver = mock_saver_cls
with ( with (
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}), patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
patch("deerflow.agents.checkpointer.provider.ensure_sqlite_parent_dir") as mock_ensure,
patch(
"deerflow.agents.checkpointer.provider.resolve_sqlite_conn_str",
return_value="/tmp/resolved/relative/test.db",
),
): ):
reset_checkpointer() reset_checkpointer()
cp = get_checkpointer() cp = get_checkpointer()
assert cp is mock_saver_instance assert cp is mock_saver_instance
mock_saver_cls.from_conn_string.assert_called_once() mock_ensure.assert_called_once_with("/tmp/resolved/relative/test.db")
mock_saver_instance.setup.assert_called_once() mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/relative/test.db")
def test_sqlite_ensure_parent_dir_before_connect(self):
"""ensure_sqlite_parent_dir must be called before from_conn_string."""
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "relative/test.db"})
call_order = []
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(side_effect=lambda *a, **kw: (call_order.append("connect"), mock_cm)[1])
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
def record_ensure(*a, **kw):
call_order.append("ensure")
with (
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
patch(
"deerflow.agents.checkpointer.provider.ensure_sqlite_parent_dir",
side_effect=record_ensure,
),
patch(
"deerflow.agents.checkpointer.provider.resolve_sqlite_conn_str",
return_value="/tmp/resolved/relative/test.db",
),
):
reset_checkpointer()
get_checkpointer()
assert call_order == ["ensure", "connect"]
def test_postgres_creates_saver(self): def test_postgres_creates_saver(self):
"""Postgres checkpointer is created when packages are available.""" """Postgres checkpointer is created when packages are available."""
cfg = _make_config(CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db")) load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
mock_saver_instance = MagicMock() mock_saver_instance = MagicMock()
mock_cm = MagicMock() mock_cm = MagicMock()
@@ -173,10 +238,7 @@ class TestGetCheckpointer:
mock_pg_module = MagicMock() mock_pg_module = MagicMock()
mock_pg_module.PostgresSaver = mock_saver_cls mock_pg_module.PostgresSaver = mock_saver_cls
with ( with patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}):
patch.object(AppConfig, "current", return_value=cfg),
patch.dict(sys.modules, {"langgraph.checkpoint.postgres": mock_pg_module}),
):
reset_checkpointer() reset_checkpointer()
cp = get_checkpointer() cp = get_checkpointer()
@@ -206,7 +268,7 @@ class TestAsyncCheckpointer:
mock_module.AsyncSqliteSaver = mock_saver_cls mock_module.AsyncSqliteSaver = mock_saver_cls
with ( with (
patch.object(AppConfig, "current", return_value=mock_config), patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config),
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}), patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
patch("deerflow.agents.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread, patch("deerflow.agents.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
patch( patch(
@@ -232,10 +294,12 @@ class TestAsyncCheckpointer:
class TestAppConfigLoadsCheckpointer: class TestAppConfigLoadsCheckpointer:
def test_load_checkpointer_section(self): def test_load_checkpointer_section(self):
"""AppConfig with checkpointer section has the correct config.""" """load_checkpointer_config_from_dict populates the global config."""
cfg = _make_config(CheckpointerConfig(type="memory")) set_checkpointer_config(None)
assert cfg.checkpointer is not None load_checkpointer_config_from_dict({"type": "memory"})
assert cfg.checkpointer.type == "memory" cfg = get_checkpointer_config()
assert cfg is not None
assert cfg.type == "memory"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -246,6 +310,68 @@ class TestAppConfigLoadsCheckpointer:
class TestClientCheckpointerFallback: class TestClientCheckpointerFallback:
def test_client_uses_config_checkpointer_when_none_provided(self): def test_client_uses_config_checkpointer_when_none_provided(self):
"""DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None.""" """DeerFlowClient._ensure_agent falls back to get_checkpointer() when checkpointer=None."""
# This is a structural test — verifying the fallback path exists. from langgraph.checkpoint.memory import InMemorySaver
cfg = _make_config(CheckpointerConfig(type="memory"))
assert cfg.checkpointer is not None from deerflow.client import DeerFlowClient
load_checkpointer_config_from_dict({"type": "memory"})
captured_kwargs = {}
def fake_create_agent(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
model_mock = MagicMock()
config_mock = MagicMock()
config_mock.models = [model_mock]
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
config_mock.checkpointer = None
with (
patch("deerflow.client.get_app_config", return_value=config_mock),
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value=""),
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
):
client = DeerFlowClient(checkpointer=None)
config = client._get_runnable_config("test-thread")
client._ensure_agent(config)
assert "checkpointer" in captured_kwargs
assert isinstance(captured_kwargs["checkpointer"], InMemorySaver)
def test_client_explicit_checkpointer_takes_precedence(self):
"""An explicitly provided checkpointer is used even when config checkpointer is set."""
from deerflow.client import DeerFlowClient
load_checkpointer_config_from_dict({"type": "memory"})
explicit_cp = MagicMock()
captured_kwargs = {}
def fake_create_agent(**kwargs):
captured_kwargs.update(kwargs)
return MagicMock()
model_mock = MagicMock()
config_mock = MagicMock()
config_mock.models = [model_mock]
config_mock.get_model_config.return_value = MagicMock(supports_vision=False)
config_mock.checkpointer = None
with (
patch("deerflow.client.get_app_config", return_value=config_mock),
patch("deerflow.client.create_agent", side_effect=fake_create_agent),
patch("deerflow.client.create_chat_model", return_value=MagicMock()),
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value=""),
patch("deerflow.client.DeerFlowClient._get_tools", return_value=[]),
):
client = DeerFlowClient(checkpointer=explicit_cp)
config = client._get_runnable_config("test-thread")
client._ensure_agent(config)
assert captured_kwargs["checkpointer"] is explicit_cp
+4 -6
View File
@@ -5,8 +5,6 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
from deerflow.config.app_config import AppConfig
class TestCheckpointerNoneFix: class TestCheckpointerNoneFix:
"""Tests that checkpointer context managers return InMemorySaver instead of None.""" """Tests that checkpointer context managers return InMemorySaver instead of None."""
@@ -16,11 +14,11 @@ class TestCheckpointerNoneFix:
"""make_checkpointer should return InMemorySaver when config.checkpointer is None.""" """make_checkpointer should return InMemorySaver when config.checkpointer is None."""
from deerflow.agents.checkpointer.async_provider import make_checkpointer from deerflow.agents.checkpointer.async_provider import make_checkpointer
# Mock AppConfig.get to return a config with checkpointer=None # Mock get_app_config to return a config with checkpointer=None
mock_config = MagicMock() mock_config = MagicMock()
mock_config.checkpointer = None mock_config.checkpointer = None
with patch.object(AppConfig, "current", return_value=mock_config): with patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config):
async with make_checkpointer() as checkpointer: async with make_checkpointer() as checkpointer:
# Should return InMemorySaver, not None # Should return InMemorySaver, not None
assert checkpointer is not None assert checkpointer is not None
@@ -39,11 +37,11 @@ class TestCheckpointerNoneFix:
"""checkpointer_context should return InMemorySaver when config.checkpointer is None.""" """checkpointer_context should return InMemorySaver when config.checkpointer is None."""
from deerflow.agents.checkpointer.provider import checkpointer_context from deerflow.agents.checkpointer.provider import checkpointer_context
# Mock AppConfig.get to return a config with checkpointer=None # Mock get_app_config to return a config with checkpointer=None
mock_config = MagicMock() mock_config = MagicMock()
mock_config.checkpointer = None mock_config.checkpointer = None
with patch.object(AppConfig, "current", return_value=mock_config): with patch("deerflow.agents.checkpointer.provider.get_app_config", return_value=mock_config):
with checkpointer_context() as checkpointer: with checkpointer_context() as checkpointer:
# Should return InMemorySaver, not None # Should return InMemorySaver, not None
assert checkpointer is not None assert checkpointer is not None
+57 -69
View File
@@ -18,7 +18,6 @@ from app.gateway.routers.models import ModelResponse, ModelsListResponse
from app.gateway.routers.skills import SkillInstallResponse, SkillResponse, SkillsListResponse from app.gateway.routers.skills import SkillInstallResponse, SkillResponse, SkillsListResponse
from app.gateway.routers.uploads import UploadResponse from app.gateway.routers.uploads import UploadResponse
from deerflow.client import DeerFlowClient from deerflow.client import DeerFlowClient
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import Paths from deerflow.config.paths import Paths
from deerflow.uploads.manager import PathTraversalError from deerflow.uploads.manager import PathTraversalError
@@ -39,13 +38,14 @@ def mock_app_config():
config = MagicMock() config = MagicMock()
config.models = [model] config.models = [model]
config.token_usage.enabled = False
return config return config
@pytest.fixture @pytest.fixture
def client(mock_app_config): def client(mock_app_config):
"""Create a DeerFlowClient with mocked config loading.""" """Create a DeerFlowClient with mocked config loading."""
with patch.object(AppConfig, "current", return_value=mock_app_config): with patch("deerflow.client.get_app_config", return_value=mock_app_config):
return DeerFlowClient() return DeerFlowClient()
@@ -67,7 +67,7 @@ class TestClientInit:
def test_custom_params(self, mock_app_config): def test_custom_params(self, mock_app_config):
mock_middleware = MagicMock() mock_middleware = MagicMock()
with patch.object(AppConfig, "current", return_value=mock_app_config): with patch("deerflow.client.get_app_config", return_value=mock_app_config):
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware]) c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
assert c._model_name == "gpt-4" assert c._model_name == "gpt-4"
assert c._thinking_enabled is False assert c._thinking_enabled is False
@@ -78,7 +78,7 @@ class TestClientInit:
assert c._middlewares == [mock_middleware] assert c._middlewares == [mock_middleware]
def test_invalid_agent_name(self, mock_app_config): def test_invalid_agent_name(self, mock_app_config):
with patch.object(AppConfig, "current", return_value=mock_app_config): with patch("deerflow.client.get_app_config", return_value=mock_app_config):
with pytest.raises(ValueError, match="Invalid agent name"): with pytest.raises(ValueError, match="Invalid agent name"):
DeerFlowClient(agent_name="invalid name with spaces!") DeerFlowClient(agent_name="invalid name with spaces!")
with pytest.raises(ValueError, match="Invalid agent name"): with pytest.raises(ValueError, match="Invalid agent name"):
@@ -86,17 +86,15 @@ class TestClientInit:
def test_custom_config_path(self, mock_app_config): def test_custom_config_path(self, mock_app_config):
with ( with (
patch.object(AppConfig, "from_file", return_value=mock_app_config) as mock_from_file, patch("deerflow.client.reload_app_config") as mock_reload,
patch.object(AppConfig, "init") as mock_init, patch("deerflow.client.get_app_config", return_value=mock_app_config),
patch.object(AppConfig, "current", return_value=mock_app_config),
): ):
DeerFlowClient(config_path="/tmp/custom.yaml") DeerFlowClient(config_path="/tmp/custom.yaml")
mock_from_file.assert_called_once_with("/tmp/custom.yaml") mock_reload.assert_called_once_with("/tmp/custom.yaml")
mock_init.assert_called_once_with(mock_app_config)
def test_checkpointer_stored(self, mock_app_config): def test_checkpointer_stored(self, mock_app_config):
cp = MagicMock() cp = MagicMock()
with patch.object(AppConfig, "current", return_value=mock_app_config): with patch("deerflow.client.get_app_config", return_value=mock_app_config):
c = DeerFlowClient(checkpointer=cp) c = DeerFlowClient(checkpointer=cp)
assert c._checkpointer is cp assert c._checkpointer is cp
@@ -110,6 +108,7 @@ class TestConfigQueries:
def test_list_models(self, client): def test_list_models(self, client):
result = client.list_models() result = client.list_models()
assert "models" in result assert "models" in result
assert result["token_usage"] == {"enabled": False}
assert len(result["models"]) == 1 assert len(result["models"]) == 1
assert result["models"][0]["name"] == "test-model" assert result["models"][0]["name"] == "test-model"
# Verify Gateway-aligned fields are present # Verify Gateway-aligned fields are present
@@ -252,8 +251,8 @@ class TestStream:
# Verify context passed to agent.stream # Verify context passed to agent.stream
agent.stream.assert_called_once() agent.stream.assert_called_once()
call_kwargs = agent.stream.call_args.kwargs call_kwargs = agent.stream.call_args.kwargs
ctx = call_kwargs["context"] assert call_kwargs["context"]["thread_id"] == "t1"
assert ctx.app_config is client._app_config assert call_kwargs["context"]["agent_name"] == "test-agent-1"
def test_custom_mode_is_normalized_to_string(self, client): def test_custom_mode_is_normalized_to_string(self, client):
"""stream() forwards custom events even when the mode is not a plain string.""" """stream() forwards custom events even when the mode is not a plain string."""
@@ -1092,7 +1091,7 @@ class TestMcpConfig:
ext_config = MagicMock() ext_config = MagicMock()
ext_config.mcp_servers = {"github": server} ext_config.mcp_servers = {"github": server}
with patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)): with patch("deerflow.client.get_extensions_config", return_value=ext_config):
result = client.get_mcp_config() result = client.get_mcp_config()
assert "mcp_servers" in result assert "mcp_servers" in result
@@ -1117,12 +1116,10 @@ class TestMcpConfig:
# Pre-set agent to verify it gets invalidated # Pre-set agent to verify it gets invalidated
client._agent = MagicMock() client._agent = MagicMock()
# Set initial AppConfig with current extensions
AppConfig.init(MagicMock(extensions=current_config))
with ( with (
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path), patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded_config)), patch("deerflow.client.get_extensions_config", return_value=current_config),
patch("deerflow.client.reload_extensions_config", return_value=reloaded_config),
): ):
result = client.update_mcp_config({"new-server": {"enabled": True, "type": "sse"}}) result = client.update_mcp_config({"new-server": {"enabled": True, "type": "sse"}})
@@ -1184,8 +1181,8 @@ class TestSkillsManagement:
with ( with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated_skill]]), patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated_skill]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path), patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=tmp_path),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)), patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()), patch("deerflow.client.reload_extensions_config"),
): ):
result = client.update_skill("test-skill", enabled=False) result = client.update_skill("test-skill", enabled=False)
assert result["enabled"] is False assert result["enabled"] is False
@@ -1316,40 +1313,35 @@ class TestMemoryManagement:
assert result == data assert result == data
def test_get_memory_config(self, client): def test_get_memory_config(self, client):
mem_config = MagicMock() config = MagicMock()
mem_config.enabled = True config.enabled = True
mem_config.storage_path = ".deer-flow/memory.json" config.storage_path = ".deer-flow/memory.json"
mem_config.debounce_seconds = 30 config.debounce_seconds = 30
mem_config.max_facts = 100 config.max_facts = 100
mem_config.fact_confidence_threshold = 0.7 config.fact_confidence_threshold = 0.7
mem_config.injection_enabled = True config.injection_enabled = True
mem_config.max_injection_tokens = 2000 config.max_injection_tokens = 2000
app_cfg = MagicMock() with patch("deerflow.config.memory_config.get_memory_config", return_value=config):
app_cfg.memory = mem_config
with patch.object(AppConfig, "current", return_value=app_cfg):
result = client.get_memory_config() result = client.get_memory_config()
assert result["enabled"] is True assert result["enabled"] is True
assert result["max_facts"] == 100 assert result["max_facts"] == 100
def test_get_memory_status(self, client): def test_get_memory_status(self, client):
mem_config = MagicMock() config = MagicMock()
mem_config.enabled = True config.enabled = True
mem_config.storage_path = ".deer-flow/memory.json" config.storage_path = ".deer-flow/memory.json"
mem_config.debounce_seconds = 30 config.debounce_seconds = 30
mem_config.max_facts = 100 config.max_facts = 100
mem_config.fact_confidence_threshold = 0.7 config.fact_confidence_threshold = 0.7
mem_config.injection_enabled = True config.injection_enabled = True
mem_config.max_injection_tokens = 2000 config.max_injection_tokens = 2000
app_cfg = MagicMock()
app_cfg.memory = mem_config
data = {"version": "1.0", "facts": []} data = {"version": "1.0", "facts": []}
with ( with (
patch.object(AppConfig, "current", return_value=app_cfg), patch("deerflow.config.memory_config.get_memory_config", return_value=config),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=data), patch("deerflow.agents.memory.updater.get_memory_data", return_value=data),
): ):
result = client.get_memory_status() result = client.get_memory_status()
@@ -1793,10 +1785,10 @@ class TestScenarioConfigManagement:
reloaded_config.mcp_servers = {"my-mcp": reloaded_server} reloaded_config.mcp_servers = {"my-mcp": reloaded_server}
client._agent = MagicMock() # Simulate existing agent client._agent = MagicMock() # Simulate existing agent
AppConfig.init(MagicMock(extensions=current_config))
with ( with (
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file), patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded_config)), patch("deerflow.client.get_extensions_config", return_value=current_config),
patch("deerflow.client.reload_extensions_config", return_value=reloaded_config),
): ):
mcp_result = client.update_mcp_config({"my-mcp": {"enabled": True}}) mcp_result = client.update_mcp_config({"my-mcp": {"enabled": True}})
assert "my-mcp" in mcp_result["mcp_servers"] assert "my-mcp" in mcp_result["mcp_servers"]
@@ -1825,8 +1817,8 @@ class TestScenarioConfigManagement:
with ( with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [toggled]]), patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [toggled]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file), patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)), patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()), patch("deerflow.client.reload_extensions_config"),
): ):
skill_result = client.update_skill("code-gen", enabled=False) skill_result = client.update_skill("code-gen", enabled=False)
assert skill_result["enabled"] is False assert skill_result["enabled"] is False
@@ -2011,10 +2003,8 @@ class TestScenarioMemoryWorkflow:
refreshed = client.reload_memory() refreshed = client.reload_memory()
assert len(refreshed["facts"]) == 2 assert len(refreshed["facts"]) == 2
app_cfg = MagicMock()
app_cfg.memory = config
with ( with (
patch.object(AppConfig, "current", return_value=app_cfg), patch("deerflow.config.memory_config.get_memory_config", return_value=config),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=updated_data), patch("deerflow.agents.memory.updater.get_memory_data", return_value=updated_data),
): ):
status = client.get_memory_status() status = client.get_memory_status()
@@ -2077,8 +2067,8 @@ class TestScenarioSkillInstallAndUse:
with ( with (
patch("deerflow.skills.loader.load_skills", side_effect=[[installed_skill], [disabled_skill]]), patch("deerflow.skills.loader.load_skills", side_effect=[[installed_skill], [disabled_skill]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file), patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)), patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()), patch("deerflow.client.reload_extensions_config"),
): ):
toggled = client.update_skill("my-analyzer", enabled=False) toggled = client.update_skill("my-analyzer", enabled=False)
assert toggled["enabled"] is False assert toggled["enabled"] is False
@@ -2208,9 +2198,11 @@ class TestGatewayConformance:
model.display_name = "Test Model" model.display_name = "Test Model"
model.description = "A test model" model.description = "A test model"
model.supports_thinking = False model.supports_thinking = False
model.supports_reasoning_effort = False
mock_app_config.models = [model] mock_app_config.models = [model]
mock_app_config.token_usage.enabled = True
with patch.object(AppConfig, "current", return_value=mock_app_config): with patch("deerflow.client.get_app_config", return_value=mock_app_config):
client = DeerFlowClient() client = DeerFlowClient()
result = client.list_models() result = client.list_models()
@@ -2218,6 +2210,7 @@ class TestGatewayConformance:
assert len(parsed.models) == 1 assert len(parsed.models) == 1
assert parsed.models[0].name == "test-model" assert parsed.models[0].name == "test-model"
assert parsed.models[0].model == "gpt-test" assert parsed.models[0].model == "gpt-test"
assert parsed.token_usage.enabled is True
def test_get_model(self, mock_app_config): def test_get_model(self, mock_app_config):
model = MagicMock() model = MagicMock()
@@ -2229,7 +2222,7 @@ class TestGatewayConformance:
mock_app_config.models = [model] mock_app_config.models = [model]
mock_app_config.get_model_config.return_value = model mock_app_config.get_model_config.return_value = model
with patch.object(AppConfig, "current", return_value=mock_app_config): with patch("deerflow.client.get_app_config", return_value=mock_app_config):
client = DeerFlowClient() client = DeerFlowClient()
result = client.get_model("test-model") result = client.get_model("test-model")
@@ -2299,7 +2292,7 @@ class TestGatewayConformance:
ext_config = MagicMock() ext_config = MagicMock()
ext_config.mcp_servers = {"test": server} ext_config.mcp_servers = {"test": server}
with patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)): with patch("deerflow.client.get_extensions_config", return_value=ext_config):
result = client.get_mcp_config() result = client.get_mcp_config()
parsed = McpConfigResponse(**result) parsed = McpConfigResponse(**result)
@@ -2325,9 +2318,9 @@ class TestGatewayConformance:
config_file.write_text("{}") config_file.write_text("{}")
with ( with (
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)), patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file), patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=ext_config)), patch("deerflow.client.reload_extensions_config", return_value=ext_config),
): ):
result = client.update_mcp_config({"srv": server.model_dump.return_value}) result = client.update_mcp_config({"srv": server.model_dump.return_value})
@@ -2358,10 +2351,7 @@ class TestGatewayConformance:
mem_cfg.injection_enabled = True mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000 mem_cfg.max_injection_tokens = 2000
app_cfg = MagicMock() with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg):
app_cfg.memory = mem_cfg
with patch.object(AppConfig, "current", return_value=app_cfg):
result = client.get_memory_config() result = client.get_memory_config()
parsed = MemoryConfigResponse(**result) parsed = MemoryConfigResponse(**result)
@@ -2378,8 +2368,6 @@ class TestGatewayConformance:
mem_cfg.injection_enabled = True mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000 mem_cfg.max_injection_tokens = 2000
app_cfg = MagicMock()
app_cfg.memory = mem_cfg
memory_data = { memory_data = {
"version": "1.0", "version": "1.0",
"lastUpdated": "", "lastUpdated": "",
@@ -2397,7 +2385,7 @@ class TestGatewayConformance:
} }
with ( with (
patch.object(AppConfig, "current", return_value=app_cfg), patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=memory_data), patch("deerflow.agents.memory.updater.get_memory_data", return_value=memory_data),
): ):
result = client.get_memory_status() result = client.get_memory_status()
@@ -2688,8 +2676,8 @@ class TestConfigUpdateErrors:
with ( with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], []]), patch("deerflow.skills.loader.load_skills", side_effect=[[skill], []]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file), patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)), patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()), patch("deerflow.client.reload_extensions_config"),
): ):
with pytest.raises(RuntimeError, match="disappeared"): with pytest.raises(RuntimeError, match="disappeared"):
client.update_skill("ghost-skill", enabled=False) client.update_skill("ghost-skill", enabled=False)
@@ -3059,10 +3047,10 @@ class TestBugAgentInvalidationInconsistency:
config_file = Path(tmp) / "ext.json" config_file = Path(tmp) / "ext.json"
config_file.write_text("{}") config_file.write_text("{}")
AppConfig.init(MagicMock(extensions=current_config))
with ( with (
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file), patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock(extensions=reloaded)), patch("deerflow.client.get_extensions_config", return_value=current_config),
patch("deerflow.client.reload_extensions_config", return_value=reloaded),
): ):
client.update_mcp_config({}) client.update_mcp_config({})
@@ -3094,8 +3082,8 @@ class TestBugAgentInvalidationInconsistency:
with ( with (
patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated]]), patch("deerflow.skills.loader.load_skills", side_effect=[[skill], [updated]]),
patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file), patch("deerflow.client.ExtensionsConfig.resolve_config_path", return_value=config_file),
patch.object(AppConfig, "current", return_value=MagicMock(extensions=ext_config)), patch("deerflow.client.get_extensions_config", return_value=ext_config),
patch("deerflow.config.app_config.AppConfig.from_file", return_value=MagicMock()), patch("deerflow.client.reload_extensions_config"),
): ):
client.update_skill("s1", enabled=False) client.update_skill("s1", enabled=False)
-73
View File
@@ -1,73 +0,0 @@
"""Verify that all sub-config Pydantic models are frozen (immutable).
Frozen models reject attribute assignment after construction, raising
pydantic.ValidationError. This test collects every BaseModel subclass
defined in the deerflow.config package and asserts that mutation is
blocked.
"""
import inspect
import pkgutil
import pytest
from pydantic import BaseModel, ValidationError
import deerflow.config as config_pkg
def _collect_config_models() -> list[type[BaseModel]]:
"""Walk deerflow.config.* and return all concrete BaseModel subclasses."""
import importlib
models: list[type[BaseModel]] = []
package_path = config_pkg.__path__
package_prefix = config_pkg.__name__ + "."
for _importer, modname, _ispkg in pkgutil.walk_packages(package_path, prefix=package_prefix):
try:
mod = importlib.import_module(modname)
except Exception:
continue
for _name, obj in inspect.getmembers(mod, inspect.isclass):
if (
issubclass(obj, BaseModel)
and obj is not BaseModel
and obj.__module__ == mod.__name__
):
models.append(obj)
return models
_EXCLUDED: set[str] = set()
_ALL_MODELS = [m for m in _collect_config_models() if m.__name__ not in _EXCLUDED]
# Sanity: make sure we actually collected a meaningful set.
assert len(_ALL_MODELS) >= 15, f"Expected at least 15 config models, found {len(_ALL_MODELS)}: {[m.__name__ for m in _ALL_MODELS]}"
@pytest.mark.parametrize("model_cls", _ALL_MODELS, ids=lambda cls: cls.__name__)
def test_config_model_is_frozen(model_cls: type[BaseModel]):
"""Every sub-config model must have frozen=True in its model_config."""
cfg = model_cls.model_config
assert cfg.get("frozen") is True, (
f"{model_cls.__name__} is not frozen. "
f"Add `model_config = ConfigDict(frozen=True)` or add `frozen=True` to the existing ConfigDict."
)
@pytest.mark.parametrize("model_cls", _ALL_MODELS, ids=lambda cls: cls.__name__)
def test_config_model_rejects_mutation(model_cls: type[BaseModel]):
"""Constructing then mutating any field must raise ValidationError."""
# Build a minimal instance -- use model_construct to skip validation for
# required fields, then pick the first field to try mutating.
fields = list(model_cls.model_fields.keys())
if not fields:
pytest.skip(f"{model_cls.__name__} has no fields")
instance = model_cls.model_construct()
first_field = fields[0]
with pytest.raises(ValidationError):
setattr(instance, first_field, "MUTATED")
+70 -11
View File
@@ -3,13 +3,13 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock, patch from unittest.mock import patch
import pytest import pytest
import yaml import yaml
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from deerflow.config.app_config import AppConfig from deerflow.config.agents_api_config import AgentsApiConfig, get_agents_api_config, set_agents_api_config
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
@@ -333,7 +333,7 @@ class TestMemoryFilePath:
with ( with (
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))), patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
): ):
storage = FileMemoryStorage() storage = FileMemoryStorage()
path = storage._get_memory_file_path(None) path = storage._get_memory_file_path(None)
@@ -346,7 +346,7 @@ class TestMemoryFilePath:
with ( with (
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))), patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
): ):
storage = FileMemoryStorage() storage = FileMemoryStorage()
path = storage._get_memory_file_path("code-reviewer") path = storage._get_memory_file_path("code-reviewer")
@@ -358,7 +358,7 @@ class TestMemoryFilePath:
with ( with (
patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)),
patch.object(AppConfig, "current", return_value=MagicMock(memory=MemoryConfig(storage_path=""))), patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")),
): ):
storage = FileMemoryStorage() storage = FileMemoryStorage()
path_global = storage._get_memory_file_path(None) path_global = storage._get_memory_file_path(None)
@@ -389,13 +389,38 @@ def _make_test_app(tmp_path: Path):
@pytest.fixture() @pytest.fixture()
def agent_client(tmp_path): def agent_client(tmp_path):
"""TestClient with agents router, using tmp_path as base_dir.""" """TestClient with agents router, using tmp_path as base_dir."""
paths_instance = _make_paths(tmp_path) import app.gateway.routers.agents as agents_router
with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch("app.gateway.routers.agents.get_paths", return_value=paths_instance): paths_instance = _make_paths(tmp_path)
app = _make_test_app(tmp_path) previous_config = AgentsApiConfig(**get_agents_api_config().model_dump())
with TestClient(app) as client:
client._tmp_path = tmp_path # type: ignore[attr-defined] with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch.object(agents_router, "get_paths", return_value=paths_instance):
yield client set_agents_api_config(AgentsApiConfig(enabled=True))
try:
app = _make_test_app(tmp_path)
with TestClient(app) as client:
client._tmp_path = tmp_path # type: ignore[attr-defined]
yield client
finally:
set_agents_api_config(previous_config)
@pytest.fixture()
def disabled_agent_client(tmp_path):
"""TestClient with agents router while the management API is disabled."""
import app.gateway.routers.agents as agents_router
paths_instance = _make_paths(tmp_path)
previous_config = AgentsApiConfig(**get_agents_api_config().model_dump())
with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch.object(agents_router, "get_paths", return_value=paths_instance):
set_agents_api_config(AgentsApiConfig(enabled=False))
try:
app = _make_test_app(tmp_path)
with TestClient(app) as client:
yield client
finally:
set_agents_api_config(previous_config)
class TestAgentsAPI: class TestAgentsAPI:
@@ -561,3 +586,37 @@ class TestUserProfileAPI:
response = agent_client.put("/api/user-profile", json={"content": ""}) response = agent_client.put("/api/user-profile", json={"content": ""})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["content"] is None assert response.json()["content"] is None
class TestAgentsApiDisabled:
def test_agents_list_returns_403(self, disabled_agent_client):
response = disabled_agent_client.get("/api/agents")
assert response.status_code == 403
assert "agents_api.enabled=true" in response.json()["detail"]
def test_agent_get_returns_403(self, disabled_agent_client):
response = disabled_agent_client.get("/api/agents/example-agent")
assert response.status_code == 403
def test_agent_name_check_returns_403(self, disabled_agent_client):
response = disabled_agent_client.get("/api/agents/check", params={"name": "example-agent"})
assert response.status_code == 403
def test_agent_create_returns_403(self, disabled_agent_client):
response = disabled_agent_client.post("/api/agents", json={"name": "example-agent", "soul": "blocked"})
assert response.status_code == 403
def test_agent_update_returns_403(self, disabled_agent_client):
response = disabled_agent_client.put("/api/agents/example-agent", json={"description": "blocked"})
assert response.status_code == 403
def test_agent_delete_returns_403(self, disabled_agent_client):
response = disabled_agent_client.delete("/api/agents/example-agent")
assert response.status_code == 403
def test_user_profile_routes_return_403(self, disabled_agent_client):
get_response = disabled_agent_client.get("/api/user-profile")
put_response = disabled_agent_client.put("/api/user-profile", json={"content": "blocked"})
assert get_response.status_code == 403
assert put_response.status_code == 403
@@ -119,6 +119,31 @@ class TestBuildPatchedMessagesPatching:
assert "interrupted" in tool_msg.content.lower() assert "interrupted" in tool_msg.content.lower()
assert tool_msg.name == "bash" assert tool_msg.name == "bash"
def test_raw_provider_tool_calls_are_patched(self):
mw = DanglingToolCallMiddleware()
msgs = [
AIMessage(
content="",
tool_calls=[],
additional_kwargs={
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "bash", "arguments": '{"command":"ls"}'},
}
]
},
)
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert len(patched) == 2
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "call_1"
assert patched[1].name == "bash"
assert patched[1].status == "error"
class TestWrapModelCall: class TestWrapModelCall:
def test_no_patch_passthrough(self): def test_no_patch_passthrough(self):
-86
View File
@@ -1,86 +0,0 @@
"""Tests for DeerFlowContext and resolve_context()."""
from dataclasses import FrozenInstanceError
from unittest.mock import MagicMock, patch
import pytest
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext, resolve_context
from deerflow.config.sandbox_config import SandboxConfig
def _make_config(**overrides) -> AppConfig:
defaults = {"sandbox": SandboxConfig(use="test")}
defaults.update(overrides)
return AppConfig(**defaults)
class TestDeerFlowContext:
def test_frozen(self):
ctx = DeerFlowContext(app_config=_make_config(), thread_id="t1")
with pytest.raises(FrozenInstanceError):
ctx.app_config = _make_config()
def test_fields(self):
config = _make_config()
ctx = DeerFlowContext(app_config=config, thread_id="t1", agent_name="test-agent")
assert ctx.thread_id == "t1"
assert ctx.agent_name == "test-agent"
assert ctx.app_config is config
def test_agent_name_default(self):
ctx = DeerFlowContext(app_config=_make_config(), thread_id="t1")
assert ctx.agent_name is None
def test_thread_id_required(self):
with pytest.raises(TypeError):
DeerFlowContext(app_config=_make_config()) # type: ignore[call-arg]
class TestResolveContext:
def test_returns_typed_context_directly(self):
"""Gateway/Client path: runtime.context is DeerFlowContext → return as-is."""
config = _make_config()
ctx = DeerFlowContext(app_config=config, thread_id="t1")
runtime = MagicMock()
runtime.context = ctx
assert resolve_context(runtime) is ctx
def test_fallback_from_configurable(self):
"""LangGraph Server path: runtime.context is None → construct from ContextVar + configurable."""
runtime = MagicMock()
runtime.context = None
config = _make_config()
with (
patch.object(AppConfig, "current", return_value=config),
patch("langgraph.config.get_config", return_value={"configurable": {"thread_id": "t2", "agent_name": "ag"}}),
):
ctx = resolve_context(runtime)
assert ctx.thread_id == "t2"
assert ctx.agent_name == "ag"
assert ctx.app_config is config
def test_fallback_empty_configurable(self):
"""LangGraph Server path with no thread_id in configurable → empty string."""
runtime = MagicMock()
runtime.context = None
config = _make_config()
with (
patch.object(AppConfig, "current", return_value=config),
patch("langgraph.config.get_config", return_value={"configurable": {}}),
):
ctx = resolve_context(runtime)
assert ctx.thread_id == ""
assert ctx.agent_name is None
def test_fallback_from_dict_context(self):
"""Legacy path: runtime.context is a dict → extract from dict directly."""
runtime = MagicMock()
runtime.context = {"thread_id": "old-dict", "agent_name": "from-dict"}
config = _make_config()
with patch.object(AppConfig, "current", return_value=config):
ctx = resolve_context(runtime)
assert ctx.thread_id == "old-dict"
assert ctx.agent_name == "from-dict"
assert ctx.app_config is config
+4 -6
View File
@@ -5,13 +5,11 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from deerflow.config.app_config import AppConfig
@pytest.fixture @pytest.fixture
def mock_app_config(): def mock_app_config():
"""Mock the app config to return tool configurations.""" """Mock the app config to return tool configurations."""
with patch.object(AppConfig, "current") as mock_config: with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
tool_config = MagicMock() tool_config = MagicMock()
tool_config.model_extra = { tool_config.model_extra = {
"max_results": 5, "max_results": 5,
@@ -69,7 +67,7 @@ class TestWebSearchTool:
def test_search_with_custom_config(self, mock_exa_client): def test_search_with_custom_config(self, mock_exa_client):
"""Test search respects custom configuration values.""" """Test search respects custom configuration values."""
with patch.object(AppConfig, "current") as mock_config: with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
tool_config = MagicMock() tool_config = MagicMock()
tool_config.model_extra = { tool_config.model_extra = {
"max_results": 10, "max_results": 10,
@@ -197,7 +195,7 @@ class TestWebFetchTool:
def test_fetch_reads_web_fetch_config(self, mock_exa_client): def test_fetch_reads_web_fetch_config(self, mock_exa_client):
"""Test that web_fetch_tool reads 'web_fetch' config, not 'web_search'.""" """Test that web_fetch_tool reads 'web_fetch' config, not 'web_search'."""
with patch.object(AppConfig, "current") as mock_config: with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
tool_config = MagicMock() tool_config = MagicMock()
tool_config.model_extra = {"api_key": "exa-fetch-key"} tool_config.model_extra = {"api_key": "exa-fetch-key"}
mock_config.return_value.get_tool_config.return_value = tool_config mock_config.return_value.get_tool_config.return_value = tool_config
@@ -217,7 +215,7 @@ class TestWebFetchTool:
def test_fetch_uses_independent_api_key(self, mock_exa_client): def test_fetch_uses_independent_api_key(self, mock_exa_client):
"""Test mixed-provider config: web_fetch uses its own api_key, not web_search's.""" """Test mixed-provider config: web_fetch uses its own api_key, not web_search's."""
with patch.object(AppConfig, "current") as mock_config: with patch("deerflow.community.exa.tools.get_app_config") as mock_config:
with patch("deerflow.community.exa.tools.Exa") as mock_exa_cls: with patch("deerflow.community.exa.tools.Exa") as mock_exa_cls:
mock_exa_cls.return_value = mock_exa_client mock_exa_cls.return_value = mock_exa_client
fetch_config = MagicMock() fetch_config = MagicMock()

Some files were not shown because too many files have changed in this diff Show More