Compare commits

...

79 Commits

Author SHA1 Message Date
Willem Jiang 7052978a43 fix the lint errors 2026-04-26 11:16:22 +08:00
Willem Jiang d9f7f658be Apply suggestions from code review
Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-26 11:12:42 +08:00
Willem Jiang a55de566b9 refactor(backend): consolidate thread_id resolution into shared get_thread_id() utility (#2522)
Extract duplicated thread_id fallback logic from 11 files into a single
  deerflow.utils.runtime.get_thread_id() function with a documented 3-level
  cascade (runtime.context → runtime.config → get_config()).

  The module docstring also clarifies the __pregel_runtime injection pattern used in
  gateway mode.
2026-04-26 10:52:37 +08:00
Willem Jiang 9dc25987e0 fix(channles):update the logger for the channel config (#2524)
* fix(channles):update the logger for the channel config

* fix(channels): normalize credential values and add tests for disabled-but-configured warning

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/dfc0a566-aa59-49f9-a74d-610292fb0a63

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

* fix the backend lint error

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-04-26 10:09:55 +08:00
Willem Jiang 8a044142cb feat(dev): add pre-commit hooks for ruff, eslint, and prettier (#2525)
* feat(dev): add pre-commit hooks for ruff, eslint, and prettier

* fix: use local uv-based ruff hooks and uv run for pre-commit install

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/a1e34cc5-0d4b-4400-9e6a-e687d964ff1e

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-04-26 09:40:17 +08:00
ming1523 410f0c48b5 fix(channels): accept single slack allowed user (#2481)
* fix(channels): accept single slack allowed user

* docs: address Slack allowed_users review notes

* ci: rerun backend unit tests

* docs: clarify Slack allowed_users config

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-25 19:40:52 +08:00
Octopus 1f59e945af fix: cap prompt caching breakpoints at 4 to prevent API 400 errors (#2449)
* fix: cap prompt caching breakpoints at 4 to prevent API 400 errors (fixes #2448)

The previous _apply_prompt_caching() attached cache_control to every text
block in the system prompt, every content block in the last N messages, and
the last tool definition. In multi-turn conversations with structured content
blocks this easily exceeded the 4-breakpoint hard limit enforced by both the
Anthropic API and AWS Bedrock, producing a 400 Bad Request (or a silent
"No generations found in stream" when streaming).

Fix: collect all candidate blocks in document order, then apply cache_control
only to the last MAX_CACHE_BREAKPOINTS (4) of them. Later breakpoints cover a
larger prefix and therefore yield better cache hit rates, making this the
optimal placement strategy as well as the safe one.

Adds 13 unit tests covering the budget cap, edge cases, and correct
last-candidate placement.

* docs: clarify _apply_prompt_caching docstring includes tool definitions

Per Copilot review: the implementation also caches the last tool definition
(see the candidates list at lines 202-205), so the docstring summary should
explicitly mention tools alongside system and recent messages.

* Fix the lint error

* style: fix ruff format check for test_claude_provider_prompt_caching.py

Add the missing blank line before the 'Edge cases' section comment so
that ruff format --check passes in CI.

---------

Co-authored-by: octo-patch <octo-patch@github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-25 19:40:06 +08:00
IECspace f394c0d8c8 feat(mcp): support custom tool interceptors via extensions_config.json (#2451)
* feat(mcp): support custom tool interceptors via extensions_config.json

Add a generic extension point for registering custom MCP tool
interceptors through `extensions_config.json`. This allows downstream
projects to inject per-request header manipulation, auth context
propagation, or other cross-cutting concerns without modifying
DeerFlow source code.

Interceptors are declared as Python callable paths in a new
`mcpInterceptors` array field and loaded via the existing
`resolve_variable` reflection mechanism:

```json
{
  "mcpInterceptors": [
    "my_package.mcp.auth:build_auth_interceptor"
  ]
}
```

Each entry must resolve to a no-arg builder function that returns an
async interceptor compatible with `MultiServerMCPClient`'s
`tool_interceptors` interface.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* test(mcp): add unit tests for custom tool interceptors

Cover all branches of the mcpInterceptors loading logic:

- valid interceptor loaded and appended to tool_interceptors
- multiple interceptors loaded in declaration order
- builder returning None is skipped
- resolve_variable ImportError logged and skipped
- builder raising exception logged and skipped
- absent mcpInterceptors field is safe (no-op)
- custom interceptors coexist with OAuth interceptor

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix(mcp): validate mcpInterceptors type and fix lint warnings

Address review feedback:

1. Validate mcpInterceptors config value before iterating:
   - Accept a single string and normalize to [string]
   - Ignore None silently
   - Log warning and skip for non-list/non-string types

2. Fix ruff F841 lint errors in tests:
   - Rename _make_mock_env to _make_patches, embed mock_client
   - Remove unused `as mock_cls` bindings where not needed
   - Extract _get_interceptors() helper to reduce repetition

3. Add two new test cases for type validation:
   - test_mcp_interceptors_single_string_is_normalized
   - test_mcp_interceptors_invalid_type_logs_warning

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(mcp): validate interceptor return type and fix import mock path

Address review feedback:

1. Validate builder return type with callable() check:
   - callable interceptor → append to tool_interceptors
   - None → silently skip (builder opted out)
   - non-callable → log warning with type name and skip

2. Fix test mock path: resolve_variable is a top-level import in
   tools.py, so mock deerflow.mcp.tools.resolve_variable instead of
   deerflow.reflection.resolve_variable to correctly intercept calls.

3. Add test_custom_interceptor_non_callable_return_logs_warning to
   cover the new non-callable validation branch.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* docs(mcp): add mcpInterceptors example and documentation

- Add mcpInterceptors field to extensions_config.example.json
- Add "Custom Tool Interceptors" section to MCP_SERVER.md with
  configuration format, example interceptor code, and edge case
  behavior notes

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: IECspace <IECspace@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-25 09:18:13 +08:00
orbisai0security 950821cb9b fix: use subprocess instead of os.system in local_backend.py (#2494)
* fix: use subprocess instead of os.system in local_backend.py

The sandbox backend and skill evaluation scripts use subprocess

* fixing the failing test

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-25 08:59:31 +08:00
pyp0327 2bb1a2dfa2 feat(models): Provider for MindIE model engine (#2483)
* feat(models): 适配 MindIE引擎的模型

* test: add unit tests for MindIEChatModel adapter and fix PR review comments

* chore: update uv.lock with pytest-asyncio

* build: add pytest-asyncio to test dependencies

* fix: address PR review comments (lazy import, cache clients, safe newline escape, strict xml regex)

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-25 08:59:03 +08:00
DanielWalnut b970993425 fix: read lead agent options from context (#2515)
* fix: read lead agent options from context

* fix: validate runtime context config
2026-04-24 22:46:51 +08:00
DanielWalnut ec8a8cae38 fix: gate deferred MCP tool execution (#2513)
* fix: gate deferred MCP tool execution

* style: format deferred tool middleware

* fix: address deferred tool review feedback
2026-04-24 22:45:41 +08:00
DanielWalnut d78ed5c8f2 fix: inherit subagent skill allowlists (#2514) 2026-04-24 21:24:42 +08:00
Nan Gao f9ff3a698d fix(middleware): avoid rescuing non-skill tool outputs during summarization (#2458)
* fix(middelware): narrow skill rescue to skill-related tool outputs

* fix(summarization): address skill rescue review feedback

* fix: wire summarization skill rescue config

* fix: remove dead skill tool helper

* fix(lint): fix format

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-24 21:19:46 +08:00
Admire c2332bb790 fix memory settings layout overflow (#2420)
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-24 20:29:55 +08:00
He Wang 3a61126824 fix: keep debug.py interactive terminal free from background log noise (#2466)
* fix(debug): keep terminal clean by redirecting all logs to file

- Redirect all logs to debug.log file to prevent background task logs
  from interfering with interactive terminal prompts
- Honor AppConfig.log_level setting instead of hard-coding to INFO
- Make logging setup idempotent by clearing pre-existing handlers
- Defer deerflow imports until after logging is configured to ensure
  import-time side effects are captured in debug.log
- Display active log level in startup banner
- Add prompt_toolkit installation tip for enhanced readline support

Made-with: Cursor

* attaching the file handler before importing/calling get_app_config()

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-24 17:09:41 +08:00
Airene Fang 11f557a2c6 feat(trace):Add run_name to the trace info for system agents. (#2492)
* feat(trace): Add `run_name` to the trace info for suggestions and memory.

before(in langsmith):
CodexChatModel
CodexChatModel
lead_agent
after:
suggest_agent
memory_agent
lead_agent

feat(trace): Add `run_name` to the trace info for suggestions and memory.

before(in langsmith):
CodexChatModel
CodexChatModel
lead_agent
after:
suggest_agent
memory_agent
lead_agent

* feat(trace): Add `run_name` to the trace info for system agents.

before(in langsmith):
CodexChatModel
CodexChatModel
CodexChatModel
CodexChatModel
lead_agent
after:
suggest_agent
title_agent
security_agent
memory_agent
lead_agent

* chore(code format):code format

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-24 17:06:55 +08:00
d 🔹 e8572b9d0c fix(jina): log transient failures at WARNING without traceback (#2484) (#2485)
The exception handler in JinaClient.crawl used logger.exception, which
emits an ERROR-level record with the full httpx/httpcore/anyio traceback
for every transient network failure (timeout, connection refused). Other
search/crawl providers in the project log the same class of recoverable
failures as a single line. One offline/slow-network session could produce
dozens of multi-frame ERROR stack traces, drowning out real problems.

Switch to logger.warning with a concise message that includes the
exception type and its str, matching the style used elsewhere for
recoverable transient failures (aio_sandbox, ddg, etc.). The exception
type now also surfaces into the returned "Error: ..." string so callers
retain diagnostic signal.

Adds a regression test that asserts the log record is WARNING, carries
no exc_info, and includes the exception class name.

Co-authored-by: voidborne-d <voidborne-d@users.noreply.github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-24 16:00:14 +08:00
Willem Jiang 80a7446fd6 fix(backend): fix the unit test error in backend 2026-04-24 14:56:03 +08:00
Willem Jiang cd12821134 fix(backend): Updated the uv.lock with new added dependency 2026-04-24 14:55:13 +08:00
Xinmin Zeng 30d619de08 feat(subagents): support per-subagent skill loading and custom subagent types (#2253)
* feat(subagents): support per-subagent skill loading and custom subagent types (#2230)

Add per-subagent skill configuration and custom subagent type registration,
aligned with Codex's role-based config layering and per-session skill injection.

Backend:
- SubagentConfig gains `skills` field (None=all, []=none, list=whitelist)
- New CustomSubagentConfig for user-defined subagent types in config.yaml
- SubagentsAppConfig gains `custom_agents` section and `get_skills_for()`
- Registry resolves custom agents with three-layer config precedence
- SubagentExecutor loads skills per-session as conversation items (Codex pattern)
- task_tool no longer appends skills to system_prompt
- Lead agent system prompt dynamically lists all registered subagent types
- setup_agent tool accepts optional skills parameter
- Gateway agents API transparently passes skills in CRUD operations

Frontend:
- Agent/CreateAgentRequest/UpdateAgentRequest types include skills field
- Agent card displays skills as badges alongside tool_groups

Config:
- config.example.yaml documents custom_agents and per-agent skills override

Tests:
- 40 new tests covering all skill config, custom agents, and registry logic
- Existing tests updated for new get_skills_prompt_section signature

Closes #2230

* fix: address review feedback on skills PR

- Remove stale get_skills_prompt_section monkeypatches from test_task_tool_core_logic.py
  (task_tool no longer imports this function after skill injection moved to executor)
- Add key prefixes (tg:/sk:) to agent-card badges to prevent React key collisions
  between tool_groups and skills

* fix(ci): resolve lint and test failures

- Format agent-card.tsx with prettier (lint-frontend)
- Remove stale "Skills Appendix" system_prompt assertion — skills are now
  loaded per-session by SubagentExecutor, not appended to system_prompt

* fix(ci): sort imports in test_subagent_skills_config.py (ruff I001)

* fix(ci): use nullish coalescing in agent-card badge condition (eslint)

* fix: address review feedback on skills PR

- Use model_fields_set in AgentUpdateRequest to distinguish "field omitted"
  from "explicitly set to null" — fixes skills=None ambiguity where None
  means "inherit all" but was treated as "don't change"
- Move lazy import of get_subagent_config outside loop in
  _build_available_subagents_description to avoid repeated import overhead

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-23 23:59:47 +08:00
JerryChaox 4e72410154 fix(gateway): bound lifespan shutdown hooks to prevent worker hang under uvicorn reload (#2331)
* fix(gateway): bound lifespan shutdown hooks to prevent worker hang

Gateway worker can hang indefinitely in `uvicorn --reload` mode with
the listening socket still bound — all /api/* requests return 504,
and SIGKILL is the only recovery.

Root cause (py-spy dump from a reproduction showed 16+ stacked frames
of signal_handler -> Event.set -> threading.Lock.__enter__ on the
main thread): CPython's `threading.Event` uses `Condition(Lock())`
where the inner Lock is non-reentrant. uvicorn's BaseReload signal
handler calls `should_exit.set()` directly from signal context; if a
second signal (SIGTERM/SIGHUP from the reload supervisor, or
watchfiles-triggered reload) arrives while the first handler holds
the Lock, the reentrant call deadlocks on itself.

The reload supervisor keeps sending those signals only when the
worker fails to exit promptly. DeerFlow's lifespan currently awaits
`stop_channel_service()` with no timeout; if a channel's `stop()`
stalls (e.g. Feishu/Slack WebSocket waiting for an ack), the worker
can't exit, the supervisor keeps signaling, and the deadlock becomes
reachable.

This is a defense-in-depth fix — it does not repair the upstream
uvicorn/CPython issue, but it ensures DeerFlow's lifespan exits
within a bounded window so the supervisor has no reason to keep
firing signals. No behavior change on the happy path.

Wraps the shutdown hook in `asyncio.wait_for(timeout=5.0)` and logs
a warning on timeout before proceeding to worker exit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Update backend/app/gateway/app.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* style: apply make format (ruff) to test assertions

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-23 19:41:26 +08:00
He Wang c42ae3af79 feat: add optional prompt-toolkit support to debug.py (#2461)
* feat: add optional prompt-toolkit support to debug.py

Use PromptSession.prompt_async() for arrow-key navigation and input
history when prompt-toolkit is available, falling back to plain input()
with a helpful install tip otherwise.

Made-with: Cursor

* fix: handle EOFError gracefully in debug.py

Catch EOFError alongside KeyboardInterrupt so that Ctrl-D exits
cleanly instead of printing a traceback.

Made-with: Cursor
2026-04-23 17:49:18 +08:00
dependabot[bot] bd35cd39aa chore(deps): bump uuid from 13.0.0 to 14.0.0 in /frontend (#2467)
Bumps [uuid](https://github.com/uuidjs/uuid) from 13.0.0 to 14.0.0.
- [Release notes](https://github.com/uuidjs/uuid/releases)
- [Changelog](https://github.com/uuidjs/uuid/blob/main/CHANGELOG.md)
- [Commits](https://github.com/uuidjs/uuid/compare/v13.0.0...v14.0.0)

---
updated-dependencies:
- dependency-name: uuid
  dependency-version: 14.0.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-23 14:47:15 +08:00
d 🔹 b90f219bd1 fix(skills): validate bundled SKILL.md front-matter in CI (fixes #2443) (#2457)
* fix(skills): validate bundled SKILL.md front-matter in CI (fixes #2443)

Adds a parametrized backend test that runs `_validate_skill_frontmatter`
against every bundled SKILL.md under `skills/public/`, so a broken
front-matter fails CI with a per-skill error message instead of
surfacing as a runtime gateway-load warning.

The new test caught two pre-existing breakages on `main` and fixes them:

* `bootstrap/SKILL.md`: the unquoted description had a second `:` mid-line
  ("Also trigger for updates: ..."), which YAML parses as a nested mapping
  ("mapping values are not allowed here"). Rewrites the description as a
  folded scalar (`>-`), which preserves the original wording (including the
  embedded colon, double quotes, and apostrophes) without further escaping.
  This complements PR #2436 (single-file colon→hyphen patch) with a more
  general convention that survives future edits.

* `chart-visualization/SKILL.md`: used `dependency:` which is not in
  `ALLOWED_FRONTMATTER_PROPERTIES`. Renamed to `compatibility:`, the
  documented field for "Required tools, dependencies" per skill-creator.
  No code reads `dependency` (verified by grep across backend/).

* Apply suggestions from code review

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* Fix the lint error

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-23 14:06:14 +08:00
dependabot[bot] 96d00f6073 chore(deps): bump dompurify from 3.3.1 to 3.4.1 in /frontend (#2462)
Bumps [dompurify](https://github.com/cure53/DOMPurify) from 3.3.1 to 3.4.1.
- [Release notes](https://github.com/cure53/DOMPurify/releases)
- [Commits](https://github.com/cure53/DOMPurify/compare/3.3.1...3.4.1)

---
updated-dependencies:
- dependency-name: dompurify
  dependency-version: 3.4.1
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-23 12:18:59 +08:00
He Wang c43c803f66 fix: remove mismatched context param in debug.py to suppress Pydantic warning (#2446)
* fix: remove mismatched context param in debug.py to suppress Pydantic warning

The ainvoke call passed context={"thread_id": ...} but the agent graph
has no context_schema (ContextT defaults to None), causing a
PydanticSerializationUnexpectedValue warning on every invocation.

Align with the production run_agent path by injecting context via
Runtime into configurable["__pregel_runtime"] instead.

Closes #2445

Made-with: Cursor

* refactor: derive runtime thread_id from config to avoid duplication

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Made-with: Cursor

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-23 09:56:57 +08:00
dependabot[bot] dbd777fe62 chore(deps): bump python-dotenv from 1.2.1 to 1.2.2 in /backend (#2440)
Bumps [python-dotenv](https://github.com/theskumar/python-dotenv) from 1.2.1 to 1.2.2.
- [Release notes](https://github.com/theskumar/python-dotenv/releases)
- [Changelog](https://github.com/theskumar/python-dotenv/blob/main/CHANGELOG.md)
- [Commits](https://github.com/theskumar/python-dotenv/compare/v1.2.1...v1.2.2)

---
updated-dependencies:
- dependency-name: python-dotenv
  dependency-version: 1.2.2
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-22 16:48:09 +08:00
dependabot[bot] 1ca2621285 chore(deps): bump lxml from 6.0.2 to 6.1.0 in /backend (#2427)
Bumps [lxml](https://github.com/lxml/lxml) from 6.0.2 to 6.1.0.
- [Release notes](https://github.com/lxml/lxml/releases)
- [Changelog](https://github.com/lxml/lxml/blob/master/CHANGES.txt)
- [Commits](https://github.com/lxml/lxml/compare/lxml-6.0.2...lxml-6.1.0)

---
updated-dependencies:
- dependency-name: lxml
  dependency-version: 6.1.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-22 16:14:11 +08:00
Shawn Jasper 5ba1dacf25 fix: rename present_file to present_files in docs and prompts (#2393)
The tool is registered as `present_files` (plural) in present_file_tool.py,
but four references in documentation and prompt strings incorrectly used the
singular form `present_file`. This could cause confusion and potentially
lead to incorrect tool invocations.

Changed files:
- backend/docs/GUARDRAILS.md
- backend/docs/ARCHITECTURE.md
- backend/packages/harness/deerflow/agents/lead_agent/prompt.py (2 occurrences)
2026-04-21 16:10:14 +08:00
Reuben Bowlby 085c13edc7 fix: remove unnecessary f-string prefixes and unused import (#2352)
- Remove f-string prefix on 7 strings with no placeholders (F541)
  in analyze.py, aggregate_benchmark.py, run_loop.py, generate_review.py
- Remove unused `os` import in quick_validate.py (F401)

Found by ruff via HUMMBL Arbiter (https://hummbl.io/audit).
2026-04-21 09:53:18 +08:00
Copilot ef04174194 Fix invalid HTML nesting in reasoning trigger during complex task rendering (#2382)
* Initial plan

* fix(frontend): avoid invalid paragraph nesting in reasoning trigger

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/4c9eb0c2-ff29-4629-a61c-4e33d736d918

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

* test(frontend): strengthen reasoning trigger DOM nesting assertion

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/4c9eb0c2-ff29-4629-a61c-4e33d736d918

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>
2026-04-21 09:41:28 +08:00
Ansel 6dce26a52e fix: resolve tool duplication and skill parser YAML inconsistencies (#1803) (#2107)
* Refactor tests for SKILL.md parser

Updated tests for SKILL.md parser to handle quoted names and descriptions correctly. Added new tests for parsing plain and single-quoted names, and ensured multi-line descriptions are processed properly.

* Implement tool name validation and deduplication

Add tool name mismatch warning and deduplication logic

* Refactor skill file parsing and error handling

* Add tests for tool name deduplication

Added tests for tool name deduplication in get_available_tools(). Ensured that duplicates are not returned, the first occurrence is kept, and warnings are logged for skipped duplicates.

* Apply suggestions from code review

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* Update minimal config to include tools list

* Update test for nonexistent skill file

Ensure the test for nonexistent files checks for None.

* Refactor tool loading and add skill management support

Refactor tool loading logic to include skill management tools based on configuration and clean up comments.

* Enhance code comments for tool loading logic

Added comments to clarify the purpose of various code sections related to tool loading and configuration.

* Fix assertion for duplicate tool name warning

* Fix indentation issues in tools.py

* Fix the lint error of test_tool_deduplication

* Fix the lint error of tools.py

* Fix the lint error

* Fix the lint error

* make format

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-20 20:25:03 +08:00
imhaoran fc94e90f6c fix(setup-agent): prevent data loss when setup fails on existing agen… (#2254)
* fix(setup-agent): prevent data loss when setup fails on existing agent directory

Record whether the agent directory pre-existed before mkdir, and only
run shutil.rmtree cleanup when the directory was newly created during
this call. Previously, any failure would delete the entire directory
including pre-existing SOUL.md and config.yaml.

* fix: address PR review — init variables before try, remove unused result

* style: fix ruff I001 import block formatting in test file

* style: add missing blank lines between top-level definitions in test file
2026-04-20 20:17:30 +08:00
Eilen Shin f2013f47aa fix command palette hydration mismatch (#2301)
* fix command palette hydration mismatch

* style: format command dialog description
2026-04-20 11:36:16 +08:00
KiteEater 4be857f64b fix: use Apple Container image pull syntax (#2366) 2026-04-20 08:00:05 +08:00
Admire c99865f53d fix(token-usage): enable stream usage for openai-compatible models (#2217)
* fix(token-usage): enable stream usage for openai-compatible models

* fix(token-usage): narrow stream_usage default to ChatOpenAI
2026-04-19 22:42:55 +08:00
YYMa 05f1da03e5 fix(script): use portable locale for langgraph log pipeline on macOS (#2361) 2026-04-19 22:41:00 +08:00
Xun a62ca5dd47 fix: Catch httpx.ReadError in the error handling (#2309)
* fix: Catch httpx.ReadError in the error handling

* fix
2026-04-19 22:30:22 +08:00
Nan Gao f514e35a36 fix(backend): make clarification messages idempotent (#2350) (#2351) 2026-04-19 22:00:58 +08:00
Xun 7c87dc5bca fix(reasoning): prevent LLM-hallucinated HTML tags from rendering as DOM elements (#2321)
* fix

* add test

* fix
2026-04-19 19:27:34 +08:00
Hinotobi 80e210f5bb [security] fix(uploads): require explicit opt-in for host-side document conversion (#2332)
* fix: disable host-side upload conversion by default

* fix: address PR review comments on upload conversion gate
2026-04-18 22:47:42 +08:00
dependabot[bot] 5656f90792 chore(deps-dev): bump pytest from 9.0.2 to 9.0.3 in /backend (#2349)
Bumps [pytest](https://github.com/pytest-dev/pytest) from 9.0.2 to 9.0.3.
- [Release notes](https://github.com/pytest-dev/pytest/releases)
- [Changelog](https://github.com/pytest-dev/pytest/blob/main/CHANGELOG.rst)
- [Commits](https://github.com/pytest-dev/pytest/compare/9.0.2...9.0.3)

---
updated-dependencies:
- dependency-name: pytest
  dependency-version: 9.0.3
  dependency-type: direct:development
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-18 22:22:40 +08:00
Shawn Jasper 55474011c9 fix(subagent): inherit parent agent's tool_groups in task_tool (#2305)
* fix(subagent): inherit parent agent's tool_groups in task_tool

When a custom agent defines tool_groups (e.g. [file:read, file:write, bash]),
the restriction is correctly applied to the lead agent. However, when the lead
agent delegates work to a subagent via the task tool, get_available_tools() is
called without the groups parameter, causing the subagent to receive ALL tools
(including web_search, web_fetch, image_search, etc.) regardless of the parent
agent's configuration.

This fix propagates tool_groups through run metadata so that task_tool passes
the same group filter when building the subagent's tool set.

Changes:
- agent.py: include tool_groups in run metadata
- task_tool.py: read tool_groups from metadata and pass to get_available_tools()

* fix: initialize metadata before conditional block and update tests for tool_groups propagation

- Initialize metadata = {} before the 'if runtime is not None' block to
  avoid Ruff F821 (possibly-undefined variable) and simplify the
  parent_tool_groups expression.
- Update existing test assertion to expect groups=None in
  get_available_tools call signature.
- Add 3 new test cases:
  - test_task_tool_propagates_tool_groups_to_subagent
  - test_task_tool_no_tool_groups_passes_none
  - test_task_tool_runtime_none_passes_groups_none
2026-04-18 22:17:37 +08:00
imhaoran 24fe5fbd8c fix(mcp): prevent RuntimeError from escaping except block in get_cach… (#2252)
* fix(mcp): prevent RuntimeError from escaping except block in get_cached_mcp_tools

When `asyncio.get_event_loop()` raises RuntimeError and the fallback
`asyncio.run()` also fails, the exception escapes unhandled because
Python does not route exceptions raised inside an `except` block to
sibling `except` clauses. Wrap the fallback call in its own try/except
so failures are logged and the function returns [] as intended.

* fix: use logger.exception to preserve stack traces on MCP init failure
2026-04-18 21:07:30 +08:00
Willem Jiang be4663505a chroe(script): disable the color log of langgraph 2026-04-18 20:03:05 +08:00
dependabot[bot] aa6098e6a4 chore(deps): bump langsmith from 0.6.4 to 0.7.31 in /backend (#2291)
Bumps [langsmith](https://github.com/langchain-ai/langsmith-sdk) from 0.6.4 to 0.7.31.
- [Release notes](https://github.com/langchain-ai/langsmith-sdk/releases)
- [Commits](https://github.com/langchain-ai/langsmith-sdk/compare/v0.6.4...v0.7.31)

---
updated-dependencies:
- dependency-name: langsmith
  dependency-version: 0.7.31
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-18 19:54:21 +08:00
Airene Fang 1221448029 fix(scripts): Cloud Provider Reports Security Issue(aliyun could) (#2323)
ATT&CK矩阵ID:T1059.004
数据来源:进程启动触发检测
告警原因:该进程的命令行显示出反弹shelI的特征
命令行:timeout 1 bash -c exec 3<>/dev/tcp/127.0.0.1/2024
进程路径:/usr/bin/timeout
进程链:-[337650] /usr/sbin/sshd -D
-[397971] /usr/sbin/sshd -D -R
-[397977]-bash
-[398903] make dev
-[398920] bash ./scripts/serve.sh --dev
-[399037]bash ./scripts/wait-for-port.sh 2024 60 LangGraph
2026-04-18 19:33:32 +08:00
Jason 3b91df2b18 fix(frontend): add catch-all API rewrite for gateway routes (#2335)
When NEXT_PUBLIC_BACKEND_BASE_URL is unset, the frontend proxies API
requests to the gateway. Only /api/agents and /api/skills had rewrite
rules, causing 404s for /api/models, /api/threads, /api/memory,
/api/mcp, /api/suggestions, /api/runs, etc.

Add a catch-all /api/:path* rewrite that proxies all remaining gateway
API routes. The existing /api/langgraph rewrite takes priority because
it is pushed to the array first (Next.js checks rewrites in order).

Fixes #2327

Co-authored-by: JasonOA888 <JasonOA888@users.noreply.github.com>
2026-04-18 11:35:19 +08:00
Shawn Jasper ca1b7d5f48 fix(sandbox): add missing path masking in ls_tool output (#2317)
ls_tool was the only file-system tool that did not call
mask_local_paths_in_output() before returning its result, causing host
absolute paths (e.g. /Users/.../backend/.deer-flow/knowledge-base/...)
to leak to the LLM instead of the expected virtual paths
(/mnt/knowledge-base/...).

This patch:
- Adds the mask_local_paths_in_output() call to ls_tool, consistent
  with bash_tool, glob_tool and grep_tool.
- Initialises thread_data = None before the is_local_sandbox branch
  (same pattern as glob_tool) so the variable is always in scope.
- Adds three new tests covering user-data path masking, skills path
  masking and the empty-directory edge case.
2026-04-18 08:46:59 +08:00
yangzheli c6b0423558 feat(frontend): add Playwright E2E tests with CI workflow (#2279)
* feat(frontend): add Playwright E2E tests with CI workflow

Add end-to-end testing infrastructure using Playwright (Chromium only).
14 tests across 5 spec files cover landing page, chat workspace,
thread history, sidebar navigation, and agent chat — all with mocked
LangGraph/Backend APIs via network interception (zero backend dependency).

New files:
- playwright.config.ts — Chromium, 30s timeout, auto-start Next.js
- tests/e2e/utils/mock-api.ts — shared API mocks & SSE stream helpers
- tests/e2e/{landing,chat,thread-history,sidebar,agent-chat}.spec.ts
- .github/workflows/e2e-tests.yml — push main + PR trigger, paths filter

Updated: package.json, Makefile, .gitignore, CONTRIBUTING.md,
frontend/CLAUDE.md, frontend/AGENTS.md, frontend/README.md

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: apply Copilot suggestions

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-18 08:21:08 +08:00
DanielWalnut 898f4e8ac2 fix: Memory update system has cache corruption, data loss, and thread-safety bugs (#2251)
* fix(memory): cache corruption, thread-safety, and caller mutation bugs

Bug 1 (updater.py): deep-copy current_memory before passing to
_apply_updates() so a subsequent save() failure cannot leave a
partially-mutated object in the storage cache.

Bug 3 (storage.py): add _cache_lock (threading.Lock) to
FileMemoryStorage and acquire it around every read/write of
_memory_cache, fixing concurrent-access races between the background
timer thread and HTTP reload calls.

Bug 4 (storage.py): replace in-place mutation
  memory_data["lastUpdated"] = ...
with a shallow copy
  memory_data = {**memory_data, "lastUpdated": ...}
so save() no longer silently modifies the caller's dict.

Regression tests added for all three bugs in test_memory_storage.py
and test_memory_updater.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style: format test_memory_updater.py with ruff

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style: remove stale bug-number labels from code comments and docstrings

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-17 12:00:31 +08:00
dependabot[bot] 259a6844bf chore(deps): bump python-multipart from 0.0.22 to 0.0.26 in /backend (#2282)
Bumps [python-multipart](https://github.com/Kludex/python-multipart) from 0.0.22 to 0.0.26.
- [Release notes](https://github.com/Kludex/python-multipart/releases)
- [Changelog](https://github.com/Kludex/python-multipart/blob/master/CHANGELOG.md)
- [Commits](https://github.com/Kludex/python-multipart/compare/0.0.22...0.0.26)

---
updated-dependencies:
- dependency-name: python-multipart
  dependency-version: 0.0.26
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-16 09:07:28 +08:00
d 🔹 a664d2f5c4 fix(checkpointer): create parent directory before opening SQLite in sync provider (#2272)
* fix(checkpointer): create parent directory before opening SQLite in sync provider

The sync checkpointer factory (_sync_checkpointer_cm) opens a SQLite
connection without first ensuring the parent directory exists.  The async
provider and both store providers already call ensure_sqlite_parent_dir(),
but this call was missing from the sync path.

When the deer-flow harness package is used from an external virtualenv
(where the .deer-flow directory is not pre-created), the missing parent
directory causes:

    sqlite3.OperationalError: unable to open database file

Add the missing ensure_sqlite_parent_dir() call in the sync SQLite
branch, consistent with the async provider, and add a regression test.

Closes #2259

* style: fix ruff format + add call-order assertion for ensure_parent_dir

- Fix formatting in test_checkpointer.py (ruff format)
- Add test_sqlite_ensure_parent_dir_before_connect to verify
  ensure_sqlite_parent_dir is called before from_conn_string
  (addresses Copilot review suggestion)

---------

Co-authored-by: voidborne-d <voidborne-d@users.noreply.github.com>
2026-04-16 09:06:38 +08:00
YuJitang 105db00987 feat: show token usage per assistant response (#2270)
* feat: show token usage per assistant response

* fix: align client models response with token usage

* fix: address token usage review feedback

* docs: clarify token usage config example

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-16 08:56:49 +08:00
Nan Gao 0e16a7fe55 fix(frontend): make Suggestion button opaque in dark mode (#2276)
* fix(frontend): make Suggestion button opaque in dark mode

The outline Button variant applies dark:bg-input/30, leaving Suggestion
pills ~70% transparent in dark mode. Scrolled chat content bled through
the buttons, making suggestion text unreadable. Override with
dark:bg-background so it matches the opaque light-mode appearance.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix the lint error of commit

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-16 08:55:16 +08:00
Nan Gao 4d3038a7b6 fix(frontend): stop artifact panel from auto-opening on rehydrated write_file (#2278)
After a page refresh, the artifact panel's autoOpen/autoSelect state is
reset to true. Submitting a new question flips thread.isLoading to true,
which message-list passes to every MessageGroup — including historical
ones. The previous response's last write_file step then satisfies the
auto-open condition and re-pops the stale artifact.

Gate the auto-open on the tool call having no result yet, so only a
write_file that is still streaming in the current response can trigger
it; rehydrated tool calls always carry a result and are now skipped.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-16 08:46:47 +08:00
Hinotobi 2176b2bbfc fix: validate bootstrap agent names before filesystem writes (#2274)
* fix: validate bootstrap agent names before filesystem writes

* fix: tighten bootstrap agent-name validation
2026-04-16 08:36:42 +08:00
Wen 8e3591312a test: add unit tests for ViewImageMiddleware (#2256)
* test: add unit tests for ViewImageMiddleware

- Add 33 test cases covering all 7 internal methods plus sync/async
  before_model hooks
- Cover normal path, edge cases (missing keys, empty base64, stale
  ToolMessages before assistant turn), and deduplication logic
- Related to Q2 Roadmap #1669

* test: add unit tests for ViewImageMiddleware

Add 35 test cases covering all internal methods, before_model hooks,
and edge cases (missing attrs, list-content dedup, stale ToolMessages).

Related to #1669
2026-04-15 23:54:30 +08:00
Willem Jiang 242c654075 fix(frontend):lint error of message-list-item.tsx 2026-04-15 23:35:50 +08:00
Willem Jiang 0c21cbf01f fix(frontend): lint error of frontend 2026-04-15 23:27:46 +08:00
Jason 772538ddba fix(frontend): add skills API rewrite rule to prevent HTML fallback (#2241)
Fixes #2203

When NEXT_PUBLIC_BACKEND_BASE_URL is not set, the frontend uses Next.js
rewrites to proxy API calls to the gateway. Skills API routes were missing
from the rewrite config, causing /api/skills to return the SPA HTML instead
of JSON, which produced 'Unexpected token <' errors in the skill settings page.

Co-authored-by: JasonOA888 <JasonOA888@users.noreply.github.com>
2026-04-15 23:21:40 +08:00
Jason 35fb3dd65a fix(frontend): resolve /mnt/ links in markdown to artifact API URLs (#2243)
* fix(gateway): forward agent_name and is_bootstrap from context to configurable

The frontend sends agent_name and is_bootstrap via the context field
in run requests, but services.py only forwards a hardcoded whitelist
of keys (_CONTEXT_CONFIGURABLE_KEYS) into the agent's configurable
dict.  Since agent_name was missing, custom agents never received
their name — make_lead_agent always fell back to the default lead
agent, skipping SOUL.md, per-agent config and skill filtering.

Similarly, is_bootstrap was dropped, so the bootstrap creation flow
could never activate the setup_agent tool path.

Add both keys to the whitelist so they reach make_lead_agent.

Fixes #2222

* fix(frontend): resolve /mnt/ links in markdown to artifact API URLs

AI agent messages contain links like /mnt/user-data/outputs/file.pdf
which were rendered as-is in the browser, resulting in 404 errors.
Images already got the correct treatment via MessageImage and
resolveArtifactURL, but anchor tags (<a>) were passed through
unchanged.

Add an 'a' component override in MessageContent_ that rewrites
/mnt/-prefixed hrefs to the artifact API endpoint, matching the
existing image handling pattern.

Fixes #2232

---------

Co-authored-by: JasonOA888 <JasonOA888@users.noreply.github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-15 23:12:21 +08:00
Jason 692f79452d fix(gateway): forward agent_name and is_bootstrap from context to configurable (#2242)
The frontend sends agent_name and is_bootstrap via the context field
in run requests, but services.py only forwards a hardcoded whitelist
of keys (_CONTEXT_CONFIGURABLE_KEYS) into the agent's configurable
dict.  Since agent_name was missing, custom agents never received
their name — make_lead_agent always fell back to the default lead
agent, skipping SOUL.md, per-agent config and skill filtering.

Similarly, is_bootstrap was dropped, so the bootstrap creation flow
could never activate the setup_agent tool path.

Add both keys to the whitelist so they reach make_lead_agent.

Fixes #2222

Co-authored-by: JasonOA888 <JasonOA888@users.noreply.github.com>
2026-04-15 23:11:10 +08:00
DanielWalnut 8760937439 fix(memory): use asyncio.to_thread for blocking file I/O in aupdate_memory (#2220)
* fix(memory): use asyncio.to_thread for blocking file I/O in aupdate_memory

`_finalize_update` performs synchronous blocking operations (os.mkdir,
file open/write/rename/stat) that were called directly from the async
`aupdate_memory` method, causing `BlockingError` from blockbuster when
running under an ASGI server. Wrap the call with `asyncio.to_thread` to
offload all blocking I/O to a thread pool.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(memory): use unique temp filename to prevent concurrent write collision

`file_path.with_suffix(".tmp")` produces a fixed path — concurrent saves
for the same agent (now possible after wrapping _finalize_update in
asyncio.to_thread) would clobber the same temp file. Use a UUID-suffixed
temp file so each write is isolated.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(memory): also offload _prepare_update_prompt to thread pool

FileMemoryStorage.load() inside _prepare_update_prompt performs
synchronous stat() and file read, blocking the event loop just like
_finalize_update did. Wrap _prepare_update_prompt in asyncio.to_thread
for the same reason.

The async path now has no blocking file I/O on the event loop:
  to_thread(_prepare_update_prompt) → await model.ainvoke() → to_thread(_finalize_update)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-14 16:41:54 +08:00
DanielWalnut 4ba3167f48 feat: flush memory before summarization (#2176)
* feat: flush memory before summarization

* fix: keep agent-scoped memory on summarization flush

* fix: harden summarization hook plumbing

* fix: address summarization review feedback

* style: format memory middleware
2026-04-14 15:01:06 +08:00
Octopus e4f896e90d fix(todo-middleware): prevent premature agent exit with incomplete todos (#2135)
* fix(todo-middleware): prevent premature agent exit with incomplete todos

When plan mode is active (is_plan_mode=True), the agent occasionally
exits the loop and outputs a final response while todo items are still
incomplete. This happens because the routing edge only checks for
tool_calls, not todo completion state.

Fixes #2112

Add an after_model override to TodoMiddleware with
@hook_config(can_jump_to=["model"]). When the model produces a
response with no tool calls but there are still incomplete todos, the
middleware injects a todo_completion_reminder HumanMessage and returns
jump_to=model to force another model turn. A cap of 2 reminders
prevents infinite loops when the agent cannot make further progress.

Also adds _completion_reminder_count() helper and 14 new unit tests
covering all edge cases of the new after_model / aafter_model logic.

* Remove unnecessary blank line in test file

* Fix runtime argument annotation in before_model

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: octo-patch <octo-patch@github.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-04-14 11:11:26 +08:00
luo jiyin 07fc25d285 feat: switch memory updater to async LLM calls (#2138)
* docs: mark memory updater async migration as completed

- Update TODO.md to mark the replacement of sync model.invoke()
  with async model.ainvoke() in title_middleware and memory updater
  as completed using [x] format

Addresses #2131

* feat: switch memory updater to async LLM calls

- Add async aupdate_memory() method using await model.ainvoke()
- Convert sync update_memory() to use async wrapper
- Add _run_async_update_sync() for nested loop context handling
- Maintain backward compatibility with existing sync API
- Add ThreadPoolExecutor for async execution from sync contexts

Addresses #2131

* test: add tests for async memory updater

- Add test_async_update_memory_uses_ainvoke() to verify async path
- Convert existing tests to use AsyncMock and ainvoke assertions
- Add test_sync_update_memory_wrapper_works_in_running_loop()
- Update all model mocks to use async await patterns

Addresses #2131

* fix: apply ruff formatting to memory updater

- Format multi-line expressions to single line
- Ensure code style consistency with project standards
- Fix lint issues caught by GitHub Actions

* test: add comprehensive tests for async memory updater

- Add test_async_update_memory_uses_ainvoke() to verify async path
- Convert existing tests to use AsyncMock and ainvoke assertions
- Add test_sync_update_memory_wrapper_works_in_running_loop()
- Update all model mocks to use async await patterns
- Ensure backward compatibility with sync API

* fix: satisfy ruff formatting in memory updater test

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-14 11:10:42 +08:00
Nan Gao 55bc09ac33 fix(backend): fix uploads for mounted sandbox providers (#2199)
* fix uploads for mounted sandbox providers

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-14 10:44:31 +08:00
dependabot[bot] c43a45ea40 chore(deps): bump pillow from 12.1.1 to 12.2.0 in /backend (#2206)
Bumps [pillow](https://github.com/python-pillow/Pillow) from 12.1.1 to 12.2.0.
- [Release notes](https://github.com/python-pillow/Pillow/releases)
- [Changelog](https://github.com/python-pillow/Pillow/blob/main/CHANGES.rst)
- [Commits](https://github.com/python-pillow/Pillow/compare/12.1.1...12.2.0)

---
updated-dependencies:
- dependency-name: pillow
  dependency-version: 12.2.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-14 10:35:59 +08:00
Admire 9cf7153b1d fix(check): windows pnpm version detection in check script (#2189)
* fix: resolve Windows pnpm detection in check script

* style: format check script regression test

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix: resolve corepack fallback on windows

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-04-14 10:29:44 +08:00
Octopus c91785dd68 fix(title): strip <think> tags from title model responses and assistant context (#1927)
* fix(title): strip <think> tags from title model responses and assistant context

Reasoning models (e.g. minimax M2.7, DeepSeek-R1) emit <think>...</think>
blocks before their actual output. When such a model is used as the title
model (or as the main agent), the raw thinking content leaked into the thread
title stored in state, so the chat list showed the internal monologue instead
of a meaningful title.

Fixes #1884

- Add `_strip_think_tags()` helper using a regex to remove all <think>...</think> blocks
- Apply it in `_parse_title()` so the title model response is always clean
- Apply it to the assistant message in `_build_title_prompt()` so thinking
  content from the first AI turn is not fed back to the title model
- Add four new unit tests covering: stripping in parse, think-only response,
  assistant prompt stripping, and end-to-end async flow with think tags

* Fix the lint error

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-14 09:51:39 +08:00
sqsge 053e18e1a6 fix(skills): avoid blocking custom skill deletion on readonly history writes (#2197) 2026-04-14 09:00:29 +08:00
Hinotobi a7e7c6d667 fix: disable custom-agent management API by default (#2161)
* fix: disable custom-agent management API by default

* style: format agents API hardening files

* fix: address review feedback for agents API hardening

* fix: add missing disabled API coverage
2026-04-14 00:03:38 +08:00
Nan Gao f4c17c66ce fix(middleware): fix present_files thread id fallback (#2181)
* fix present files thread id fallback

* fix: resolve present_files thread id from runtime config
2026-04-13 22:59:13 +08:00
lesliewangwyc-dev 1df389b9d0 fix: wrap blocking readability call with asyncio.to_thread in web_fetch (#2157)
* fix: wrap blocking readability call with asyncio.to_thread in web_fetch

The readability extractor internally spawns a Node.js subprocess via
readabilipy, which blocks the async event loop and causes a
BlockingError when web_fetch is invoked inside LangGraph's async
runtime.

Wrap the synchronous extract_article call with asyncio.to_thread to
offload it to a thread pool, unblocking the event loop.

Note: community/infoquest/tools.py has the same latent issue and
should be addressed in a follow-up PR.

Closes #2152

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* test: verify web_fetch offloads extraction via asyncio.to_thread

Add a regression test that monkeypatches asyncio.to_thread to confirm
readability extraction is offloaded to a worker thread, preventing
future refactors from reintroducing the blocking call.

Addresses Copilot review feedback on #2157.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-13 21:15:24 +08:00
5db71cb68c fix(middleware): repair dangling tool-call history after loop interru… (#2035)
* fix(middleware): repair dangling tool-call history after loop interruption (#2029)

* docs(backend): fix middleware chain ordering

---------

Co-authored-by: luoxiao6645 <luoxiao6645@gmail.com>
2026-04-12 19:11:22 +08:00
yangzheli 4efc8d404f feat(frontend): set up Vitest frontend testing infrastructure with CI workflow (#2147)
* feat: set up Vitest frontend testing infrastructure with CI workflow

Migrate existing 4 frontend test files from Node.js native test runner
(node:test + node:assert/strict) to Vitest, reorganize test directory
structure under tests/unit/ mirroring src/ layout, and add a dedicated
CI workflow for frontend unit tests.

- Add vitest as devDependency, remove tsx
- Create vitest.config.ts with @/ path alias
- Migrate tests to Vitest API (test/expect/vi)
- Rename .mjs test files to .ts
- Move tests from src/ to tests/unit/ (mirrors src/ layout)
- Add frontend/Makefile `test` target
- Add .github/workflows/frontend-unit-tests.yml (parallel to backend)
- Update CONTRIBUTING.md, README.md, AGENTS.md, CLAUDE.md

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* style: fix the lint error

* style: fix the lint error

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-04-12 18:00:43 +08:00
Jin 4d4ddb3d3f feat(llm): introduce lightweight circuit breaker to prevent rate-limit bans and resource exhaustion (#2095) 2026-04-12 17:48:40 +08:00
180 changed files with 9722 additions and 1432 deletions
+63
View File
@@ -0,0 +1,63 @@
name: E2E Tests
on:
push:
branches: [ 'main' ]
paths:
- 'frontend/**'
- '.github/workflows/e2e-tests.yml'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- 'frontend/**'
- '.github/workflows/e2e-tests.yml'
concurrency:
group: e2e-tests-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
e2e-tests:
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.draft == false }}
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
- name: Enable Corepack
run: corepack enable
- name: Use pinned pnpm version
run: corepack prepare pnpm@10.26.2 --activate
- name: Install frontend dependencies
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: Install Playwright Chromium
working-directory: frontend
run: npx playwright install chromium --with-deps
- name: Run E2E tests
working-directory: frontend
run: pnpm exec playwright test
env:
SKIP_ENV_VALIDATION: '1'
- name: Upload Playwright report
uses: actions/upload-artifact@v4
if: ${{ !cancelled() }}
with:
name: playwright-report
path: frontend/playwright-report/
retention-days: 7
+43
View File
@@ -0,0 +1,43 @@
name: Frontend Unit Tests
on:
push:
branches: [ 'main' ]
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
concurrency:
group: frontend-unit-tests-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
permissions:
contents: read
jobs:
frontend-unit-tests:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
- name: Enable Corepack
run: corepack enable
- name: Use pinned pnpm version
run: corepack prepare pnpm@10.26.2 --activate
- name: Install frontend dependencies
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: Run unit tests of frontend
working-directory: frontend
run: make test
+3
View File
@@ -40,6 +40,7 @@ coverage/
skills/custom/* skills/custom/*
logs/ logs/
log/ log/
debug.log
# Local git hooks (keep only on this machine, do not push) # Local git hooks (keep only on this machine, do not push)
.githooks/ .githooks/
@@ -55,5 +56,7 @@ web/
backend/Dockerfile.langgraph backend/Dockerfile.langgraph
config.yaml.bak config.yaml.bak
.playwright-mcp .playwright-mcp
/frontend/test-results/
/frontend/playwright-report/
.gstack/ .gstack/
.worktrees .worktrees
+33
View File
@@ -0,0 +1,33 @@
repos:
# Backend: ruff lint + format via uv (uses the same ruff version as backend deps)
- repo: local
hooks:
- id: ruff
name: ruff lint
entry: bash -c 'cd backend && uv run ruff check --fix "${@/#backend\//}"' --
language: system
types_or: [python]
files: ^backend/
- id: ruff-format
name: ruff format
entry: bash -c 'cd backend && uv run ruff format "${@/#backend\//}"' --
language: system
types_or: [python]
files: ^backend/
# Frontend: eslint + prettier (must run from frontend/ for node_modules resolution)
- repo: local
hooks:
- id: frontend-eslint
name: eslint (frontend)
entry: bash -c 'cd frontend && npx eslint --fix "${@/#frontend\//}"' --
language: system
types_or: [javascript, tsx, ts]
files: ^frontend/
- id: frontend-prettier
name: prettier (frontend)
entry: bash -c 'cd frontend && npx prettier --write "${@/#frontend\//}"' --
language: system
files: ^frontend/
types_or: [javascript, tsx, ts, json, css]
+12 -7
View File
@@ -166,7 +166,7 @@ Required tools:
1. **Configure the application** (same as Docker setup above) 1. **Configure the application** (same as Docker setup above)
2. **Install dependencies**: 2. **Install dependencies** (this also sets up pre-commit hooks):
```bash ```bash
make install make install
``` ```
@@ -298,19 +298,24 @@ Nginx (port 2026) ← Unified entry point
```bash ```bash
# Backend tests # Backend tests
cd backend cd backend
uv run pytest make test
# Frontend checks # Frontend unit tests
cd frontend cd frontend
pnpm check make test
# Frontend E2E tests (requires Chromium; builds and auto-starts the Next.js production server)
cd frontend
make test-e2e
``` ```
### PR Regression Checks ### PR Regression Checks
Every pull request runs the backend regression workflow at [.github/workflows/backend-unit-tests.yml](.github/workflows/backend-unit-tests.yml), including: Every pull request triggers the following CI workflows:
- `tests/test_provisioner_kubeconfig.py` - **Backend unit tests** — [.github/workflows/backend-unit-tests.yml](.github/workflows/backend-unit-tests.yml)
- `tests/test_docker_sandbox_mode_detection.py` - **Frontend unit tests** — [.github/workflows/frontend-unit-tests.yml](.github/workflows/frontend-unit-tests.yml)
- **Frontend E2E tests** — [.github/workflows/e2e-tests.yml](.github/workflows/e2e-tests.yml) (triggered only when `frontend/` files change)
## Code Style ## Code Style
+4 -2
View File
@@ -23,7 +23,7 @@ help:
@echo " make config - Generate local config files (aborts if config already exists)" @echo " make config - Generate local config files (aborts if config already exists)"
@echo " make config-upgrade - Merge new fields from config.example.yaml into config.yaml" @echo " make config-upgrade - Merge new fields from config.example.yaml into config.yaml"
@echo " make check - Check if all required tools are installed" @echo " make check - Check if all required tools are installed"
@echo " make install - Install all dependencies (frontend + backend)" @echo " make install - Install all dependencies (frontend + backend + pre-commit hooks)"
@echo " make setup-sandbox - Pre-pull sandbox container image (recommended)" @echo " make setup-sandbox - Pre-pull sandbox container image (recommended)"
@echo " make dev - Start all services in development mode (with hot-reloading)" @echo " make dev - Start all services in development mode (with hot-reloading)"
@echo " make dev-pro - Start in dev + Gateway mode (experimental, no LangGraph server)" @echo " make dev-pro - Start in dev + Gateway mode (experimental, no LangGraph server)"
@@ -73,6 +73,8 @@ install:
@cd backend && uv sync @cd backend && uv sync
@echo "Installing frontend dependencies..." @echo "Installing frontend dependencies..."
@cd frontend && pnpm install @cd frontend && pnpm install
@echo "Installing pre-commit hooks..."
@$(BACKEND_UV_RUN) --with pre-commit pre-commit install
@echo "✓ All dependencies installed" @echo "✓ All dependencies installed"
@echo "" @echo ""
@echo "==========================================" @echo "=========================================="
@@ -99,7 +101,7 @@ setup-sandbox:
echo ""; \ echo ""; \
if command -v container >/dev/null 2>&1 && [ "$$(uname)" = "Darwin" ]; then \ if command -v container >/dev/null 2>&1 && [ "$$(uname)" = "Darwin" ]; then \
echo "Detected Apple Container on macOS, pulling image..."; \ echo "Detected Apple Container on macOS, pulling image..."; \
container pull "$$IMAGE" || echo "⚠ Apple Container pull failed, will try Docker"; \ container image pull "$$IMAGE" || echo "⚠ Apple Container pull failed, will try Docker"; \
fi; \ fi; \
if command -v docker >/dev/null 2>&1; then \ if command -v docker >/dev/null 2>&1; then \
echo "Pulling image using Docker..."; \ echo "Pulling image using Docker..."; \
+3 -1
View File
@@ -264,7 +264,7 @@ On Windows, run the local development flow from Git Bash. Native `cmd.exe` and P
2. **Install dependencies**: 2. **Install dependencies**:
```bash ```bash
make install # Install backend + frontend dependencies make install # Install backend + frontend dependencies + pre-commit hooks
``` ```
3. **(Optional) Pre-pull sandbox image**: 3. **(Optional) Pre-pull sandbox image**:
@@ -658,6 +658,8 @@ This is the difference between a chatbot with tool access and an agent with an a
**Summarization**: Within a session, DeerFlow manages context aggressively — summarizing completed sub-tasks, offloading intermediate results to the filesystem, compressing what's no longer immediately relevant. This lets it stay sharp across long, multi-step tasks without blowing the context window. **Summarization**: Within a session, DeerFlow manages context aggressively — summarizing completed sub-tasks, offloading intermediate results to the filesystem, compressing what's no longer immediately relevant. This lets it stay sharp across long, multi-step tasks without blowing the context window.
**Strict Tool-Call Recovery**: When a provider or middleware interrupts a tool-call loop, DeerFlow now strips provider-level raw tool-call metadata on forced-stop assistant messages and injects placeholder tool results for dangling calls before the next model invocation. This keeps OpenAI-compatible reasoning models that strictly validate `tool_call_id` sequences from failing with malformed history errors.
### Long-Term Memory ### Long-Term Memory
Most agents forget everything the moment a conversation ends. DeerFlow remembers. Most agents forget everything the moment a conversation ends. DeerFlow remembers.
+16 -10
View File
@@ -156,20 +156,26 @@ from deerflow.config import get_app_config
### Middleware Chain ### Middleware Chain
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`: Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`_build_middlewares`):
1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory 1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation 2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state 3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption) 4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption), including raw provider tool-call payloads preserved only in `additional_kwargs["tool_calls"]`
5. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider. 5. **LLMErrorHandlingMiddleware** - Normalizes provider/model invocation failures into recoverable assistant-facing errors before later middleware/tool stages run
6. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled) 6. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
7. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode) 7. **SandboxAuditMiddleware** - Audits sandboxed shell/file operations for security logging before tool execution continues
8. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model 8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
9. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses) 9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
10. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support) 10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
11. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if subagent_enabled) 11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
12. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last) 12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
15. **DeferredToolFilterMiddleware** - Hides deferred tool schemas from the bound model until tool search is enabled (optional)
16. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
17. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
18. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
### Configuration System ### Configuration System
+20 -1
View File
@@ -23,6 +23,16 @@ _CHANNEL_REGISTRY: dict[str, str] = {
"wecom": "app.channels.wecom:WeComChannel", "wecom": "app.channels.wecom:WeComChannel",
} }
# Keys that indicate a user has configured credentials for a channel.
_CHANNEL_CREDENTIAL_KEYS: dict[str, list[str]] = {
"discord": ["bot_token"],
"feishu": ["app_id", "app_secret"],
"slack": ["bot_token", "app_token"],
"telegram": ["bot_token"],
"wecom": ["bot_id", "bot_secret"],
"wechat": ["bot_token"],
}
_CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL" _CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL" _CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
@@ -88,7 +98,16 @@ class ChannelService:
if not isinstance(channel_config, dict): if not isinstance(channel_config, dict):
continue continue
if not channel_config.get("enabled", False): if not channel_config.get("enabled", False):
logger.info("Channel %s is disabled, skipping", name) cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
if has_creds:
logger.warning(
"Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.",
name,
name,
)
else:
logger.info("Channel %s is disabled, skipping", name)
continue continue
await self._start_channel(name, channel_config) await self._start_channel(name, channel_config)
+20 -2
View File
@@ -16,13 +16,31 @@ logger = logging.getLogger(__name__)
_slack_md_converter = SlackMarkdownConverter() _slack_md_converter = SlackMarkdownConverter()
def _normalize_allowed_users(allowed_users: Any) -> set[str]:
if allowed_users is None:
return set()
if isinstance(allowed_users, str):
values = [allowed_users]
elif isinstance(allowed_users, list | tuple | set):
values = allowed_users
else:
logger.warning(
"Slack allowed_users should be a list of Slack user IDs or a single Slack user ID string; treating %s as one string value",
type(allowed_users).__name__,
)
values = [allowed_users]
return {str(user_id) for user_id in values if str(user_id)}
class SlackChannel(Channel): class SlackChannel(Channel):
"""Slack IM channel using Socket Mode (WebSocket, no public IP). """Slack IM channel using Socket Mode (WebSocket, no public IP).
Configuration keys (in ``config.yaml`` under ``channels.slack``): Configuration keys (in ``config.yaml`` under ``channels.slack``):
- ``bot_token``: Slack Bot User OAuth Token (xoxb-...). - ``bot_token``: Slack Bot User OAuth Token (xoxb-...).
- ``app_token``: Slack App-Level Token (xapp-...) for Socket Mode. - ``app_token``: Slack App-Level Token (xapp-...) for Socket Mode.
- ``allowed_users``: (optional) List of allowed Slack user IDs. Empty = allow all. - ``allowed_users``: (optional) List of allowed Slack user IDs, or a
single Slack user ID string as shorthand. Empty = allow all. Other
scalar values are treated as a single string with a warning.
""" """
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None: def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
@@ -30,7 +48,7 @@ class SlackChannel(Channel):
self._socket_client = None self._socket_client = None
self._web_client = None self._web_client = None
self._loop: asyncio.AbstractEventLoop | None = None self._loop: asyncio.AbstractEventLoop | None = None
self._allowed_users: set[str] = {str(user_id) for user_id in config.get("allowed_users", [])} self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
async def start(self) -> None: async def start(self) -> None:
if self._running: if self._running:
+16 -2
View File
@@ -1,3 +1,4 @@
import asyncio
import logging import logging
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@@ -32,6 +33,11 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Upper bound (seconds) each lifespan shutdown hook is allowed to run.
# Bounds worker exit time so uvicorn's reload supervisor does not keep
# firing signals into a worker that is stuck waiting for shutdown cleanup.
_SHUTDOWN_HOOK_TIMEOUT_SECONDS = 5.0
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
@@ -63,11 +69,19 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
yield yield
# Stop channel service on shutdown # Stop channel service on shutdown (bounded to prevent worker hang)
try: try:
from app.channels.service import stop_channel_service from app.channels.service import stop_channel_service
await stop_channel_service() await asyncio.wait_for(
stop_channel_service(),
timeout=_SHUTDOWN_HOOK_TIMEOUT_SECONDS,
)
except TimeoutError:
logger.warning(
"Channel service shutdown exceeded %.1fs; proceeding with worker exit.",
_SHUTDOWN_HOOK_TIMEOUT_SECONDS,
)
except Exception: except Exception:
logger.exception("Failed to stop channel service") logger.exception("Failed to stop channel service")
+42 -4
View File
@@ -8,6 +8,7 @@ import yaml
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from deerflow.config.agents_api_config import get_agents_api_config
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
@@ -24,6 +25,7 @@ class AgentResponse(BaseModel):
description: str = Field(default="", description="Agent description") description: str = Field(default="", description="Agent description")
model: str | None = Field(default=None, description="Optional model override") model: str | None = Field(default=None, description="Optional model override")
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist") tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
skills: list[str] | None = Field(default=None, description="Optional skill whitelist (None=all, []=none)")
soul: str | None = Field(default=None, description="SOUL.md content") soul: str | None = Field(default=None, description="SOUL.md content")
@@ -40,6 +42,7 @@ class AgentCreateRequest(BaseModel):
description: str = Field(default="", description="Agent description") description: str = Field(default="", description="Agent description")
model: str | None = Field(default=None, description="Optional model override") model: str | None = Field(default=None, description="Optional model override")
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist") tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
skills: list[str] | None = Field(default=None, description="Optional skill whitelist (None=all enabled, []=none)")
soul: str = Field(default="", description="SOUL.md content — agent personality and behavioral guardrails") soul: str = Field(default="", description="SOUL.md content — agent personality and behavioral guardrails")
@@ -49,6 +52,7 @@ class AgentUpdateRequest(BaseModel):
description: str | None = Field(default=None, description="Updated description") description: str | None = Field(default=None, description="Updated description")
model: str | None = Field(default=None, description="Updated model override") model: str | None = Field(default=None, description="Updated model override")
tool_groups: list[str] | None = Field(default=None, description="Updated tool group whitelist") tool_groups: list[str] | None = Field(default=None, description="Updated tool group whitelist")
skills: list[str] | None = Field(default=None, description="Updated skill whitelist (None=all, []=none)")
soul: str | None = Field(default=None, description="Updated SOUL.md content") soul: str | None = Field(default=None, description="Updated SOUL.md content")
@@ -73,6 +77,15 @@ def _normalize_agent_name(name: str) -> str:
return name.lower() return name.lower()
def _require_agents_api_enabled() -> None:
"""Reject access unless the custom-agent management API is explicitly enabled."""
if not get_agents_api_config().enabled:
raise HTTPException(
status_code=403,
detail=("Custom-agent management API is disabled. Set agents_api.enabled=true to expose agent and user-profile routes over HTTP."),
)
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse: def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
"""Convert AgentConfig to AgentResponse.""" """Convert AgentConfig to AgentResponse."""
soul: str | None = None soul: str | None = None
@@ -84,6 +97,7 @@ def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False
description=agent_cfg.description, description=agent_cfg.description,
model=agent_cfg.model, model=agent_cfg.model,
tool_groups=agent_cfg.tool_groups, tool_groups=agent_cfg.tool_groups,
skills=agent_cfg.skills,
soul=soul, soul=soul,
) )
@@ -100,6 +114,8 @@ async def list_agents() -> AgentsListResponse:
Returns: Returns:
List of all custom agents with their metadata and soul content. List of all custom agents with their metadata and soul content.
""" """
_require_agents_api_enabled()
try: try:
agents = list_custom_agents() agents = list_custom_agents()
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents]) return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
@@ -125,6 +141,7 @@ async def check_agent_name(name: str) -> dict:
Raises: Raises:
HTTPException: 422 if the name is invalid. HTTPException: 422 if the name is invalid.
""" """
_require_agents_api_enabled()
_validate_agent_name(name) _validate_agent_name(name)
normalized = _normalize_agent_name(name) normalized = _normalize_agent_name(name)
available = not get_paths().agent_dir(normalized).exists() available = not get_paths().agent_dir(normalized).exists()
@@ -149,6 +166,7 @@ async def get_agent(name: str) -> AgentResponse:
Raises: Raises:
HTTPException: 404 if agent not found. HTTPException: 404 if agent not found.
""" """
_require_agents_api_enabled()
_validate_agent_name(name) _validate_agent_name(name)
name = _normalize_agent_name(name) name = _normalize_agent_name(name)
@@ -181,6 +199,7 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
Raises: Raises:
HTTPException: 409 if agent already exists, 422 if name is invalid. HTTPException: 409 if agent already exists, 422 if name is invalid.
""" """
_require_agents_api_enabled()
_validate_agent_name(request.name) _validate_agent_name(request.name)
normalized_name = _normalize_agent_name(request.name) normalized_name = _normalize_agent_name(request.name)
@@ -200,6 +219,8 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
config_data["model"] = request.model config_data["model"] = request.model
if request.tool_groups is not None: if request.tool_groups is not None:
config_data["tool_groups"] = request.tool_groups config_data["tool_groups"] = request.tool_groups
if request.skills is not None:
config_data["skills"] = request.skills
config_file = agent_dir / "config.yaml" config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f: with open(config_file, "w", encoding="utf-8") as f:
@@ -243,6 +264,7 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
Raises: Raises:
HTTPException: 404 if agent not found. HTTPException: 404 if agent not found.
""" """
_require_agents_api_enabled()
_validate_agent_name(name) _validate_agent_name(name)
name = _normalize_agent_name(name) name = _normalize_agent_name(name)
@@ -255,21 +277,32 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
try: try:
# Update config if any config fields changed # Update config if any config fields changed
config_changed = any(v is not None for v in [request.description, request.model, request.tool_groups]) # Use model_fields_set to distinguish "field omitted" from "explicitly set to null".
# This is critical for skills where None means "inherit all" (not "don't change").
fields_set = request.model_fields_set
config_changed = bool(fields_set & {"description", "model", "tool_groups", "skills"})
if config_changed: if config_changed:
updated: dict = { updated: dict = {
"name": agent_cfg.name, "name": agent_cfg.name,
"description": request.description if request.description is not None else agent_cfg.description, "description": request.description if "description" in fields_set else agent_cfg.description,
} }
new_model = request.model if request.model is not None else agent_cfg.model new_model = request.model if "model" in fields_set else agent_cfg.model
if new_model is not None: if new_model is not None:
updated["model"] = new_model updated["model"] = new_model
new_tool_groups = request.tool_groups if request.tool_groups is not None else agent_cfg.tool_groups new_tool_groups = request.tool_groups if "tool_groups" in fields_set else agent_cfg.tool_groups
if new_tool_groups is not None: if new_tool_groups is not None:
updated["tool_groups"] = new_tool_groups updated["tool_groups"] = new_tool_groups
# skills: None = inherit all, [] = no skills, ["a","b"] = whitelist
if "skills" in fields_set:
new_skills = request.skills
else:
new_skills = agent_cfg.skills
if new_skills is not None:
updated["skills"] = new_skills
config_file = agent_dir / "config.yaml" config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f: with open(config_file, "w", encoding="utf-8") as f:
yaml.dump(updated, f, default_flow_style=False, allow_unicode=True) yaml.dump(updated, f, default_flow_style=False, allow_unicode=True)
@@ -315,6 +348,8 @@ async def get_user_profile() -> UserProfileResponse:
Returns: Returns:
UserProfileResponse with content=None if USER.md does not exist yet. UserProfileResponse with content=None if USER.md does not exist yet.
""" """
_require_agents_api_enabled()
try: try:
user_md_path = get_paths().user_md_file user_md_path = get_paths().user_md_file
if not user_md_path.exists(): if not user_md_path.exists():
@@ -341,6 +376,8 @@ async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileR
Returns: Returns:
UserProfileResponse with the saved content. UserProfileResponse with the saved content.
""" """
_require_agents_api_enabled()
try: try:
paths = get_paths() paths = get_paths()
paths.base_dir.mkdir(parents=True, exist_ok=True) paths.base_dir.mkdir(parents=True, exist_ok=True)
@@ -367,6 +404,7 @@ async def delete_agent(name: str) -> None:
Raises: Raises:
HTTPException: 404 if agent not found. HTTPException: 404 if agent not found.
""" """
_require_agents_api_enabled()
_validate_agent_name(name) _validate_agent_name(name)
name = _normalize_agent_name(name) name = _normalize_agent_name(name)
+22 -5
View File
@@ -17,10 +17,17 @@ class ModelResponse(BaseModel):
supports_reasoning_effort: bool = Field(default=False, description="Whether model supports reasoning effort") supports_reasoning_effort: bool = Field(default=False, description="Whether model supports reasoning effort")
class TokenUsageResponse(BaseModel):
"""Token usage display configuration."""
enabled: bool = Field(default=False, description="Whether token usage display is enabled")
class ModelsListResponse(BaseModel): class ModelsListResponse(BaseModel):
"""Response model for listing all models.""" """Response model for listing all models."""
models: list[ModelResponse] models: list[ModelResponse]
token_usage: TokenUsageResponse
@router.get( @router.get(
@@ -36,7 +43,7 @@ async def list_models() -> ModelsListResponse:
excluding sensitive fields like API keys and internal configuration. excluding sensitive fields like API keys and internal configuration.
Returns: Returns:
A list of all configured models with their metadata. A list of all configured models with their metadata and token usage display settings.
Example Response: Example Response:
```json ```json
@@ -44,17 +51,24 @@ async def list_models() -> ModelsListResponse:
"models": [ "models": [
{ {
"name": "gpt-4", "name": "gpt-4",
"model": "gpt-4",
"display_name": "GPT-4", "display_name": "GPT-4",
"description": "OpenAI GPT-4 model", "description": "OpenAI GPT-4 model",
"supports_thinking": false "supports_thinking": false,
"supports_reasoning_effort": false
}, },
{ {
"name": "claude-3-opus", "name": "claude-3-opus",
"model": "claude-3-opus",
"display_name": "Claude 3 Opus", "display_name": "Claude 3 Opus",
"description": "Anthropic Claude 3 Opus model", "description": "Anthropic Claude 3 Opus model",
"supports_thinking": true "supports_thinking": true,
"supports_reasoning_effort": false
} }
] ],
"token_usage": {
"enabled": true
}
} }
``` ```
""" """
@@ -70,7 +84,10 @@ async def list_models() -> ModelsListResponse:
) )
for model in config.models for model in config.models
] ]
return ModelsListResponse(models=models) return ModelsListResponse(
models=models,
token_usage=TokenUsageResponse(enabled=config.token_usage.enabled),
)
@router.get( @router.get(
+18 -12
View File
@@ -1,3 +1,4 @@
import errno
import json import json
import logging import logging
import shutil import shutil
@@ -201,18 +202,23 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
ensure_custom_skill_is_editable(skill_name) ensure_custom_skill_is_editable(skill_name)
skill_dir = get_custom_skill_dir(skill_name) skill_dir = get_custom_skill_dir(skill_name)
prev_content = read_custom_skill_content(skill_name) prev_content = read_custom_skill_content(skill_name)
append_history( try:
skill_name, append_history(
{ skill_name,
"action": "human_delete", {
"author": "human", "action": "human_delete",
"thread_id": None, "author": "human",
"file_path": "SKILL.md", "thread_id": None,
"prev_content": prev_content, "file_path": "SKILL.md",
"new_content": None, "prev_content": prev_content,
"scanner": {"decision": "allow", "reason": "Deletion requested."}, "new_content": None,
}, "scanner": {"decision": "allow", "reason": "Deletion requested."},
) },
)
except OSError as e:
if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}:
raise
logger.warning("Skipping delete history write for custom skill %s due to readonly/permission failure; continuing with skill directory removal: %s", skill_name, e)
shutil.rmtree(skill_dir) shutil.rmtree(skill_dir)
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async()
return {"success": True} return {"success": True}
+1 -1
View File
@@ -121,7 +121,7 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S
try: try:
model = create_chat_model(name=request.model_name, thinking_enabled=False) model = create_chat_model(name=request.model_name, thinking_enabled=False)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)]) response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)], config={"run_name": "suggest_agent"})
raw = _extract_response_text(response.content) raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or [] suggestions = _parse_json_string_list(raw) or []
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()] cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
+39 -6
View File
@@ -7,8 +7,9 @@ import stat
from fastapi import APIRouter, File, HTTPException, UploadFile from fastapi import APIRouter, File, HTTPException, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from deerflow.config.app_config import get_app_config
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
from deerflow.sandbox.sandbox_provider import get_sandbox_provider from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
from deerflow.uploads.manager import ( from deerflow.uploads.manager import (
PathTraversalError, PathTraversalError,
delete_file_safe, delete_file_safe,
@@ -53,6 +54,34 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
os.chmod(file_path, writable_mode, **chmod_kwargs) os.chmod(file_path, writable_mode, **chmod_kwargs)
def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
def _get_uploads_config_value(key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access."""
cfg = get_app_config()
uploads_cfg = getattr(cfg, "uploads", None)
if isinstance(uploads_cfg, dict):
return uploads_cfg.get(key, default)
return getattr(uploads_cfg, key, default)
def _auto_convert_documents_enabled() -> bool:
"""Return whether automatic host-side document conversion is enabled.
The secure default is disabled unless an operator explicitly opts in via
uploads.auto_convert_documents in config.yaml.
"""
try:
raw = _get_uploads_config_value("auto_convert_documents", False)
if isinstance(raw, str):
return raw.strip().lower() in {"1", "true", "yes", "on"}
return bool(raw)
except Exception:
return False
@router.post("", response_model=UploadResponse) @router.post("", response_model=UploadResponse)
async def upload_files( async def upload_files(
thread_id: str, thread_id: str,
@@ -70,8 +99,12 @@ async def upload_files(
uploaded_files = [] uploaded_files = []
sandbox_provider = get_sandbox_provider() sandbox_provider = get_sandbox_provider()
sandbox_id = sandbox_provider.acquire(thread_id) sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
sandbox = sandbox_provider.get(sandbox_id) sandbox = None
if sync_to_sandbox:
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
auto_convert_documents = _auto_convert_documents_enabled()
for file in files: for file in files:
if not file.filename: if not file.filename:
@@ -90,7 +123,7 @@ async def upload_files(
virtual_path = upload_virtual_path(safe_filename) virtual_path = upload_virtual_path(safe_filename)
if sandbox_id != "local": if sync_to_sandbox and sandbox is not None:
_make_file_sandbox_writable(file_path) _make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, content) sandbox.update_file(virtual_path, content)
@@ -105,12 +138,12 @@ async def upload_files(
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}") logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
file_ext = file_path.suffix.lower() file_ext = file_path.suffix.lower()
if file_ext in CONVERTIBLE_EXTENSIONS: if auto_convert_documents and file_ext in CONVERTIBLE_EXTENSIONS:
md_path = await convert_file_to_markdown(file_path) md_path = await convert_file_to_markdown(file_path)
if md_path: if md_path:
md_virtual_path = upload_virtual_path(md_path.name) md_virtual_path = upload_virtual_path(md_path.name)
if sandbox_id != "local": if sync_to_sandbox and sandbox is not None:
_make_file_sandbox_writable(md_path) _make_file_sandbox_writable(md_path)
sandbox.update_file(md_virtual_path, md_path.read_bytes()) sandbox.update_file(md_virtual_path, md_path.read_bytes())
+34 -15
View File
@@ -12,6 +12,7 @@ import json
import logging import logging
import re import re
import time import time
from collections.abc import Mapping
from typing import Any from typing import Any
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
@@ -101,9 +102,10 @@ def resolve_agent_factory(assistant_id: str | None):
"""Resolve the agent factory callable from config. """Resolve the agent factory callable from config.
Custom agents are implemented as ``lead_agent`` + an ``agent_name`` Custom agents are implemented as ``lead_agent`` + an ``agent_name``
injected into ``configurable`` — see :func:`build_run_config`. All injected into ``configurable`` or ``context`` — see
``assistant_id`` values therefore map to the same factory; the routing :func:`build_run_config`. All ``assistant_id`` values therefore map to the
happens inside ``make_lead_agent`` when it reads ``cfg["agent_name"]``. same factory; the routing happens inside ``make_lead_agent`` when it reads
``cfg["agent_name"]``.
""" """
from deerflow.agents.lead_agent.agent import make_lead_agent from deerflow.agents.lead_agent.agent import make_lead_agent
@@ -120,10 +122,12 @@ def build_run_config(
"""Build a RunnableConfig dict for the agent. """Build a RunnableConfig dict for the agent.
When *assistant_id* refers to a custom agent (anything other than When *assistant_id* refers to a custom agent (anything other than
``"lead_agent"`` / ``None``), the name is forwarded as ``"lead_agent"`` / ``None``), the name is forwarded as ``agent_name`` in
``configurable["agent_name"]``. ``make_lead_agent`` reads this key to whichever runtime options container is active: ``context`` for
load the matching ``agents/<name>/SOUL.md`` and per-agent config — LangGraph >= 0.6.0 requests, otherwise ``configurable``.
without it the agent silently runs as the default lead agent. ``make_lead_agent`` reads this key to load the matching
``agents/<name>/SOUL.md`` and per-agent config — without it the agent
silently runs as the default lead agent.
This mirrors the channel manager's ``_resolve_run_params`` logic so that This mirrors the channel manager's ``_resolve_run_params`` logic so that
the LangGraph Platform-compatible HTTP API and the IM channel path behave the LangGraph Platform-compatible HTTP API and the IM channel path behave
@@ -142,7 +146,14 @@ def build_run_config(
thread_id, thread_id,
list(request_config.get("configurable", {}).keys()), list(request_config.get("configurable", {}).keys()),
) )
config["context"] = request_config["context"] context_value = request_config["context"]
if context_value is None:
context = {}
elif isinstance(context_value, Mapping):
context = dict(context_value)
else:
raise ValueError("request config 'context' must be a mapping or null.")
config["context"] = context
else: else:
configurable = {"thread_id": thread_id} configurable = {"thread_id": thread_id}
configurable.update(request_config.get("configurable", {})) configurable.update(request_config.get("configurable", {}))
@@ -154,13 +165,19 @@ def build_run_config(
config["configurable"] = {"thread_id": thread_id} config["configurable"] = {"thread_id": thread_id}
# Inject custom agent name when the caller specified a non-default assistant. # Inject custom agent name when the caller specified a non-default assistant.
# Honour an explicit configurable["agent_name"] in the request if already set. # Honour an explicit agent_name in the active runtime options container.
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "configurable" in config: if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID:
if "agent_name" not in config["configurable"]: normalized = assistant_id.strip().lower().replace("_", "-")
normalized = assistant_id.strip().lower().replace("_", "-") if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized): raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.") if "configurable" in config:
config["configurable"]["agent_name"] = normalized target = config["configurable"]
elif "context" in config:
target = config["context"]
else:
target = config.setdefault("configurable", {})
if target is not None and "agent_name" not in target:
target["agent_name"] = normalized
if metadata: if metadata:
config.setdefault("metadata", {}).update(metadata) config.setdefault("metadata", {}).update(metadata)
return config return config
@@ -298,6 +315,8 @@ async def start_run(
"is_plan_mode", "is_plan_mode",
"subagent_enabled", "subagent_enabled",
"max_concurrent_subagents", "max_concurrent_subagents",
"agent_name",
"is_bootstrap",
} }
configurable = config.setdefault("configurable", {}) configurable = config.setdefault("configurable", {})
for key in _CONTEXT_CONFIGURABLE_KEYS: for key in _CONTEXT_CONFIGURABLE_KEYS:
+78 -13
View File
@@ -19,24 +19,78 @@ import asyncio
import logging import logging
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain_core.messages import HumanMessage
from deerflow.agents import make_lead_agent try:
from prompt_toolkit import PromptSession
from prompt_toolkit.history import InMemoryHistory
_HAS_PROMPT_TOOLKIT = True
except ImportError:
_HAS_PROMPT_TOOLKIT = False
load_dotenv() load_dotenv()
logging.basicConfig( _LOG_FMT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
level=logging.INFO, _LOG_DATEFMT = "%Y-%m-%d %H:%M:%S"
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
) def _logging_level_from_config(name: str) -> int:
"""Map ``config.yaml`` ``log_level`` string to a ``logging`` level constant."""
mapping = logging.getLevelNamesMapping()
return mapping.get((name or "info").strip().upper(), logging.INFO)
def _setup_logging(log_level: str) -> None:
"""Send application logs to ``debug.log`` at *log_level*; do not print them on the console.
Idempotent: any pre-existing handlers on the root logger (e.g. installed by
``logging.basicConfig`` in transitively imported modules) are removed so the
debug session output only lands in ``debug.log``.
"""
level = _logging_level_from_config(log_level)
root = logging.root
for h in list(root.handlers):
root.removeHandler(h)
h.close()
root.setLevel(level)
file_handler = logging.FileHandler("debug.log", mode="a", encoding="utf-8")
file_handler.setLevel(level)
file_handler.setFormatter(logging.Formatter(_LOG_FMT, datefmt=_LOG_DATEFMT))
root.addHandler(file_handler)
def _update_logging_level(log_level: str) -> None:
"""Update the root logger and existing handlers to *log_level*."""
level = _logging_level_from_config(log_level)
root = logging.root
root.setLevel(level)
for handler in root.handlers:
handler.setLevel(level)
async def main(): async def main():
# Install file logging first so warnings emitted while loading config do not
# leak onto the interactive terminal via Python's lastResort handler.
_setup_logging("info")
from deerflow.config import get_app_config
app_config = get_app_config()
_update_logging_level(app_config.log_level)
# Delay the rest of the deerflow imports until *after* logging is installed
# so that any import-time side effects (e.g. deerflow.agents starts a
# background skill-loader thread on import) emit logs to debug.log instead
# of leaking onto the interactive terminal via Python's lastResort handler.
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.agents import make_lead_agent
from deerflow.mcp import initialize_mcp_tools
# Initialize MCP tools at startup # Initialize MCP tools at startup
try: try:
from deerflow.mcp import initialize_mcp_tools
await initialize_mcp_tools() await initialize_mcp_tools()
except Exception as e: except Exception as e:
print(f"Warning: Failed to initialize MCP tools: {e}") print(f"Warning: Failed to initialize MCP tools: {e}")
@@ -52,16 +106,27 @@ async def main():
} }
} }
runtime = Runtime(context={"thread_id": config["configurable"]["thread_id"]})
config["configurable"]["__pregel_runtime"] = runtime
agent = make_lead_agent(config) agent = make_lead_agent(config)
session = PromptSession(history=InMemoryHistory()) if _HAS_PROMPT_TOOLKIT else None
print("=" * 50) print("=" * 50)
print("Lead Agent Debug Mode") print("Lead Agent Debug Mode")
print("Type 'quit' or 'exit' to stop") print("Type 'quit' or 'exit' to stop")
print(f"Logs: debug.log (log_level={app_config.log_level})")
if not _HAS_PROMPT_TOOLKIT:
print("Tip: `uv sync --group dev` to enable arrow-key & history support")
print("=" * 50) print("=" * 50)
while True: while True:
try: try:
user_input = input("\nYou: ").strip() if session:
user_input = (await session.prompt_async("\nYou: ")).strip()
else:
user_input = input("\nYou: ").strip()
if not user_input: if not user_input:
continue continue
if user_input.lower() in ("quit", "exit"): if user_input.lower() in ("quit", "exit"):
@@ -70,15 +135,15 @@ async def main():
# Invoke the agent # Invoke the agent
state = {"messages": [HumanMessage(content=user_input)]} state = {"messages": [HumanMessage(content=user_input)]}
result = await agent.ainvoke(state, config=config, context={"thread_id": "debug-thread-001"}) result = await agent.ainvoke(state, config=config)
# Print the response # Print the response
if result.get("messages"): if result.get("messages"):
last_message = result["messages"][-1] last_message = result["messages"][-1]
print(f"\nAgent: {last_message.content}") print(f"\nAgent: {last_message.content}")
except KeyboardInterrupt: except (KeyboardInterrupt, EOFError):
print("\nInterrupted. Goodbye!") print("\nGoodbye!")
break break
except Exception as e: except Exception as e:
print(f"\nError: {e}") print(f"\nError: {e}")
+1 -1
View File
@@ -199,7 +199,7 @@ class ThreadState(AgentState):
│ Built-in Tools │ │ Configured Tools │ │ MCP Tools │ │ Built-in Tools │ │ Configured Tools │ │ MCP Tools │
│ (packages/harness/deerflow/tools/) │ │ (config.yaml) │ │ (extensions.json) │ │ (packages/harness/deerflow/tools/) │ │ (config.yaml) │ │ (extensions.json) │
├─────────────────────┤ ├─────────────────────┤ ├─────────────────────┤ ├─────────────────────┤ ├─────────────────────┤ ├─────────────────────┤
│ - present_file │ │ - web_search │ │ - github │ │ - present_files │ │ - web_search │ │ - github │
│ - ask_clarification │ │ - web_fetch │ │ - filesystem │ │ - ask_clarification │ │ - web_fetch │ │ - filesystem │
│ - view_image │ │ - bash │ │ - postgres │ │ - view_image │ │ - bash │ │ - postgres │
│ │ │ - read_file │ │ - brave-search │ │ │ │ - read_file │ │ - brave-search │
+6 -3
View File
@@ -2,12 +2,12 @@
## 概述 ## 概述
DeerFlow 后端提供了完整的文件上传功能,支持多文件上传,并自动将 Office 文档和 PDF 转换为 Markdown 格式。 DeerFlow 后端提供了完整的文件上传功能,支持多文件上传,并可选地将 Office 文档和 PDF 转换为 Markdown 格式。
## 功能特性 ## 功能特性
- ✅ 支持多文件同时上传 - ✅ 支持多文件同时上传
-自动转换文档为 MarkdownPDF、PPT、Excel、Word -可选地转换文档为 MarkdownPDF、PPT、Excel、Word
- ✅ 文件存储在线程隔离的目录中 - ✅ 文件存储在线程隔离的目录中
- ✅ Agent 自动感知已上传的文件 - ✅ Agent 自动感知已上传的文件
- ✅ 支持文件列表查询和删除 - ✅ 支持文件列表查询和删除
@@ -86,7 +86,7 @@ DELETE /api/threads/{thread_id}/uploads/{filename}
## 支持的文档格式 ## 支持的文档格式
以下格式会自动转换为 Markdown: 以下格式在显式启用 `uploads.auto_convert_documents: true`会自动转换为 Markdown
- PDF (`.pdf`) - PDF (`.pdf`)
- PowerPoint (`.ppt`, `.pptx`) - PowerPoint (`.ppt`, `.pptx`)
- Excel (`.xls`, `.xlsx`) - Excel (`.xls`, `.xlsx`)
@@ -94,6 +94,8 @@ DELETE /api/threads/{thread_id}/uploads/{filename}
转换后的 Markdown 文件会保存在同一目录下,文件名为原文件名 + `.md` 扩展名。 转换后的 Markdown 文件会保存在同一目录下,文件名为原文件名 + `.md` 扩展名。
默认情况下,自动转换是关闭的,以避免在网关主机上对不受信任的 Office/PDF 上传执行解析。只有在受信任部署中明确接受此风险时,才应将 `uploads.auto_convert_documents` 设置为 `true`
## Agent 集成 ## Agent 集成
### 自动文件列举 ### 自动文件列举
@@ -207,6 +209,7 @@ backend/.deer-flow/threads/
- 最大文件大小:100MB(可在 nginx.conf 中配置 `client_max_body_size` - 最大文件大小:100MB(可在 nginx.conf 中配置 `client_max_body_size`
- 文件名安全性:系统会自动验证文件路径,防止目录遍历攻击 - 文件名安全性:系统会自动验证文件路径,防止目录遍历攻击
- 线程隔离:每个线程的上传文件相互隔离,无法跨线程访问 - 线程隔离:每个线程的上传文件相互隔离,无法跨线程访问
- 自动文档转换默认关闭;如需启用,需在 `config.yaml` 中显式设置 `uploads.auto_convert_documents: true`
## 技术实现 ## 技术实现
+1 -1
View File
@@ -296,7 +296,7 @@ These are the tool names your provider will see in `request.tool_name`:
| `web_search` | Web search query | | `web_search` | Web search query |
| `web_fetch` | Fetch URL content | | `web_fetch` | Fetch URL content |
| `image_search` | Image search | | `image_search` | Image search |
| `present_file` | Present file to user | | `present_files` | Present file to user |
| `view_image` | Display image | | `view_image` | Display image |
| `ask_clarification` | Ask user a question | | `ask_clarification` | Ask user a question |
| `task` | Delegate to subagent | | `task` | Delegate to subagent |
+35
View File
@@ -45,6 +45,41 @@ Example:
} }
``` ```
## Custom Tool Interceptors
You can register custom interceptors that run before every MCP tool call. This is useful for injecting per-request headers (e.g., user auth tokens from the LangGraph execution context), logging, or metrics.
Declare interceptors in `extensions_config.json` using the `mcpInterceptors` field:
```json
{
"mcpInterceptors": [
"my_package.mcp.auth:build_auth_interceptor"
],
"mcpServers": { ... }
}
```
Each entry is a Python import path in `module:variable` format (resolved via `resolve_variable`). The variable must be a **no-arg builder function** that returns an async interceptor compatible with `MultiServerMCPClient`s `tool_interceptors` interface, or `None` to skip.
Example interceptor that injects auth headers from LangGraph metadata:
```python
def build_auth_interceptor():
async def interceptor(request, handler):
from langgraph.config import get_config
metadata = get_config().get("metadata", {})
headers = dict(request.headers or {})
if token := metadata.get("auth_token"):
headers["X-Auth-Token"] = token
return await handler(request.override(headers=headers))
return interceptor
```
- A single string value is accepted and normalized to a one-element list.
- Invalid paths or builder failures are logged as warnings without blocking other interceptors.
- The builder return value must be `callable`; non-callable values are skipped with a warning.
## How It Works ## How It Works
MCP servers expose tools that are automatically discovered and integrated into DeerFlows agent system at runtime. Once enabled, these tools become available to agents without additional code changes. MCP servers expose tools that are automatically discovered and integrated into DeerFlows agent system at runtime. Once enabled, these tools become available to agents without additional code changes.
+1 -1
View File
@@ -24,7 +24,7 @@
- [ ] Optimize async concurrency in agent hot path (IM channels multi-task scenario) - [ ] Optimize async concurrency in agent hot path (IM channels multi-task scenario)
- [ ] Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py` - [ ] Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
- Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search) - Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search)
- Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater - [x] Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
- Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O - Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O
- For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker) - For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker)
+28
View File
@@ -41,6 +41,13 @@ summarization:
# Custom summary prompt (optional) # Custom summary prompt (optional)
summary_prompt: null summary_prompt: null
# Tool names treated as skill file reads for skill rescue
skill_file_read_tool_names:
- read_file
- read
- view
- cat
``` ```
### Configuration Options ### Configuration Options
@@ -125,6 +132,26 @@ keep:
- **Default**: `null` (uses LangChain's default prompt) - **Default**: `null` (uses LangChain's default prompt)
- **Description**: Custom prompt template for generating summaries. The prompt should guide the model to extract the most important context. - **Description**: Custom prompt template for generating summaries. The prompt should guide the model to extract the most important context.
#### `preserve_recent_skill_count`
- **Type**: Integer (≥ 0)
- **Default**: `5`
- **Description**: Number of most-recently-loaded skill files (tool results whose tool name is in `skill_file_read_tool_names` and whose target path is under `skills.container_path`, e.g. `/mnt/skills/...`) that are rescued from summarization. Prevents the agent from losing skill instructions after compression. Set to `0` to disable skill rescue entirely.
#### `preserve_recent_skill_tokens`
- **Type**: Integer (≥ 0)
- **Default**: `25000`
- **Description**: Total token budget reserved for rescued skill reads. Once this budget is exhausted, older skill bundles are allowed to be summarized.
#### `preserve_recent_skill_tokens_per_skill`
- **Type**: Integer (≥ 0)
- **Default**: `5000`
- **Description**: Per-skill token cap. Any individual skill read whose tool result exceeds this size is not rescued (it falls through to the summarizer like ordinary content).
#### `skill_file_read_tool_names`
- **Type**: List of strings
- **Default**: `["read_file", "read", "view", "cat"]`
- **Description**: Tool names treated as skill file reads during summarization rescue. A tool call is only eligible for skill rescue when its name appears in this list and its target path is under `skills.container_path`.
**Default Prompt Behavior:** **Default Prompt Behavior:**
The default LangChain prompt instructs the model to: The default LangChain prompt instructs the model to:
- Extract highest quality/most relevant context - Extract highest quality/most relevant context
@@ -147,6 +174,7 @@ The default LangChain prompt instructs the model to:
- A single summary message is added - A single summary message is added
- Recent messages are preserved - Recent messages are preserved
6. **AI/Tool Pair Protection**: The system ensures AI messages and their corresponding tool messages stay together 6. **AI/Tool Pair Protection**: The system ensures AI messages and their corresponding tool messages stay together
7. **Skill Rescue**: Before the summary is generated, the most recently loaded skill files (tool results whose tool name is in `skill_file_read_tool_names` and whose target path is under `skills.container_path`) are lifted out of the summarization set and prepended to the preserved tail. Selection walks newest-first under three budgets: `preserve_recent_skill_count`, `preserve_recent_skill_tokens`, and `preserve_recent_skill_tokens_per_skill`. The triggering AIMessage and all of its paired ToolMessages move together so tool_call ↔ tool_result pairing stays intact.
### Token Counting ### Token Counting
@@ -27,7 +27,7 @@ from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import get_app_config
from deerflow.config.checkpointer_config import CheckpointerConfig from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import resolve_sqlite_conn_str from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -67,6 +67,7 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
raise ImportError(SQLITE_INSTALL) from exc raise ImportError(SQLITE_INSTALL) from exc
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db") conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
ensure_sqlite_parent_dir(conn_str)
with SqliteSaver.from_conn_string(conn_str) as saver: with SqliteSaver.from_conn_string(conn_str) as saver:
saver.setup() saver.setup()
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str) logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
@@ -1,28 +1,40 @@
import logging import logging
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware from langchain.agents.middleware import AgentMiddleware
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from deerflow.agents.lead_agent.prompt import apply_prompt_template from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware
from deerflow.agents.middlewares.title_middleware import TitleMiddleware from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import get_app_config
from deerflow.config.memory_config import get_memory_config
from deerflow.config.summarization_config import get_summarization_config from deerflow.config.summarization_config import get_summarization_config
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _get_runtime_config(config: RunnableConfig) -> dict:
"""Merge legacy configurable options with LangGraph runtime context."""
cfg = dict(config.get("configurable", {}) or {})
context = config.get("context", {}) or {}
if isinstance(context, dict):
cfg.update(context)
return cfg
def _resolve_model_name(requested_model_name: str | None = None) -> str: def _resolve_model_name(requested_model_name: str | None = None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured.""" """Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
app_config = get_app_config() app_config = get_app_config()
@@ -38,7 +50,7 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
return default_model_name return default_model_name
def _create_summarization_middleware() -> SummarizationMiddleware | None: def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None:
"""Create and configure the summarization middleware from config.""" """Create and configure the summarization middleware from config."""
config = get_summarization_config() config = get_summarization_config()
@@ -77,7 +89,28 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
if config.summary_prompt is not None: if config.summary_prompt is not None:
kwargs["summary_prompt"] = config.summary_prompt kwargs["summary_prompt"] = config.summary_prompt
return SummarizationMiddleware(**kwargs) hooks: list[BeforeSummarizationHook] = []
if get_memory_config().enabled:
hooks.append(memory_flush_hook)
# The logic below relies on two assumptions holding true: this factory is
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
# config is not expected to change after startup.
try:
skills_container_path = get_app_config().skills.container_path or "/mnt/skills"
except Exception:
logger.exception("Failed to resolve skills container path; falling back to default")
skills_container_path = "/mnt/skills"
return DeerFlowSummarizationMiddleware(
**kwargs,
skills_container_path=skills_container_path,
skill_file_read_tool_names=config.skill_file_read_tool_names,
before_summarization=hooks,
preserve_recent_skill_count=config.preserve_recent_skill_count,
preserve_recent_skill_tokens=config.preserve_recent_skill_tokens,
preserve_recent_skill_tokens_per_skill=config.preserve_recent_skill_tokens_per_skill,
)
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None: def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
@@ -224,7 +257,8 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
middlewares.append(summarization_middleware) middlewares.append(summarization_middleware)
# Add TodoList middleware if plan mode is enabled # Add TodoList middleware if plan mode is enabled
is_plan_mode = config.get("configurable", {}).get("is_plan_mode", False) cfg = _get_runtime_config(config)
is_plan_mode = cfg.get("is_plan_mode", False)
todo_list_middleware = _create_todo_list_middleware(is_plan_mode) todo_list_middleware = _create_todo_list_middleware(is_plan_mode)
if todo_list_middleware is not None: if todo_list_middleware is not None:
middlewares.append(todo_list_middleware) middlewares.append(todo_list_middleware)
@@ -253,9 +287,9 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
middlewares.append(DeferredToolFilterMiddleware()) middlewares.append(DeferredToolFilterMiddleware())
# Add SubagentLimitMiddleware to truncate excess parallel task calls # Add SubagentLimitMiddleware to truncate excess parallel task calls
subagent_enabled = config.get("configurable", {}).get("subagent_enabled", False) subagent_enabled = cfg.get("subagent_enabled", False)
if subagent_enabled: if subagent_enabled:
max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3) max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents)) middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
# LoopDetectionMiddleware — detect and break repetitive tool call loops # LoopDetectionMiddleware — detect and break repetitive tool call loops
@@ -275,7 +309,7 @@ def make_lead_agent(config: RunnableConfig):
from deerflow.tools import get_available_tools from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent from deerflow.tools.builtins import setup_agent
cfg = config.get("configurable", {}) cfg = _get_runtime_config(config)
thinking_enabled = cfg.get("thinking_enabled", True) thinking_enabled = cfg.get("thinking_enabled", True)
reasoning_effort = cfg.get("reasoning_effort", None) reasoning_effort = cfg.get("reasoning_effort", None)
@@ -284,7 +318,7 @@ def make_lead_agent(config: RunnableConfig):
subagent_enabled = cfg.get("subagent_enabled", False) subagent_enabled = cfg.get("subagent_enabled", False)
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3) max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
is_bootstrap = cfg.get("is_bootstrap", False) is_bootstrap = cfg.get("is_bootstrap", False)
agent_name = cfg.get("agent_name") agent_name = validate_agent_name(cfg.get("agent_name"))
agent_config = load_agent_config(agent_name) if not is_bootstrap else None agent_config = load_agent_config(agent_name) if not is_bootstrap else None
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default # Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
@@ -325,6 +359,8 @@ def make_lead_agent(config: RunnableConfig):
"reasoning_effort": reasoning_effort, "reasoning_effort": reasoning_effort,
"is_plan_mode": is_plan_mode, "is_plan_mode": is_plan_mode,
"subagent_enabled": subagent_enabled, "subagent_enabled": subagent_enabled,
"tool_groups": agent_config.tool_groups if agent_config else None,
"available_skills": ["bootstrap"] if is_bootstrap else (agent_config.skills if agent_config and agent_config.skills is not None else None),
} }
) )
@@ -164,6 +164,36 @@ Skip simple one-off tasks.
""" """
def _build_available_subagents_description(available_names: list[str], bash_available: bool) -> str:
"""Dynamically build subagent type descriptions from registry.
Mirrors Codex's pattern where agent_type_description is dynamically generated
from all registered roles, so the LLM knows about every available type.
"""
# Built-in descriptions (kept for backward compatibility with existing prompt quality)
builtin_descriptions = {
"general-purpose": "For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.",
"bash": (
"For command execution (git, build, test, deploy operations)" if bash_available else "Not available in the current sandbox configuration. Use direct file/web tools or switch to AioSandboxProvider for isolated shell access."
),
}
# Lazy import moved outside loop to avoid repeated import overhead
from deerflow.subagents.registry import get_subagent_config
lines = []
for name in available_names:
if name in builtin_descriptions:
lines.append(f"- **{name}**: {builtin_descriptions[name]}")
else:
config = get_subagent_config(name)
if config is not None:
desc = config.description.split("\n")[0].strip() # First line only for brevity
lines.append(f"- **{name}**: {desc}")
return "\n".join(lines)
def _build_subagent_section(max_concurrent: int) -> str: def _build_subagent_section(max_concurrent: int) -> str:
"""Build the subagent system prompt section with dynamic concurrency limit. """Build the subagent system prompt section with dynamic concurrency limit.
@@ -174,13 +204,12 @@ def _build_subagent_section(max_concurrent: int) -> str:
Formatted subagent section string. Formatted subagent section string.
""" """
n = max_concurrent n = max_concurrent
bash_available = "bash" in get_available_subagent_names() available_names = get_available_subagent_names()
available_subagents = ( bash_available = "bash" in available_names
"- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.\n- **bash**: For command execution (git, build, test, deploy operations)"
if bash_available # Dynamically build subagent type descriptions from registry (aligned with Codex's
else "- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.\n" # agent_type_description pattern where all registered roles are listed in the tool spec).
"- **bash**: Not available in the current sandbox configuration. Use direct file/web tools or switch to AioSandboxProvider for isolated shell access." available_subagents = _build_available_subagents_description(available_names, bash_available)
)
direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc." direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc."
direct_execution_example = ( direct_execution_example = (
'# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()' '# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()'
@@ -420,7 +449,7 @@ You: "Deploying to staging..." [proceed]
- Treat `/mnt/user-data/workspace` as your default current working directory for coding and file-editing tasks - Treat `/mnt/user-data/workspace` as your default current working directory for coding and file-editing tasks
- When writing scripts or commands that create/read files from the workspace, prefer relative paths such as `hello.txt`, `../uploads/data.csv`, and `../outputs/report.md` - When writing scripts or commands that create/read files from the workspace, prefer relative paths such as `hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`
- Avoid hardcoding `/mnt/user-data/...` inside generated scripts when a relative path from the workspace is enough - Avoid hardcoding `/mnt/user-data/...` inside generated scripts when a relative path from the workspace is enough
- Final deliverables must be copied to `/mnt/user-data/outputs` and presented using `present_file` tool - Final deliverables must be copied to `/mnt/user-data/outputs` and presented using `present_files` tool
{acp_section} {acp_section}
</working_directory> </working_directory>
@@ -648,7 +677,7 @@ def _build_acp_section() -> str:
"- ACP agents (e.g. codex, claude_code) run in their own independent workspace — NOT in `/mnt/user-data/`\n" "- ACP agents (e.g. codex, claude_code) run in their own independent workspace — NOT in `/mnt/user-data/`\n"
"- When writing prompts for ACP agents, describe the task only — do NOT reference `/mnt/user-data` paths\n" "- When writing prompts for ACP agents, describe the task only — do NOT reference `/mnt/user-data` paths\n"
"- ACP agent results are accessible at `/mnt/acp-workspace/` (read-only) — use `ls`, `read_file`, or `bash cp` to retrieve output files\n" "- ACP agent results are accessible at `/mnt/acp-workspace/` (read-only) — use `ls`, `read_file`, or `bash cp` to retrieve output files\n"
"- To deliver ACP output to the user: copy from `/mnt/acp-workspace/<file>` to `/mnt/user-data/outputs/<file>`, then use `present_file`" "- To deliver ACP output to the user: copy from `/mnt/acp-workspace/<file>` to `/mnt/user-data/outputs/<file>`, then use `present_files`"
) )
@@ -0,0 +1,109 @@
"""Shared helpers for turning conversations into memory update inputs."""
from __future__ import annotations
import re
from copy import copy
from typing import Any
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
_CORRECTION_PATTERNS = (
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
re.compile(r"\btry again\b", re.IGNORECASE),
re.compile(r"\bredo\b", re.IGNORECASE),
re.compile(r"不对"),
re.compile(r"你理解错了"),
re.compile(r"你理解有误"),
re.compile(r"重试"),
re.compile(r"重新来"),
re.compile(r"换一种"),
re.compile(r"改用"),
)
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
def extract_message_text(message: Any) -> str:
"""Extract plain text from message content for filtering and signal detection."""
content = getattr(message, "content", "")
if isinstance(content, list):
text_parts: list[str] = []
for part in content:
if isinstance(part, str):
text_parts.append(part)
elif isinstance(part, dict):
text_val = part.get("text")
if isinstance(text_val, str):
text_parts.append(text_val)
return " ".join(text_parts)
return str(content)
def filter_messages_for_memory(messages: list[Any]) -> list[Any]:
"""Keep only user inputs and final assistant responses for memory updates."""
filtered = []
skip_next_ai = False
for msg in messages:
msg_type = getattr(msg, "type", None)
if msg_type == "human":
content_str = extract_message_text(msg)
if "<uploaded_files>" in content_str:
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
if not stripped:
skip_next_ai = True
continue
clean_msg = copy(msg)
clean_msg.content = stripped
filtered.append(clean_msg)
skip_next_ai = False
else:
filtered.append(msg)
skip_next_ai = False
elif msg_type == "ai":
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
if skip_next_ai:
skip_next_ai = False
continue
filtered.append(msg)
return filtered
def detect_correction(messages: list[Any]) -> bool:
"""Detect explicit user corrections in recent conversation turns."""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = extract_message_text(msg).strip()
if content and any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
return True
return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns."""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = extract_message_text(msg).strip()
if content and any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
@@ -61,48 +61,88 @@ class MemoryUpdateQueue:
return return
with self._lock: with self._lock:
existing_context = next( self._enqueue_locked(
(context for context in self._queue if context.thread_id == thread_id),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id, thread_id=thread_id,
messages=messages, messages=messages,
agent_name=agent_name, agent_name=agent_name,
correction_detected=merged_correction_detected, correction_detected=correction_detected,
reinforcement_detected=merged_reinforcement_detected, reinforcement_detected=reinforcement_detected,
) )
# Check if this thread already has a pending update
# If so, replace it with the newer one
self._queue = [c for c in self._queue if c.thread_id != thread_id]
self._queue.append(context)
# Reset or start the debounce timer
self._reset_timer() self._reset_timer()
logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue)) logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue))
def add_nowait(
self,
thread_id: str,
messages: list[Any],
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
"""Add a conversation and start processing immediately in the background."""
config = get_memory_config()
if not config.enabled:
return
with self._lock:
self._enqueue_locked(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
self._schedule_timer(0)
logger.info("Memory update queued for immediate processing on thread %s, queue size: %d", thread_id, len(self._queue))
def _enqueue_locked(
self,
*,
thread_id: str,
messages: list[Any],
agent_name: str | None,
correction_detected: bool,
reinforcement_detected: bool,
) -> None:
existing_context = next(
(context for context in self._queue if context.thread_id == thread_id),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
)
self._queue = [c for c in self._queue if c.thread_id != thread_id]
self._queue.append(context)
def _reset_timer(self) -> None: def _reset_timer(self) -> None:
"""Reset the debounce timer.""" """Reset the debounce timer."""
config = get_memory_config() config = get_memory_config()
self._schedule_timer(config.debounce_seconds)
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
def _schedule_timer(self, delay_seconds: float) -> None:
"""Schedule queue processing after the provided delay."""
# Cancel existing timer if any # Cancel existing timer if any
if self._timer is not None: if self._timer is not None:
self._timer.cancel() self._timer.cancel()
# Start new timer
self._timer = threading.Timer( self._timer = threading.Timer(
config.debounce_seconds, delay_seconds,
self._process_queue, self._process_queue,
) )
self._timer.daemon = True self._timer.daemon = True
self._timer.start() self._timer.start()
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
def _process_queue(self) -> None: def _process_queue(self) -> None:
"""Process all queued conversation contexts.""" """Process all queued conversation contexts."""
# Import here to avoid circular dependency # Import here to avoid circular dependency
@@ -110,8 +150,8 @@ class MemoryUpdateQueue:
with self._lock: with self._lock:
if self._processing: if self._processing:
# Already processing, reschedule # Preserve immediate flush semantics even if another worker is active.
self._reset_timer() self._schedule_timer(0)
return return
if not self._queue: if not self._queue:
@@ -164,6 +204,13 @@ class MemoryUpdateQueue:
self._process_queue() self._process_queue()
def flush_nowait(self) -> None:
"""Start queue processing immediately in a background thread."""
with self._lock:
# Daemon thread: queued messages may be lost if the process exits
# before _process_queue completes. Acceptable for best-effort memory updates.
self._schedule_timer(0)
def clear(self) -> None: def clear(self) -> None:
"""Clear the queue without processing. """Clear the queue without processing.
@@ -4,6 +4,7 @@ import abc
import json import json
import logging import logging
import threading import threading
import uuid
from datetime import UTC, datetime from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -66,6 +67,8 @@ class FileMemoryStorage(MemoryStorage):
# Per-agent memory cache: keyed by agent_name (None = global) # Per-agent memory cache: keyed by agent_name (None = global)
# Value: (memory_data, file_mtime) # Value: (memory_data, file_mtime)
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {} self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
# Guards all reads and writes to _memory_cache across concurrent callers.
self._cache_lock = threading.Lock()
def _validate_agent_name(self, agent_name: str) -> None: def _validate_agent_name(self, agent_name: str) -> None:
"""Validate that the agent name is safe to use in filesystem paths. """Validate that the agent name is safe to use in filesystem paths.
@@ -114,14 +117,17 @@ class FileMemoryStorage(MemoryStorage):
except OSError: except OSError:
current_mtime = None current_mtime = None
cached = self._memory_cache.get(agent_name) with self._cache_lock:
cached = self._memory_cache.get(agent_name)
if cached is not None and cached[1] == current_mtime:
return cached[0]
if cached is None or cached[1] != current_mtime: memory_data = self._load_memory_from_file(agent_name)
memory_data = self._load_memory_from_file(agent_name)
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, current_mtime) self._memory_cache[agent_name] = (memory_data, current_mtime)
return memory_data
return cached[0] return memory_data
def reload(self, agent_name: str | None = None) -> dict[str, Any]: def reload(self, agent_name: str | None = None) -> dict[str, Any]:
"""Reload memory data from file, forcing cache invalidation.""" """Reload memory data from file, forcing cache invalidation."""
@@ -133,7 +139,8 @@ class FileMemoryStorage(MemoryStorage):
except OSError: except OSError:
mtime = None mtime = None
self._memory_cache[agent_name] = (memory_data, mtime) with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
return memory_data return memory_data
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool: def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
@@ -142,9 +149,12 @@ class FileMemoryStorage(MemoryStorage):
try: try:
file_path.parent.mkdir(parents=True, exist_ok=True) file_path.parent.mkdir(parents=True, exist_ok=True)
memory_data["lastUpdated"] = utc_now_iso_z() # Shallow-copy before adding lastUpdated so the caller's dict is not
# mutated as a side-effect, and the cache reference is not silently
# updated before the file write succeeds.
memory_data = {**memory_data, "lastUpdated": utc_now_iso_z()}
temp_path = file_path.with_suffix(".tmp") temp_path = file_path.with_suffix(f".{uuid.uuid4().hex}.tmp")
with open(temp_path, "w", encoding="utf-8") as f: with open(temp_path, "w", encoding="utf-8") as f:
json.dump(memory_data, f, indent=2, ensure_ascii=False) json.dump(memory_data, f, indent=2, ensure_ascii=False)
@@ -155,7 +165,8 @@ class FileMemoryStorage(MemoryStorage):
except OSError: except OSError:
mtime = None mtime = None
self._memory_cache[agent_name] = (memory_data, mtime) with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
logger.info("Memory saved to %s", file_path) logger.info("Memory saved to %s", file_path)
return True return True
except OSError as e: except OSError as e:
@@ -0,0 +1,31 @@
"""Hooks fired before summarization removes messages from state."""
from __future__ import annotations
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
from deerflow.config.memory_config import get_memory_config
def memory_flush_hook(event: SummarizationEvent) -> None:
"""Flush messages about to be summarized into the memory queue."""
if not get_memory_config().enabled or not event.thread_id:
return
filtered_messages = filter_messages_for_memory(list(event.messages_to_summarize))
user_messages = [message for message in filtered_messages if getattr(message, "type", None) == "human"]
assistant_messages = [message for message in filtered_messages if getattr(message, "type", None) == "ai"]
if not user_messages or not assistant_messages:
return
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue()
queue.add_nowait(
thread_id=event.thread_id,
messages=filtered_messages,
agent_name=event.agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
@@ -1,10 +1,15 @@
"""Memory updater for reading, writing, and updating memory data.""" """Memory updater for reading, writing, and updating memory data."""
import asyncio
import atexit
import concurrent.futures
import copy
import json import json
import logging import logging
import math import math
import re import re
import uuid import uuid
from collections.abc import Awaitable
from typing import Any from typing import Any
from deerflow.agents.memory.prompt import ( from deerflow.agents.memory.prompt import (
@@ -21,6 +26,12 @@ from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
max_workers=4,
thread_name_prefix="memory-updater-sync",
)
atexit.register(lambda: _SYNC_MEMORY_UPDATER_EXECUTOR.shutdown(wait=False))
def _create_empty_memory() -> dict[str, Any]: def _create_empty_memory() -> dict[str, Any]:
"""Backward-compatible wrapper around the storage-layer empty-memory factory.""" """Backward-compatible wrapper around the storage-layer empty-memory factory."""
@@ -206,6 +217,39 @@ def _extract_text(content: Any) -> str:
return str(content) return str(content)
def _run_async_update_sync(coro: Awaitable[bool]) -> bool:
"""Run an async memory update from sync code, including nested-loop contexts."""
handed_off = False
try:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(asyncio.run, coro)
handed_off = True
return future.result()
handed_off = True
return asyncio.run(coro)
except Exception:
if not handed_off:
close = getattr(coro, "close", None)
if callable(close):
try:
close()
except Exception:
logger.debug(
"Failed to close un-awaited memory update coroutine",
exc_info=True,
)
logger.exception("Failed to run async memory update from sync context")
return False
# Matches sentences that describe a file-upload *event* rather than general # Matches sentences that describe a file-upload *event* rather than general
# file-related work. Deliberately narrow to avoid removing legitimate facts # file-related work. Deliberately narrow to avoid removing legitimate facts
# such as "User works with CSV files" or "prefers PDF export". # such as "User works with CSV files" or "prefers PDF export".
@@ -269,6 +313,117 @@ class MemoryUpdater:
model_name = self._model_name or config.model_name model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False) return create_chat_model(name=model_name, thinking_enabled=False)
def _build_correction_hint(
self,
correction_detected: bool,
reinforcement_detected: bool,
) -> str:
"""Build optional prompt hints for correction and reinforcement signals."""
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
return correction_hint
def _prepare_update_prompt(
self,
messages: list[Any],
agent_name: str | None,
correction_detected: bool,
reinforcement_detected: bool,
) -> tuple[dict[str, Any], str] | None:
"""Load memory and build the update prompt for a conversation."""
config = get_memory_config()
if not config.enabled or not messages:
return None
current_memory = get_memory_data(agent_name)
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return None
correction_hint = self._build_correction_hint(
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
)
return current_memory, prompt
def _finalize_update(
self,
current_memory: dict[str, Any],
response_content: Any,
thread_id: str | None,
agent_name: str | None,
) -> bool:
"""Parse the model response, apply updates, and persist memory."""
response_text = _extract_text(response_content).strip()
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
# Deep-copy before in-place mutation so a subsequent save() failure
# cannot corrupt the still-cached original object reference.
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
return get_memory_storage().save(updated_memory, agent_name)
async def aupdate_memory(
self,
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Update memory asynchronously based on conversation messages."""
try:
prepared = await asyncio.to_thread(
self._prepare_update_prompt,
messages=messages,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
if prepared is None:
return False
current_memory, prompt = prepared
model = self._get_model()
response = await model.ainvoke(prompt, config={"run_name": "memory_agent"})
return await asyncio.to_thread(
self._finalize_update,
current_memory=current_memory,
response_content=response.content,
thread_id=thread_id,
agent_name=agent_name,
)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
return False
except Exception as e:
logger.exception("Memory update failed: %s", e)
return False
def update_memory( def update_memory(
self, self,
messages: list[Any], messages: list[Any],
@@ -277,7 +432,7 @@ class MemoryUpdater:
correction_detected: bool = False, correction_detected: bool = False,
reinforcement_detected: bool = False, reinforcement_detected: bool = False,
) -> bool: ) -> bool:
"""Update memory based on conversation messages. """Synchronously update memory via the async updater path.
Args: Args:
messages: List of conversation messages. messages: List of conversation messages.
@@ -289,78 +444,15 @@ class MemoryUpdater:
Returns: Returns:
True if update was successful, False otherwise. True if update was successful, False otherwise.
""" """
config = get_memory_config() return _run_async_update_sync(
if not config.enabled: self.aupdate_memory(
return False messages=messages,
thread_id=thread_id,
if not messages: agent_name=agent_name,
return False correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
try:
# Get current memory
current_memory = get_memory_data(agent_name)
# Format conversation for prompt
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return False
# Build prompt
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
) )
)
# Call LLM
model = self._get_model()
response = model.invoke(prompt)
response_text = _extract_text(response.content).strip()
# Parse response
# Remove markdown code blocks if present
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
# Apply updates
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
# Strip file-upload mentions from all summaries before saving.
# Uploaded files are session-scoped and won't exist in future sessions,
# so recording upload events in long-term memory causes the agent to
# try (and fail) to locate those files in subsequent conversations.
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
# Save
return get_memory_storage().save(updated_memory, agent_name)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
return False
except Exception as e:
logger.exception("Memory update failed: %s", e)
return False
def _apply_updates( def _apply_updates(
self, self,
@@ -3,6 +3,7 @@
import json import json
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from hashlib import sha256
from typing import override from typing import override
from langchain.agents import AgentState from langchain.agents import AgentState
@@ -36,6 +37,13 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
state_schema = ClarificationMiddlewareState state_schema = ClarificationMiddlewareState
def _stable_message_id(self, tool_call_id: str, formatted_message: str) -> str:
"""Build a deterministic message ID so retried clarification calls replace, not append."""
if tool_call_id:
return f"clarification:{tool_call_id}"
digest = sha256(formatted_message.encode("utf-8")).hexdigest()[:16]
return f"clarification:{digest}"
def _is_chinese(self, text: str) -> bool: def _is_chinese(self, text: str) -> bool:
"""Check if text contains Chinese characters. """Check if text contains Chinese characters.
@@ -131,6 +139,7 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
# Create a ToolMessage with the formatted question # Create a ToolMessage with the formatted question
# This will be added to the message history # This will be added to the message history
tool_message = ToolMessage( tool_message = ToolMessage(
id=self._stable_message_id(tool_call_id, formatted_message),
content=formatted_message, content=formatted_message,
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
name="ask_clarification", name="ask_clarification",
@@ -13,6 +13,7 @@ at the correct positions (immediately after each dangling AIMessage), not append
to the end of the message list as before_model + add_messages reducer would do. to the end of the message list as before_model + add_messages reducer would do.
""" """
import json
import logging import logging
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import override from typing import override
@@ -33,6 +34,44 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
offending AIMessage so the LLM receives a well-formed conversation. offending AIMessage so the LLM receives a well-formed conversation.
""" """
@staticmethod
def _message_tool_calls(msg) -> list[dict]:
"""Return normalized tool calls from structured fields or raw provider payloads."""
tool_calls = getattr(msg, "tool_calls", None) or []
if tool_calls:
return list(tool_calls)
raw_tool_calls = (getattr(msg, "additional_kwargs", None) or {}).get("tool_calls") or []
normalized: list[dict] = []
for raw_tc in raw_tool_calls:
if not isinstance(raw_tc, dict):
continue
function = raw_tc.get("function")
name = raw_tc.get("name")
if not name and isinstance(function, dict):
name = function.get("name")
args = raw_tc.get("args", {})
if not args and isinstance(function, dict):
raw_args = function.get("arguments")
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except (TypeError, ValueError, json.JSONDecodeError):
parsed_args = {}
args = parsed_args if isinstance(parsed_args, dict) else {}
normalized.append(
{
"id": raw_tc.get("id"),
"name": name or "unknown",
"args": args if isinstance(args, dict) else {},
}
)
return normalized
def _build_patched_messages(self, messages: list) -> list | None: def _build_patched_messages(self, messages: list) -> list | None:
"""Return a new message list with patches inserted at the correct positions. """Return a new message list with patches inserted at the correct positions.
@@ -51,7 +90,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
for msg in messages: for msg in messages:
if getattr(msg, "type", None) != "ai": if getattr(msg, "type", None) != "ai":
continue continue
for tc in getattr(msg, "tool_calls", None) or []: for tc in self._message_tool_calls(msg):
tc_id = tc.get("id") tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids: if tc_id and tc_id not in existing_tool_msg_ids:
needs_patch = True needs_patch = True
@@ -70,7 +109,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
patched.append(msg) patched.append(msg)
if getattr(msg, "type", None) != "ai": if getattr(msg, "type", None) != "ai":
continue continue
for tc in getattr(msg, "tool_calls", None) or []: for tc in self._message_tool_calls(msg):
tc_id = tc.get("id") tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids: if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
patched.append( patched.append(
@@ -16,6 +16,9 @@ from typing import override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
from langchain_core.messages import ToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,7 +38,7 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
if not registry: if not registry:
return request return request
deferred_names = {e.name for e in registry.entries} deferred_names = registry.deferred_names
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names] active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
if len(active_tools) < len(request.tools): if len(active_tools) < len(request.tools):
@@ -43,6 +46,28 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
return request.override(tools=active_tools) return request.override(tools=active_tools)
def _blocked_tool_message(self, request: ToolCallRequest) -> ToolMessage | None:
from deerflow.tools.builtins.tool_search import get_deferred_registry
registry = get_deferred_registry()
if not registry:
return None
tool_name = str(request.tool_call.get("name") or "")
if not tool_name:
return None
if not registry.contains(tool_name):
return None
tool_call_id = str(request.tool_call.get("id") or "missing_tool_call_id")
return ToolMessage(
content=(f"Error: Tool '{tool_name}' is deferred and has not been promoted yet. Call tool_search first to expose and promote this tool's schema, then retry."),
tool_call_id=tool_call_id,
name=tool_name,
status="error",
)
@override @override
def wrap_model_call( def wrap_model_call(
self, self,
@@ -51,6 +76,17 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
) -> ModelCallResult: ) -> ModelCallResult:
return handler(self._filter_tools(request)) return handler(self._filter_tools(request))
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
blocked = self._blocked_tool_message(request)
if blocked is not None:
return blocked
return handler(request)
@override @override
async def awrap_model_call( async def awrap_model_call(
self, self,
@@ -58,3 +94,14 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
handler: Callable[[ModelRequest], Awaitable[ModelResponse]], handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult: ) -> ModelCallResult:
return await handler(self._filter_tools(request)) return await handler(self._filter_tools(request))
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
blocked = self._blocked_tool_message(request)
if blocked is not None:
return blocked
return await handler(request)
@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import threading
import time import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from email.utils import parsedate_to_datetime from email.utils import parsedate_to_datetime
@@ -19,6 +20,8 @@ from langchain.agents.middleware.types import (
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp from langgraph.errors import GraphBubbleUp
from deerflow.config import get_app_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504} _RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
@@ -67,6 +70,80 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
retry_base_delay_ms: int = 1000 retry_base_delay_ms: int = 1000
retry_cap_delay_ms: int = 8000 retry_cap_delay_ms: int = 8000
circuit_failure_threshold: int = 5
circuit_recovery_timeout_sec: int = 60
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
# Load Circuit Breaker configs from app config if available, fall back to defaults
try:
app_config = get_app_config()
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
except (FileNotFoundError, RuntimeError):
# Gracefully fall back to class defaults in test environments
pass
# Circuit Breaker state
self._circuit_lock = threading.Lock()
self._circuit_failure_count = 0
self._circuit_open_until = 0.0
self._circuit_state = "closed"
self._circuit_probe_in_flight = False
def _check_circuit(self) -> bool:
"""Returns True if circuit is OPEN (fast fail), False otherwise."""
with self._circuit_lock:
now = time.time()
if self._circuit_state == "open":
if now < self._circuit_open_until:
return True
self._circuit_state = "half_open"
self._circuit_probe_in_flight = False
if self._circuit_state == "half_open":
if self._circuit_probe_in_flight:
return True
self._circuit_probe_in_flight = True
return False
return False
def _record_success(self) -> None:
with self._circuit_lock:
if self._circuit_state != "closed" or self._circuit_failure_count > 0:
logger.info("Circuit breaker reset (Closed). LLM service recovered.")
self._circuit_failure_count = 0
self._circuit_open_until = 0.0
self._circuit_state = "closed"
self._circuit_probe_in_flight = False
def _record_failure(self) -> None:
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
self._circuit_state = "open"
self._circuit_probe_in_flight = False
logger.error(
"Circuit breaker probe failed (Open). Will probe again after %ds.",
self.circuit_recovery_timeout_sec,
)
return
self._circuit_failure_count += 1
if self._circuit_failure_count >= self.circuit_failure_threshold:
self._circuit_open_until = time.time() + self.circuit_recovery_timeout_sec
if self._circuit_state != "open":
self._circuit_state = "open"
self._circuit_probe_in_flight = False
logger.error(
"Circuit breaker tripped (Open). Threshold reached (%d). Will probe after %ds.",
self.circuit_failure_threshold,
self.circuit_recovery_timeout_sec,
)
def _classify_error(self, exc: BaseException) -> tuple[bool, str]: def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
detail = _extract_error_detail(exc) detail = _extract_error_detail(exc)
lowered = detail.lower() lowered = detail.lower()
@@ -83,6 +160,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"APITimeoutError", "APITimeoutError",
"APIConnectionError", "APIConnectionError",
"InternalServerError", "InternalServerError",
"ReadError", # httpx.ReadError: connection dropped mid-stream
"RemoteProtocolError", # httpx: server closed connection unexpectedly
}: }:
return True, "transient" return True, "transient"
if status_code in _RETRIABLE_STATUS_CODES: if status_code in _RETRIABLE_STATUS_CODES:
@@ -104,6 +183,9 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily" reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s." return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
def _build_circuit_breaker_message(self) -> str:
return "The configured LLM provider is currently unavailable due to continuous failures. Circuit breaker is engaged to protect the system. Please wait a moment before trying again."
def _build_user_message(self, exc: BaseException, reason: str) -> str: def _build_user_message(self, exc: BaseException, reason: str) -> str:
detail = _extract_error_detail(exc) detail = _extract_error_detail(exc)
if reason == "quota": if reason == "quota":
@@ -138,12 +220,20 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
request: ModelRequest, request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse], handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult: ) -> ModelCallResult:
if self._check_circuit():
return AIMessage(content=self._build_circuit_breaker_message())
attempt = 1 attempt = 1
while True: while True:
try: try:
return handler(request) response = handler(request)
self._record_success()
return response
except GraphBubbleUp: except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume). # Preserve LangGraph control-flow signals (interrupt/pause/resume).
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_probe_in_flight = False
raise raise
except Exception as exc: except Exception as exc:
retriable, reason = self._classify_error(exc) retriable, reason = self._classify_error(exc)
@@ -166,6 +256,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
_extract_error_detail(exc), _extract_error_detail(exc),
exc_info=exc, exc_info=exc,
) )
if retriable:
self._record_failure()
return AIMessage(content=self._build_user_message(exc, reason)) return AIMessage(content=self._build_user_message(exc, reason))
@override @override
@@ -174,12 +266,20 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
request: ModelRequest, request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]], handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult: ) -> ModelCallResult:
if self._check_circuit():
return AIMessage(content=self._build_circuit_breaker_message())
attempt = 1 attempt = 1
while True: while True:
try: try:
return await handler(request) response = await handler(request)
self._record_success()
return response
except GraphBubbleUp: except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume). # Preserve LangGraph control-flow signals (interrupt/pause/resume).
with self._circuit_lock:
if self._circuit_state == "half_open":
self._circuit_probe_in_flight = False
raise raise
except Exception as exc: except Exception as exc:
retriable, reason = self._classify_error(exc) retriable, reason = self._classify_error(exc)
@@ -202,6 +302,8 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
_extract_error_detail(exc), _extract_error_detail(exc),
exc_info=exc, exc_info=exc,
) )
if retriable:
self._record_failure()
return AIMessage(content=self._build_user_message(exc, reason)) return AIMessage(content=self._build_user_message(exc, reason))
@@ -17,6 +17,7 @@ import json
import logging import logging
import threading import threading
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from copy import deepcopy
from typing import override from typing import override
from langchain.agents import AgentState from langchain.agents import AgentState
@@ -24,6 +25,8 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor # Defaults — can be overridden via constructor
@@ -182,10 +185,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
def _get_thread_id(self, runtime: Runtime) -> str: def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking.""" """Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None return get_thread_id(runtime) or "default"
if thread_id:
return thread_id
return "default"
def _evict_if_needed(self) -> None: def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit. """Evict least recently used threads if over the limit.
@@ -323,6 +323,26 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
# Fallback: coerce unexpected types to str to avoid TypeError # Fallback: coerce unexpected types to str to avoid TypeError
return str(content) + f"\n\n{text}" return str(content) + f"\n\n{text}"
@staticmethod
def _build_hard_stop_update(last_msg, content: str | list) -> dict:
"""Clear tool-call metadata so forced-stop messages serialize as plain assistant text."""
update = {
"tool_calls": [],
"content": content,
}
additional_kwargs = dict(getattr(last_msg, "additional_kwargs", {}) or {})
for key in ("tool_calls", "function_call"):
additional_kwargs.pop(key, None)
update["additional_kwargs"] = additional_kwargs
response_metadata = deepcopy(getattr(last_msg, "response_metadata", {}) or {})
if response_metadata.get("finish_reason") == "tool_calls":
response_metadata["finish_reason"] = "stop"
update["response_metadata"] = response_metadata
return update
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None: def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
warning, hard_stop = self._track_and_check(state, runtime) warning, hard_stop = self._track_and_check(state, runtime)
@@ -330,12 +350,8 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
# Strip tool_calls from the last AIMessage to force text output # Strip tool_calls from the last AIMessage to force text output
messages = state.get("messages", []) messages = state.get("messages", [])
last_msg = messages[-1] last_msg = messages[-1]
stripped_msg = last_msg.model_copy( content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG)
update={ stripped_msg = last_msg.model_copy(update=self._build_hard_stop_update(last_msg, content))
"tool_calls": [],
"content": self._append_text(last_msg.content, warning),
}
)
return {"messages": [stripped_msg]} return {"messages": [stripped_msg]}
if warning: if warning:
@@ -1,50 +1,19 @@
"""Middleware for memory mechanism.""" """Middleware for memory mechanism."""
import logging import logging
import re from typing import override
from typing import Any, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.memory_config import get_memory_config from deerflow.config.memory_config import get_memory_config
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
_CORRECTION_PATTERNS = (
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
re.compile(r"\btry again\b", re.IGNORECASE),
re.compile(r"\bredo\b", re.IGNORECASE),
re.compile(r"不对"),
re.compile(r"你理解错了"),
re.compile(r"你理解有误"),
re.compile(r"重试"),
re.compile(r"重新来"),
re.compile(r"换一种"),
re.compile(r"改用"),
)
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
class MemoryMiddlewareState(AgentState): class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema.""" """Compatible with the `ThreadState` schema."""
@@ -52,125 +21,6 @@ class MemoryMiddlewareState(AgentState):
pass pass
def _extract_message_text(message: Any) -> str:
"""Extract plain text from message content for filtering and signal detection."""
content = getattr(message, "content", "")
if isinstance(content, list):
text_parts: list[str] = []
for part in content:
if isinstance(part, str):
text_parts.append(part)
elif isinstance(part, dict):
text_val = part.get("text")
if isinstance(text_val, str):
text_parts.append(text_val)
return " ".join(text_parts)
return str(content)
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
"""Filter messages to keep only user inputs and final assistant responses.
This filters out:
- Tool messages (intermediate tool call results)
- AI messages with tool_calls (intermediate steps, not final responses)
- The <uploaded_files> block injected by UploadsMiddleware into human messages
(file paths are session-scoped and must not persist in long-term memory).
The user's actual question is preserved; only turns whose content is entirely
the upload block (nothing remains after stripping) are dropped along with
their paired assistant response.
Only keeps:
- Human messages (with the ephemeral upload block removed)
- AI messages without tool_calls (final assistant responses), unless the
paired human turn was upload-only and had no real user text.
Args:
messages: List of all conversation messages.
Returns:
Filtered list containing only user inputs and final assistant responses.
"""
filtered = []
skip_next_ai = False
for msg in messages:
msg_type = getattr(msg, "type", None)
if msg_type == "human":
content_str = _extract_message_text(msg)
if "<uploaded_files>" in content_str:
# Strip the ephemeral upload block; keep the user's real question.
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
if not stripped:
# Nothing left — the entire turn was upload bookkeeping;
# skip it and the paired assistant response.
skip_next_ai = True
continue
# Rebuild the message with cleaned content so the user's question
# is still available for memory summarisation.
from copy import copy
clean_msg = copy(msg)
clean_msg.content = stripped
filtered.append(clean_msg)
skip_next_ai = False
else:
filtered.append(msg)
skip_next_ai = False
elif msg_type == "ai":
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
if skip_next_ai:
skip_next_ai = False
continue
filtered.append(msg)
# Skip tool messages and AI messages with tool_calls
return filtered
def detect_correction(messages: list[Any]) -> bool:
"""Detect explicit user corrections in recent conversation turns.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale corrections from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
return True
return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns.
Complements detect_correction() by identifying when the user confirms the
agent's approach was correct. This allows the memory system to record what
worked well, not just what went wrong.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale signals from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution. """Middleware that queues conversation for memory update after agent execution.
@@ -207,13 +57,10 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
if not config.enabled: if not config.enabled:
return None return None
# Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata # Resolve thread ID from the runtime or configured fallback sources
thread_id = runtime.context.get("thread_id") if runtime.context else None thread_id = get_thread_id(runtime)
if thread_id is None:
config_data = get_config()
thread_id = config_data.get("configurable", {}).get("thread_id")
if not thread_id: if not thread_id:
logger.debug("No thread_id in context, skipping memory update") logger.debug("No thread_id could be resolved from runtime/config, skipping memory update")
return None return None
# Get messages from state # Get messages from state
@@ -223,7 +70,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
return None return None
# Filter to only keep user inputs and final assistant responses # Filter to only keep user inputs and final assistant responses
filtered_messages = _filter_messages_for_memory(messages) filtered_messages = filter_messages_for_memory(messages)
# Only queue if there's meaningful conversation # Only queue if there's meaningful conversation
# At minimum need one user message and one assistant response # At minimum need one user message and one assistant response
@@ -14,6 +14,7 @@ from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command from langgraph.types import Command
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -218,15 +219,7 @@ class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _get_thread_id(self, request: ToolCallRequest) -> str | None: def _get_thread_id(self, request: ToolCallRequest) -> str | None:
runtime = request.runtime # ToolRuntime; may be None-like in tests return get_thread_id(request.runtime)
if runtime is None:
return None
ctx = getattr(runtime, "context", None) or {}
thread_id = ctx.get("thread_id") if isinstance(ctx, dict) else None
if thread_id is None:
cfg = getattr(runtime, "config", None) or {}
thread_id = cfg.get("configurable", {}).get("thread_id")
return thread_id
_AUDIT_COMMAND_LIMIT = 200 _AUDIT_COMMAND_LIMIT = 200
@@ -0,0 +1,337 @@
"""Summarization middleware extensions for DeerFlow."""
from __future__ import annotations
import logging
from collections.abc import Collection
from dataclasses import dataclass
from typing import Any, Protocol, runtime_checkable
from langchain.agents import AgentState
from langchain.agents.middleware import SummarizationMiddleware
from langchain_core.messages import AIMessage, AnyMessage, RemoveMessage, ToolMessage
from langgraph.config import get_config
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class SummarizationEvent:
"""Context emitted before conversation history is summarized away."""
messages_to_summarize: tuple[AnyMessage, ...]
preserved_messages: tuple[AnyMessage, ...]
thread_id: str | None
agent_name: str | None
runtime: Runtime
@runtime_checkable
class BeforeSummarizationHook(Protocol):
"""Hook invoked before summarization removes messages from state."""
def __call__(self, event: SummarizationEvent) -> None: ...
def _resolve_agent_name(runtime: Runtime) -> str | None:
"""Resolve the current agent name from runtime context or LangGraph config."""
agent_name = runtime.context.get("agent_name") if runtime.context else None
if agent_name is None:
try:
config_data = get_config()
except RuntimeError:
return None
agent_name = config_data.get("configurable", {}).get("agent_name")
return agent_name
def _tool_call_path(tool_call: dict[str, Any]) -> str | None:
"""Best-effort extraction of a file path argument from a read_file-like tool call."""
args = tool_call.get("args") or {}
if not isinstance(args, dict):
return None
for key in ("path", "file_path", "filepath"):
value = args.get(key)
if isinstance(value, str) and value:
return value
return None
def _clone_ai_message(
message: AIMessage,
tool_calls: list[dict[str, Any]],
*,
content: Any | None = None,
) -> AIMessage:
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
update: dict[str, Any] = {"tool_calls": tool_calls}
if content is not None:
update["content"] = content
return message.model_copy(update=update)
@dataclass
class _SkillBundle:
"""Skill-related tool calls and tool results associated with one AIMessage."""
ai_index: int
skill_tool_indices: tuple[int, ...]
skill_tool_call_ids: frozenset[str]
skill_tool_tokens: int
skill_key: str
class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
"""Summarization middleware with pre-compression hook dispatch and skill rescue."""
def __init__(
self,
*args,
skills_container_path: str | None = None,
skill_file_read_tool_names: Collection[str] | None = None,
before_summarization: list[BeforeSummarizationHook] | None = None,
preserve_recent_skill_count: int = 5,
preserve_recent_skill_tokens: int = 25_000,
preserve_recent_skill_tokens_per_skill: int = 5_000,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._skills_container_path = skills_container_path or "/mnt/skills"
self._skill_file_read_tool_names = frozenset(skill_file_read_tool_names or {"read_file", "read", "view", "cat"})
self._before_summarization_hooks = before_summarization or []
self._preserve_recent_skill_count = max(0, preserve_recent_skill_count)
self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens)
self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill)
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._maybe_summarize(state, runtime)
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return await self._amaybe_summarize(state, runtime)
def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
messages = state["messages"]
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if not self._should_summarize(messages, total_tokens):
return None
cutoff_index = self._determine_cutoff_index(messages)
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = self._create_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages,
*preserved_messages,
]
}
async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
messages = state["messages"]
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if not self._should_summarize(messages, total_tokens):
return None
cutoff_index = self._determine_cutoff_index(messages)
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = await self._acreate_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages,
*preserved_messages,
]
}
def _partition_with_skill_rescue(
self,
messages: list[AnyMessage],
cutoff_index: int,
) -> tuple[list[AnyMessage], list[AnyMessage]]:
"""Partition like the parent, then rescue recently-loaded skill bundles."""
to_summarize, preserved = self._partition_messages(messages, cutoff_index)
if self._preserve_recent_skill_count == 0 or self._preserve_recent_skill_tokens == 0 or not to_summarize:
return to_summarize, preserved
try:
bundles = self._find_skill_bundles(to_summarize, self._skills_container_path)
except Exception:
logger.exception("Skill-preserving summarization rescue failed; falling back to default partition")
return to_summarize, preserved
if not bundles:
return to_summarize, preserved
rescue_bundles = self._select_bundles_to_rescue(bundles)
if not rescue_bundles:
return to_summarize, preserved
bundles_by_ai_index = {bundle.ai_index: bundle for bundle in rescue_bundles}
rescue_tool_indices = {idx for bundle in rescue_bundles for idx in bundle.skill_tool_indices}
rescued: list[AnyMessage] = []
remaining: list[AnyMessage] = []
for i, msg in enumerate(to_summarize):
bundle = bundles_by_ai_index.get(i)
if bundle is not None and isinstance(msg, AIMessage):
rescued_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") in bundle.skill_tool_call_ids]
remaining_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") not in bundle.skill_tool_call_ids]
if rescued_tool_calls:
rescued.append(_clone_ai_message(msg, rescued_tool_calls, content=""))
if remaining_tool_calls or msg.content:
remaining.append(_clone_ai_message(msg, remaining_tool_calls))
continue
if i in rescue_tool_indices:
rescued.append(msg)
continue
remaining.append(msg)
return remaining, rescued + preserved
def _find_skill_bundles(
self,
messages: list[AnyMessage],
skills_root: str,
) -> list[_SkillBundle]:
"""Locate AIMessage + paired ToolMessage groups that load skill files."""
bundles: list[_SkillBundle] = []
n = len(messages)
i = 0
while i < n:
msg = messages[i]
if not (isinstance(msg, AIMessage) and msg.tool_calls):
i += 1
continue
tool_calls = list(msg.tool_calls)
skill_paths_by_id: dict[str, str] = {}
for tc in tool_calls:
if self._is_skill_tool_call(tc, skills_root):
tc_id = tc.get("id")
path = _tool_call_path(tc)
if tc_id and path:
skill_paths_by_id[tc_id] = path
if not skill_paths_by_id:
i += 1
continue
skill_tool_tokens = 0
skill_key_parts: list[str] = []
skill_tool_indices: list[int] = []
matched_skill_call_ids: set[str] = set()
j = i + 1
while j < n and isinstance(messages[j], ToolMessage):
j += 1
for k in range(i + 1, j):
tool_msg = messages[k]
if isinstance(tool_msg, ToolMessage) and tool_msg.tool_call_id in skill_paths_by_id:
skill_tool_tokens += self.token_counter([tool_msg])
skill_key_parts.append(skill_paths_by_id[tool_msg.tool_call_id])
skill_tool_indices.append(k)
matched_skill_call_ids.add(tool_msg.tool_call_id)
if not skill_tool_indices:
i = j
continue
bundles.append(
_SkillBundle(
ai_index=i,
skill_tool_indices=tuple(skill_tool_indices),
skill_tool_call_ids=frozenset(matched_skill_call_ids),
skill_tool_tokens=skill_tool_tokens,
skill_key="|".join(sorted(skill_key_parts)),
)
)
i = j
return bundles
def _select_bundles_to_rescue(self, bundles: list[_SkillBundle]) -> list[_SkillBundle]:
"""Pick bundles to keep, walking newest-first under count/token budgets."""
selected: list[_SkillBundle] = []
if not bundles:
return selected
seen_skill_keys: set[str] = set()
total_tokens = 0
kept = 0
for bundle in reversed(bundles):
if kept >= self._preserve_recent_skill_count:
break
if bundle.skill_key in seen_skill_keys:
continue
if bundle.skill_tool_tokens > self._preserve_recent_skill_tokens_per_skill:
continue
if total_tokens + bundle.skill_tool_tokens > self._preserve_recent_skill_tokens:
continue
selected.append(bundle)
total_tokens += bundle.skill_tool_tokens
kept += 1
seen_skill_keys.add(bundle.skill_key)
selected.reverse()
return selected
def _is_skill_tool_call(self, tool_call: dict[str, Any], skills_root: str) -> bool:
"""Return True when ``tool_call`` reads a file under the configured skills root."""
name = tool_call.get("name") or ""
if name not in self._skill_file_read_tool_names:
return False
path = _tool_call_path(tool_call)
if not path:
return False
normalized_root = skills_root.rstrip("/")
return path == normalized_root or path.startswith(normalized_root + "/")
def _fire_hooks(
self,
messages_to_summarize: list[AnyMessage],
preserved_messages: list[AnyMessage],
runtime: Runtime,
) -> None:
if not self._before_summarization_hooks:
return
event = SummarizationEvent(
messages_to_summarize=tuple(messages_to_summarize),
preserved_messages=tuple(preserved_messages),
thread_id=get_thread_id(runtime),
agent_name=_resolve_agent_name(runtime),
runtime=runtime,
)
for hook in self._before_summarization_hooks:
try:
hook(event)
except Exception:
hook_name = getattr(hook, "__name__", None) or type(hook).__name__
logger.exception("before_summarization hook %s failed", hook_name)
@@ -3,11 +3,11 @@ from typing import NotRequired, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadDataState from deerflow.agents.thread_state import ThreadDataState
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -75,11 +75,7 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
@override @override
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None: def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
context = runtime.context or {} thread_id = get_thread_id(runtime)
thread_id = context.get("thread_id")
if thread_id is None:
config = get_config()
thread_id = config.get("configurable", {}).get("thread_id")
if thread_id is None: if thread_id is None:
raise ValueError("Thread ID is required in runtime context or config.configurable") raise ValueError("Thread ID is required in runtime context or config.configurable")
@@ -1,6 +1,7 @@
"""Middleware for automatic thread title generation.""" """Middleware for automatic thread title generation."""
import logging import logging
import re
from typing import NotRequired, override from typing import NotRequired, override
from langchain.agents import AgentState from langchain.agents import AgentState
@@ -77,7 +78,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "") assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
user_msg = self._normalize_content(user_msg_content) user_msg = self._normalize_content(user_msg_content)
assistant_msg = self._normalize_content(assistant_msg_content) assistant_msg = self._strip_think_tags(self._normalize_content(assistant_msg_content))
prompt = config.prompt_template.format( prompt = config.prompt_template.format(
max_words=config.max_words, max_words=config.max_words,
@@ -86,10 +87,15 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
) )
return prompt, user_msg return prompt, user_msg
def _strip_think_tags(self, text: str) -> str:
"""Remove <think>...</think> blocks emitted by reasoning models (e.g. minimax, DeepSeek-R1)."""
return re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE).strip()
def _parse_title(self, content: object) -> str: def _parse_title(self, content: object) -> str:
"""Normalize model output into a clean title string.""" """Normalize model output into a clean title string."""
config = get_title_config() config = get_title_config()
title_content = self._normalize_content(content) title_content = self._normalize_content(content)
title_content = self._strip_think_tags(title_content)
title = title_content.strip().strip('"').strip("'") title = title_content.strip().strip('"').strip("'")
return title[: config.max_chars] if len(title) > config.max_chars else title return title[: config.max_chars] if len(title) > config.max_chars else title
@@ -121,7 +127,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
model = create_chat_model(name=config.model_name, thinking_enabled=False) model = create_chat_model(name=config.model_name, thinking_enabled=False)
else: else:
model = create_chat_model(thinking_enabled=False) model = create_chat_model(thinking_enabled=False)
response = await model.ainvoke(prompt) response = await model.ainvoke(prompt, config={"run_name": "title_agent"})
title = self._parse_title(response.content) title = self._parse_title(response.content)
if title: if title:
return {"title": title} return {"title": title}
@@ -1,9 +1,14 @@
"""Middleware that extends TodoListMiddleware with context-loss detection. """Middleware that extends TodoListMiddleware with context-loss detection and premature-exit prevention.
When the message history is truncated (e.g., by SummarizationMiddleware), the When the message history is truncated (e.g., by SummarizationMiddleware), the
original `write_todos` tool call and its ToolMessage can be scrolled out of the original `write_todos` tool call and its ToolMessage can be scrolled out of the
active context window. This middleware detects that situation and injects a active context window. This middleware detects that situation and injects a
reminder message so the model still knows about the outstanding todo list. reminder message so the model still knows about the outstanding todo list.
Additionally, this middleware prevents the agent from exiting the loop while
there are still incomplete todo items. When the model produces a final response
(no tool calls) but todos are not yet complete, the middleware injects a reminder
and jumps back to the model node to force continued engagement.
""" """
from __future__ import annotations from __future__ import annotations
@@ -12,6 +17,7 @@ from typing import Any, override
from langchain.agents.middleware import TodoListMiddleware from langchain.agents.middleware import TodoListMiddleware
from langchain.agents.middleware.todo import PlanningState, Todo from langchain.agents.middleware.todo import PlanningState, Todo
from langchain.agents.middleware.types import hook_config
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
@@ -34,6 +40,11 @@ def _reminder_in_messages(messages: list[Any]) -> bool:
return False return False
def _completion_reminder_count(messages: list[Any]) -> int:
"""Return the number of todo_completion_reminder HumanMessages in *messages*."""
return sum(1 for msg in messages if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_completion_reminder")
def _format_todos(todos: list[Todo]) -> str: def _format_todos(todos: list[Todo]) -> str:
"""Format a list of Todo items into a human-readable string.""" """Format a list of Todo items into a human-readable string."""
lines: list[str] = [] lines: list[str] = []
@@ -57,7 +68,7 @@ class TodoMiddleware(TodoListMiddleware):
def before_model( def before_model(
self, self,
state: PlanningState, state: PlanningState,
runtime: Runtime, # noqa: ARG002 runtime: Runtime,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Inject a todo-list reminder when write_todos has left the context window.""" """Inject a todo-list reminder when write_todos has left the context window."""
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment] todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
@@ -98,3 +109,71 @@ class TodoMiddleware(TodoListMiddleware):
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Async version of before_model.""" """Async version of before_model."""
return self.before_model(state, runtime) return self.before_model(state, runtime)
# Maximum number of completion reminders before allowing the agent to exit.
# This prevents infinite loops when the agent cannot make further progress.
_MAX_COMPLETION_REMINDERS = 2
@hook_config(can_jump_to=["model"])
@override
def after_model(
self,
state: PlanningState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Prevent premature agent exit when todo items are still incomplete.
In addition to the base class check for parallel ``write_todos`` calls,
this override intercepts model responses that have no tool calls while
there are still incomplete todo items. It injects a reminder
``HumanMessage`` and jumps back to the model node so the agent
continues working through the todo list.
A retry cap of ``_MAX_COMPLETION_REMINDERS`` (default 2) prevents
infinite loops when the agent cannot make further progress.
"""
# 1. Preserve base class logic (parallel write_todos detection).
base_result = super().after_model(state, runtime)
if base_result is not None:
return base_result
# 2. Only intervene when the agent wants to exit (no tool calls).
messages = state.get("messages") or []
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
if not last_ai or last_ai.tool_calls:
return None
# 3. Allow exit when all todos are completed or there are no todos.
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
if not todos or all(t.get("status") == "completed" for t in todos):
return None
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS:
return None
# 5. Inject a reminder and force the agent back to the model.
incomplete = [t for t in todos if t.get("status") != "completed"]
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
reminder = HumanMessage(
name="todo_completion_reminder",
content=(
"<system_reminder>\n"
"You have incomplete todo items that must be finished before giving your final response:\n\n"
f"{incomplete_text}\n\n"
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
"as you finish them, and only respond when all items are done.\n"
"</system_reminder>"
),
)
return {"jump_to": "model", "messages": [reminder]}
@override
@hook_config(can_jump_to=["model"])
async def aafter_model(
self,
state: PlanningState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async version of after_model."""
return self.after_model(state, runtime)
@@ -11,6 +11,7 @@ from langgraph.runtime import Runtime
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
from deerflow.utils.file_conversion import extract_outline from deerflow.utils.file_conversion import extract_outline
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -213,14 +214,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
return None return None
# Resolve uploads directory for existence checks # Resolve uploads directory for existence checks
thread_id = (runtime.context or {}).get("thread_id") thread_id = get_thread_id(runtime)
if thread_id is None:
try:
from langgraph.config import get_config
thread_id = get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
pass # get_config() raises outside a runnable context (e.g. unit tests)
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files # Get newly uploaded files from the current message's additional_kwargs.files
+6 -1
View File
@@ -722,6 +722,10 @@ class DeerFlowClient:
Dict with "models" key containing list of model info dicts, Dict with "models" key containing list of model info dicts,
matching the Gateway API ``ModelsListResponse`` schema. matching the Gateway API ``ModelsListResponse`` schema.
""" """
token_usage_enabled = getattr(getattr(self._app_config, "token_usage", None), "enabled", False)
if not isinstance(token_usage_enabled, bool):
token_usage_enabled = False
return { return {
"models": [ "models": [
{ {
@@ -733,7 +737,8 @@ class DeerFlowClient:
"supports_reasoning_effort": getattr(model, "supports_reasoning_effort", False), "supports_reasoning_effort": getattr(model, "supports_reasoning_effort", False),
} }
for model in self._app_config.models for model in self._app_config.models
] ],
"token_usage": {"enabled": token_usage_enabled},
} }
def list_skills(self, enabled_only: bool = False) -> dict: def list_skills(self, enabled_only: bool = False) -> dict:
@@ -119,6 +119,16 @@ class AioSandboxProvider(SandboxProvider):
if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0: if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0:
self._start_idle_checker() self._start_idle_checker()
@property
def uses_thread_data_mounts(self) -> bool:
"""Whether thread workspace/uploads/outputs are visible via mounts.
Local container backends bind-mount the thread data directories, so files
written by the gateway are already visible when the sandbox starts.
Remote backends may require explicit file sync.
"""
return isinstance(self._backend, LocalContainerBackend)
# ── Factory methods ────────────────────────────────────────────────── # ── Factory methods ──────────────────────────────────────────────────
def _create_backend(self) -> SandboxBackend: def _create_backend(self) -> SandboxBackend:
@@ -38,6 +38,6 @@ class JinaClient:
return response.text return response.text
except Exception as e: except Exception as e:
error_message = f"Request to Jina API failed: {str(e)}" error_message = f"Request to Jina API failed: {type(e).__name__}: {e}"
logger.exception(error_message) logger.warning(error_message)
return f"Error: {error_message}" return f"Error: {error_message}"
@@ -1,3 +1,5 @@
import asyncio
from langchain.tools import tool from langchain.tools import tool
from deerflow.community.jina_ai.jina_client import JinaClient from deerflow.community.jina_ai.jina_client import JinaClient
@@ -26,5 +28,5 @@ async def web_fetch_tool(url: str) -> str:
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout) html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
if isinstance(html_content, str) and html_content.startswith("Error:"): if isinstance(html_content, str) and html_content.startswith("Error:"):
return html_content return html_content
article = readability_extractor.extract_article(html_content) article = await asyncio.to_thread(readability_extractor.extract_article, html_content)
return article.to_markdown()[:4096] return article.to_markdown()[:4096]
@@ -0,0 +1,32 @@
"""Configuration for the custom agents management API."""
from pydantic import BaseModel, Field
class AgentsApiConfig(BaseModel):
"""Configuration for custom-agent and user-profile management routes."""
enabled: bool = Field(
default=False,
description=("Whether to expose the custom-agent management API over HTTP. When disabled, the gateway rejects read/write access to custom agent SOUL.md, config, and USER.md prompt-management routes."),
)
_agents_api_config: AgentsApiConfig = AgentsApiConfig()
def get_agents_api_config() -> AgentsApiConfig:
"""Get the current agents API configuration."""
return _agents_api_config
def set_agents_api_config(config: AgentsApiConfig) -> None:
"""Set the agents API configuration."""
global _agents_api_config
_agents_api_config = config
def load_agents_api_config_from_dict(config_dict: dict) -> None:
"""Load agents API configuration from a dictionary."""
global _agents_api_config
_agents_api_config = AgentsApiConfig(**config_dict)
@@ -15,6 +15,17 @@ SOUL_FILENAME = "SOUL.md"
AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$") AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
def validate_agent_name(name: str | None) -> str | None:
"""Validate a custom agent name before using it in filesystem paths."""
if name is None:
return None
if not isinstance(name, str):
raise ValueError("Invalid agent name. Expected a string or None.")
if not AGENT_NAME_PATTERN.fullmatch(name):
raise ValueError(f"Invalid agent name '{name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
return name
class AgentConfig(BaseModel): class AgentConfig(BaseModel):
"""Configuration for a custom agent.""" """Configuration for a custom agent."""
@@ -46,8 +57,7 @@ def load_agent_config(name: str | None) -> AgentConfig | None:
if name is None: if name is None:
return None return None
if not AGENT_NAME_PATTERN.match(name): name = validate_agent_name(name)
raise ValueError(f"Invalid agent name '{name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
agent_dir = get_paths().agent_dir(name) agent_dir = get_paths().agent_dir(name)
config_file = agent_dir / "config.yaml" config_file = agent_dir / "config.yaml"
@@ -9,6 +9,7 @@ from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.acp_config import load_acp_config_from_dict from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.extensions_config import ExtensionsConfig from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
@@ -30,6 +31,13 @@ load_dotenv()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CircuitBreakerConfig(BaseModel):
"""Configuration for the LLM Circuit Breaker."""
failure_threshold: int = Field(default=5, description="Number of consecutive failures before tripping the circuit")
recovery_timeout_sec: int = Field(default=60, description="Time in seconds before attempting to recover the circuit")
def _default_config_candidates() -> tuple[Path, ...]: def _default_config_candidates() -> tuple[Path, ...]:
"""Return deterministic config.yaml locations without relying on cwd.""" """Return deterministic config.yaml locations without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4] backend_dir = Path(__file__).resolve().parents[4]
@@ -53,8 +61,10 @@ class AppConfig(BaseModel):
title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration") title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration")
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration") summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration") memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
agents_api: AgentsApiConfig = Field(default_factory=AgentsApiConfig, description="Custom-agent management API configuration")
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration") subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration") guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
model_config = ConfigDict(extra="allow", frozen=False) model_config = ConfigDict(extra="allow", frozen=False)
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration") checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration") stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
@@ -117,6 +127,10 @@ class AppConfig(BaseModel):
if "memory" in config_data: if "memory" in config_data:
load_memory_config_from_dict(config_data["memory"]) load_memory_config_from_dict(config_data["memory"])
# Always refresh agents API config so removed config sections reset
# singleton-backed state to its default/disabled values on reload.
load_agents_api_config_from_dict(config_data.get("agents_api") or {})
# Load subagents config if present # Load subagents config if present
if "subagents" in config_data: if "subagents" in config_data:
load_subagents_config_from_dict(config_data["subagents"]) load_subagents_config_from_dict(config_data["subagents"])
@@ -129,6 +143,10 @@ class AppConfig(BaseModel):
if "guardrails" in config_data: if "guardrails" in config_data:
load_guardrails_config_from_dict(config_data["guardrails"]) load_guardrails_config_from_dict(config_data["guardrails"])
# Load circuit_breaker config if present
if "circuit_breaker" in config_data:
config_data["circuit_breaker"] = config_data["circuit_breaker"]
# Load checkpointer config if present # Load checkpointer config if present
if "checkpointer" in config_data: if "checkpointer" in config_data:
load_checkpointer_config_from_dict(config_data["checkpointer"]) load_checkpointer_config_from_dict(config_data["checkpointer"])
@@ -25,6 +25,47 @@ class SubagentOverrideConfig(BaseModel):
min_length=1, min_length=1,
description="Model name for this subagent (None = inherit from parent agent)", description="Model name for this subagent (None = inherit from parent agent)",
) )
skills: list[str] | None = Field(
default=None,
description="Skill names whitelist for this subagent (None = inherit all enabled skills, [] = no skills)",
)
class CustomSubagentConfig(BaseModel):
"""User-defined subagent type declared in config.yaml."""
description: str = Field(
description="When the lead agent should delegate to this subagent",
)
system_prompt: str = Field(
description="System prompt that guides the subagent's behavior",
)
tools: list[str] | None = Field(
default=None,
description="Tool names whitelist (None = inherit all tools from parent)",
)
disallowed_tools: list[str] | None = Field(
default_factory=lambda: ["task", "ask_clarification", "present_files"],
description="Tool names to deny",
)
skills: list[str] | None = Field(
default=None,
description="Skill names whitelist (None = inherit all enabled skills, [] = no skills)",
)
model: str = Field(
default="inherit",
description="Model to use - 'inherit' uses parent's model",
)
max_turns: int = Field(
default=50,
ge=1,
description="Maximum number of agent turns before stopping",
)
timeout_seconds: int = Field(
default=900,
ge=1,
description="Maximum execution time in seconds",
)
class SubagentsAppConfig(BaseModel): class SubagentsAppConfig(BaseModel):
@@ -44,6 +85,10 @@ class SubagentsAppConfig(BaseModel):
default_factory=dict, default_factory=dict,
description="Per-agent configuration overrides keyed by agent name", description="Per-agent configuration overrides keyed by agent name",
) )
custom_agents: dict[str, CustomSubagentConfig] = Field(
default_factory=dict,
description="User-defined subagent types keyed by agent name",
)
def get_timeout_for(self, agent_name: str) -> int: def get_timeout_for(self, agent_name: str) -> int:
"""Get the effective timeout for a specific agent. """Get the effective timeout for a specific agent.
@@ -82,6 +127,20 @@ class SubagentsAppConfig(BaseModel):
return self.max_turns return self.max_turns
return builtin_default return builtin_default
def get_skills_for(self, agent_name: str) -> list[str] | None:
"""Get the skills override for a specific agent.
Args:
agent_name: The name of the subagent.
Returns:
Skill names whitelist if overridden, None otherwise (subagent will inherit all enabled skills).
"""
override = self.agents.get(agent_name)
if override is not None and override.skills is not None:
return override.skills
return None
_subagents_config: SubagentsAppConfig = SubagentsAppConfig() _subagents_config: SubagentsAppConfig = SubagentsAppConfig()
@@ -105,15 +164,20 @@ def load_subagents_config_from_dict(config_dict: dict) -> None:
parts.append(f"max_turns={override.max_turns}") parts.append(f"max_turns={override.max_turns}")
if override.model is not None: if override.model is not None:
parts.append(f"model={override.model}") parts.append(f"model={override.model}")
if override.skills is not None:
parts.append(f"skills={override.skills}")
if parts: if parts:
overrides_summary[name] = ", ".join(parts) overrides_summary[name] = ", ".join(parts)
if overrides_summary: custom_agents_names = list(_subagents_config.custom_agents.keys())
if overrides_summary or custom_agents_names:
logger.info( logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s", "Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s, custom_agents=%s",
_subagents_config.timeout_seconds, _subagents_config.timeout_seconds,
_subagents_config.max_turns, _subagents_config.max_turns,
overrides_summary, overrides_summary or "none",
custom_agents_names or "none",
) )
else: else:
logger.info( logger.info(
@@ -51,6 +51,25 @@ class SummarizationConfig(BaseModel):
default=None, default=None,
description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.", description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.",
) )
preserve_recent_skill_count: int = Field(
default=5,
ge=0,
description="Number of most-recently-loaded skill files to exclude from summarization. Set to 0 to disable skill preservation.",
)
preserve_recent_skill_tokens: int = Field(
default=25000,
ge=0,
description="Total token budget reserved for recently-loaded skill files that must be preserved across summarization.",
)
preserve_recent_skill_tokens_per_skill: int = Field(
default=5000,
ge=0,
description="Per-skill token cap when preserving skill files across summarization. Skill reads above this size are not rescued.",
)
skill_file_read_tool_names: list[str] = Field(
default_factory=lambda: ["read_file", "read", "view", "cat"],
description="Tool names treated as skill file reads when preserving recently-loaded skills across summarization.",
)
# Global configuration instance # Global configuration instance
@@ -118,9 +118,13 @@ def get_cached_mcp_tools() -> list[BaseTool]:
loop.run_until_complete(initialize_mcp_tools()) loop.run_until_complete(initialize_mcp_tools())
except RuntimeError: except RuntimeError:
# No event loop exists, create one # No event loop exists, create one
asyncio.run(initialize_mcp_tools()) try:
except Exception as e: asyncio.run(initialize_mcp_tools())
logger.error(f"Failed to lazy-initialize MCP tools: {e}") except Exception:
logger.exception("Failed to lazy-initialize MCP tools")
return []
except Exception:
logger.exception("Failed to lazy-initialize MCP tools")
return [] return []
return _mcp_tools_cache or [] return _mcp_tools_cache or []
@@ -12,6 +12,7 @@ from langchain_core.tools import BaseTool
from deerflow.config.extensions_config import ExtensionsConfig from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.mcp.client import build_servers_config from deerflow.mcp.client import build_servers_config
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
from deerflow.reflection import resolve_variable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -95,6 +96,27 @@ async def get_mcp_tools() -> list[BaseTool]:
if oauth_interceptor is not None: if oauth_interceptor is not None:
tool_interceptors.append(oauth_interceptor) tool_interceptors.append(oauth_interceptor)
# Load custom interceptors declared in extensions_config.json
# Format: "mcpInterceptors": ["pkg.module:builder_func", ...]
raw_interceptor_paths = (extensions_config.model_extra or {}).get("mcpInterceptors")
if isinstance(raw_interceptor_paths, str):
raw_interceptor_paths = [raw_interceptor_paths]
elif not isinstance(raw_interceptor_paths, list):
if raw_interceptor_paths is not None:
logger.warning(f"mcpInterceptors must be a list of strings, got {type(raw_interceptor_paths).__name__}; skipping")
raw_interceptor_paths = []
for interceptor_path in raw_interceptor_paths:
try:
builder = resolve_variable(interceptor_path)
interceptor = builder()
if callable(interceptor):
tool_interceptors.append(interceptor)
logger.info(f"Loaded MCP interceptor: {interceptor_path}")
elif interceptor is not None:
logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping")
except Exception as e:
logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True)
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True) client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True)
# Get all tools from all servers # Get all tools from all servers
@@ -190,23 +190,33 @@ class ClaudeChatModel(ChatAnthropic):
) )
def _apply_prompt_caching(self, payload: dict) -> None: def _apply_prompt_caching(self, payload: dict) -> None:
"""Apply ephemeral cache_control to system and recent messages.""" """Apply ephemeral cache_control to system, recent messages, and last tool definition.
# Cache system messages
Uses a budget of MAX_CACHE_BREAKPOINTS (4) breakpoints — the hard limit
enforced by both the Anthropic API and AWS Bedrock. Breakpoints are
placed on the *last* eligible blocks because later breakpoints cover a
larger prefix and yield better cache hit rates.
"""
MAX_CACHE_BREAKPOINTS = 4
# Collect candidate blocks in document order:
# 1. system text blocks
# 2. content blocks of the last prompt_cache_size messages
# 3. the last tool definition
candidates: list[dict] = []
# 1. System blocks
system = payload.get("system") system = payload.get("system")
if system and isinstance(system, list): if system and isinstance(system, list):
for block in system: for block in system:
if isinstance(block, dict) and block.get("type") == "text": if isinstance(block, dict) and block.get("type") == "text":
block["cache_control"] = {"type": "ephemeral"} candidates.append(block)
elif system and isinstance(system, str): elif system and isinstance(system, str):
payload["system"] = [ new_block: dict = {"type": "text", "text": system}
{ payload["system"] = [new_block]
"type": "text", candidates.append(new_block)
"text": system,
"cache_control": {"type": "ephemeral"},
}
]
# Cache recent messages # 2. Recent message blocks
messages = payload.get("messages", []) messages = payload.get("messages", [])
cache_start = max(0, len(messages) - self.prompt_cache_size) cache_start = max(0, len(messages) - self.prompt_cache_size)
for i in range(cache_start, len(messages)): for i in range(cache_start, len(messages)):
@@ -217,20 +227,21 @@ class ClaudeChatModel(ChatAnthropic):
if isinstance(content, list): if isinstance(content, list):
for block in content: for block in content:
if isinstance(block, dict): if isinstance(block, dict):
block["cache_control"] = {"type": "ephemeral"} candidates.append(block)
elif isinstance(content, str) and content: elif isinstance(content, str) and content:
msg["content"] = [ new_block = {"type": "text", "text": content}
{ msg["content"] = [new_block]
"type": "text", candidates.append(new_block)
"text": content,
"cache_control": {"type": "ephemeral"},
}
]
# Cache the last tool definition # 3. Last tool definition
tools = payload.get("tools", []) tools = payload.get("tools", [])
if tools and isinstance(tools[-1], dict): if tools and isinstance(tools[-1], dict):
tools[-1]["cache_control"] = {"type": "ephemeral"} candidates.append(tools[-1])
# Apply cache_control only to the last MAX_CACHE_BREAKPOINTS candidates
# to stay within the API limit.
for block in candidates[-MAX_CACHE_BREAKPOINTS:]:
block["cache_control"] = {"type": "ephemeral"}
def _apply_thinking_budget(self, payload: dict) -> None: def _apply_thinking_budget(self, payload: dict) -> None:
"""Auto-allocate thinking budget (80% of max_tokens).""" """Auto-allocate thinking budget (80% of max_tokens)."""
@@ -30,6 +30,22 @@ def _vllm_disable_chat_template_kwargs(chat_template_kwargs: dict) -> dict:
return disable_kwargs return disable_kwargs
def _enable_stream_usage_by_default(model_use_path: str, model_settings_from_config: dict) -> None:
"""Enable stream usage for OpenAI-compatible models unless explicitly configured.
LangChain only auto-enables ``stream_usage`` for OpenAI models when no custom
base URL or client is configured. DeerFlow frequently uses OpenAI-compatible
gateways, so token usage tracking would otherwise stay empty and the
TokenUsageMiddleware would have nothing to log.
"""
if model_use_path != "langchain_openai:ChatOpenAI":
return
if "stream_usage" in model_settings_from_config:
return
if "base_url" in model_settings_from_config or "openai_api_base" in model_settings_from_config:
model_settings_from_config["stream_usage"] = True
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel: def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel:
"""Create a chat model instance from the config. """Create a chat model instance from the config.
@@ -97,6 +113,8 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
kwargs.pop("reasoning_effort", None) kwargs.pop("reasoning_effort", None)
model_settings_from_config.pop("reasoning_effort", None) model_settings_from_config.pop("reasoning_effort", None)
_enable_stream_usage_by_default(model_config.use, model_settings_from_config)
# For Codex Responses API models: map thinking mode to reasoning_effort # For Codex Responses API models: map thinking mode to reasoning_effort
from deerflow.models.openai_codex_provider import CodexChatModel from deerflow.models.openai_codex_provider import CodexChatModel
@@ -113,6 +131,12 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
elif "reasoning_effort" not in model_settings_from_config: elif "reasoning_effort" not in model_settings_from_config:
model_settings_from_config["reasoning_effort"] = "medium" model_settings_from_config["reasoning_effort"] = "medium"
# For MindIE models: enforce conservative retry defaults.
# Timeout normalization is handled inside MindIEChatModel itself.
if getattr(model_class, "__name__", "") == "MindIEChatModel":
# Enforce max_retries constraint to prevent cascading timeouts.
model_settings_from_config["max_retries"] = model_settings_from_config.get("max_retries", 1)
model_instance = model_class(**{**model_settings_from_config, **kwargs}) model_instance = model_class(**{**model_settings_from_config, **kwargs})
callbacks = build_tracing_callbacks() callbacks = build_tracing_callbacks()
@@ -0,0 +1,237 @@
import ast
import json
import re
import uuid
from collections.abc import Iterator
import httpx
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
def _fix_messages(messages: list) -> list:
"""Sanitize incoming messages for MindIE compatibility.
MindIE's chat template may fail to parse LangChain's native tool_calls
or ToolMessage roles, resulting in 0-token generation errors. This function
flattens multi-modal list contents into strings and converts tool-related
messages into raw text with XML tags expected by the underlying model.
"""
fixed = []
for msg in messages:
# Flatten content if it's a list of blocks
if isinstance(msg.content, list):
parts = []
for block in msg.content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict) and block.get("type") == "text":
parts.append(block.get("text", ""))
text = "".join(parts)
else:
text = msg.content or ""
# Convert AIMessage with tool_calls to raw XML text format
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", []):
xml_parts = []
for tool in msg.tool_calls:
args_xml = " ".join(f"<parameter={k}>{json.dumps(v, ensure_ascii=False)}</parameter>" for k, v in tool.get("args", {}).items())
xml_parts.append(f"<tool_call> <function={tool['name']}> {args_xml} </function> </tool_call>")
full_text = f"{text}\n" + "\n".join(xml_parts) if text else "\n".join(xml_parts)
fixed.append(AIMessage(content=full_text.strip() or " "))
continue
# Wrap tool execution results in XML tags and convert to HumanMessage
if isinstance(msg, ToolMessage):
tool_result_text = f"<tool_response>\n{text}\n</tool_response>"
fixed.append(HumanMessage(content=tool_result_text))
continue
# Fallback to prevent completely empty message content
if not text.strip():
text = " "
fixed.append(msg.model_copy(update={"content": text}))
return fixed
def _parse_xml_tool_call_to_dict(content: str) -> tuple[str, list[dict]]:
"""Parse XML-style tool calls from model output into LangChain dicts.
Args:
content: The raw text output from the model.
Returns:
A tuple containing the cleaned text (with XML blocks removed) and
a list of tool call dictionaries formatted for LangChain.
"""
if not isinstance(content, str) or "<tool_call>" not in content:
return content, []
tool_calls = []
clean_parts: list[str] = []
cursor = 0
for start, end, inner_content in _iter_tool_call_blocks(content):
clean_parts.append(content[cursor:start])
cursor = end
func_match = re.search(r"<function=([^>]+)>", inner_content)
if not func_match:
continue
function_name = func_match.group(1).strip()
args = {}
param_pattern = re.compile(r"<parameter=([^>]+)>(.*?)</parameter>", re.DOTALL)
for param_match in param_pattern.finditer(inner_content):
key = param_match.group(1).strip()
raw_value = param_match.group(2).strip()
# Attempt to deserialize string values into native Python types
# to satisfy downstream Pydantic validation.
parsed_value = raw_value
if raw_value.startswith(("[", "{")) or raw_value in ("true", "false", "null") or raw_value.isdigit():
try:
parsed_value = json.loads(raw_value)
except json.JSONDecodeError:
try:
parsed_value = ast.literal_eval(raw_value)
except (ValueError, SyntaxError):
pass
args[key] = parsed_value
tool_calls.append({"name": function_name, "args": args, "id": f"call_{uuid.uuid4().hex[:10]}"})
clean_parts.append(content[cursor:])
return "".join(clean_parts).strip(), tool_calls
def _iter_tool_call_blocks(content: str) -> Iterator[tuple[int, int, str]]:
"""Iterate `<tool_call>...</tool_call>` blocks and tolerate nesting."""
token_pattern = re.compile(r"</?tool_call>")
depth = 0
block_start = -1
for match in token_pattern.finditer(content):
token = match.group(0)
if token == "<tool_call>":
if depth == 0:
block_start = match.start()
depth += 1
continue
if depth == 0:
continue
depth -= 1
if depth == 0 and block_start != -1:
block_end = match.end()
inner_start = block_start + len("<tool_call>")
inner_end = match.start()
yield block_start, block_end, content[inner_start:inner_end]
block_start = -1
def _decode_escaped_newlines_outside_fences(content: str) -> str:
"""Decode literal `\\n` outside fenced code blocks."""
if "\\n" not in content:
return content
parts = re.split(r"(```[\s\S]*?```)", content)
for idx, part in enumerate(parts):
if part.startswith("```"):
continue
parts[idx] = part.replace("\\n", "\n")
return "".join(parts)
class MindIEChatModel(ChatOpenAI):
"""Chat model adapter for MindIE engine.
Addresses compatibility issues including:
- Flattening multimodal list contents to strings.
- Intercepting and parsing hardcoded XML tool calls into LangChain standard.
- Handling stream=True dropping choices when tools are present by falling back
to non-streaming generation and yielding simulated chunks.
- Fixing over-escaped newline characters from gateway responses.
"""
def __init__(self, **kwargs):
"""Normalize timeout kwargs without creating long-lived clients."""
connect_timeout = kwargs.pop("connect_timeout", 30.0)
read_timeout = kwargs.pop("read_timeout", 900.0)
write_timeout = kwargs.pop("write_timeout", 60.0)
pool_timeout = kwargs.pop("pool_timeout", 30.0)
kwargs.setdefault(
"timeout",
httpx.Timeout(
connect=connect_timeout,
read=read_timeout,
write=write_timeout,
pool=pool_timeout,
),
)
super().__init__(**kwargs)
def _patch_result_with_tools(self, result: ChatResult) -> ChatResult:
"""Apply post-generation fixes to the model result."""
for gen in result.generations:
msg = gen.message
if isinstance(msg.content, str):
# Keep escaped newlines inside fenced code blocks untouched.
msg.content = _decode_escaped_newlines_outside_fences(msg.content)
if "<tool_call>" in msg.content:
clean_content, extracted_tools = _parse_xml_tool_call_to_dict(msg.content)
if extracted_tools:
msg.content = clean_content
if getattr(msg, "tool_calls", None) is None:
msg.tool_calls = []
msg.tool_calls.extend(extracted_tools)
return result
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
result = super()._generate(_fix_messages(messages), stop=stop, run_manager=run_manager, **kwargs)
return self._patch_result_with_tools(result)
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
result = await super()._agenerate(_fix_messages(messages), stop=stop, run_manager=run_manager, **kwargs)
return self._patch_result_with_tools(result)
async def _astream(self, messages, stop=None, run_manager=None, **kwargs):
# Route standard queries to native streaming for lower TTFB
if not kwargs.get("tools"):
async for chunk in super()._astream(_fix_messages(messages), stop=stop, run_manager=run_manager, **kwargs):
if isinstance(chunk.message.content, str):
chunk.message.content = _decode_escaped_newlines_outside_fences(chunk.message.content)
yield chunk
return
# Fallback for tool-enabled requests:
# MindIE currently drops choices when stream=True and tools are present.
# We await the full generation and yield chunks to simulate streaming.
result = await self._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs)
for gen in result.generations:
msg = gen.message
content = msg.content
standard_tool_calls = getattr(msg, "tool_calls", [])
# Yield text in chunks to allow downstream UI/Markdown parsers to render smoothly
if isinstance(content, str) and content:
chunk_size = 15
for i in range(0, len(content), chunk_size):
chunk_text = content[i : i + chunk_size]
chunk_msg = AIMessageChunk(content=chunk_text, id=msg.id, response_metadata=msg.response_metadata if i == 0 else {})
yield ChatGenerationChunk(message=chunk_msg, generation_info=gen.generation_info if i == 0 else None)
if standard_tool_calls:
yield ChatGenerationChunk(message=AIMessageChunk(content="", id=msg.id, tool_calls=standard_tool_calls, invalid_tool_calls=getattr(msg, "invalid_tool_calls", [])))
else:
chunk_msg = AIMessageChunk(content=content, id=msg.id, tool_calls=standard_tool_calls, invalid_tool_calls=getattr(msg, "invalid_tool_calls", []))
yield ChatGenerationChunk(message=chunk_msg, generation_info=gen.generation_info)
@@ -288,10 +288,10 @@ class LocalSandbox(Sandbox):
timeout=600, timeout=600,
) )
else: else:
args = [shell, "-c", resolved_command]
result = subprocess.run( result = subprocess.run(
resolved_command, args,
executable=shell, shell=False,
shell=True,
capture_output=True, capture_output=True,
text=True, text=True,
timeout=600, timeout=600,
@@ -11,6 +11,8 @@ _singleton: LocalSandbox | None = None
class LocalSandboxProvider(SandboxProvider): class LocalSandboxProvider(SandboxProvider):
uses_thread_data_mounts = True
def __init__(self): def __init__(self):
"""Initialize the local sandbox provider with path mappings.""" """Initialize the local sandbox provider with path mappings."""
self._path_mappings = self._setup_path_mappings() self._path_mappings = self._setup_path_mappings()
@@ -7,6 +7,7 @@ from langgraph.runtime import Runtime
from deerflow.agents.thread_state import SandboxState, ThreadDataState from deerflow.agents.thread_state import SandboxState, ThreadDataState
from deerflow.sandbox import get_sandbox_provider from deerflow.sandbox import get_sandbox_provider
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -56,7 +57,7 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
# Eager initialization (original behavior) # Eager initialization (original behavior)
if "sandbox" not in state or state["sandbox"] is None: if "sandbox" not in state or state["sandbox"] is None:
thread_id = (runtime.context or {}).get("thread_id") thread_id = get_thread_id(runtime)
if thread_id is None: if thread_id is None:
return super().before_agent(state, runtime) return super().before_agent(state, runtime)
sandbox_id = self._acquire_sandbox(thread_id) sandbox_id = self._acquire_sandbox(thread_id)
@@ -8,6 +8,8 @@ from deerflow.sandbox.sandbox import Sandbox
class SandboxProvider(ABC): class SandboxProvider(ABC):
"""Abstract base class for sandbox providers""" """Abstract base class for sandbox providers"""
uses_thread_data_mounts: bool = False
@abstractmethod @abstractmethod
def acquire(self, thread_id: str | None = None) -> str: def acquire(self, thread_id: str | None = None) -> str:
"""Acquire a sandbox environment and return its ID. """Acquire a sandbox environment and return its ID.
@@ -19,6 +19,7 @@ from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import get_sandbox_provider from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.sandbox.search import GrepMatch from deerflow.sandbox.search import GrepMatch
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
from deerflow.utils.runtime import get_thread_id
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)") _ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
_FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE) _FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE)
@@ -851,11 +852,9 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
# Sandbox was released, fall through to acquire new one # Sandbox was released, fall through to acquire new one
# Lazy acquisition: get thread_id and acquire sandbox # Lazy acquisition: get thread_id and acquire sandbox
thread_id = runtime.context.get("thread_id") if runtime.context else None thread_id = get_thread_id(runtime)
if thread_id is None: if thread_id is None:
thread_id = runtime.config.get("configurable", {}).get("thread_id") if runtime.config else None raise SandboxRuntimeError("Thread ID not available in runtime context, runtime config, or LangGraph config")
if thread_id is None:
raise SandboxRuntimeError("Thread ID not available in runtime context")
provider = get_sandbox_provider() provider = get_sandbox_provider()
sandbox_id = provider.acquire(thread_id) sandbox_id = provider.acquire(thread_id)
@@ -1047,6 +1046,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
sandbox = ensure_sandbox_initialized(runtime) sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime) ensure_thread_directories_exist(runtime)
requested_path = path requested_path = path
thread_data = None
if is_local_sandbox(runtime): if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime) thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data, read_only=True) validate_local_tool_path(path, thread_data, read_only=True)
@@ -1061,6 +1061,8 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
if not children: if not children:
return "(empty)" return "(empty)"
output = "\n".join(children) output = "\n".join(children)
if thread_data is not None:
output = mask_local_paths_in_output(output, thread_data)
try: try:
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import get_app_config
@@ -2,21 +2,24 @@ import logging
import re import re
from pathlib import Path from pathlib import Path
import yaml
from .types import Skill from .types import Skill
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None = None) -> Skill | None: def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None = None) -> Skill | None:
""" """Parse a SKILL.md file and extract metadata.
Parse a SKILL.md file and extract metadata.
Args: Args:
skill_file: Path to the SKILL.md file skill_file: Path to the SKILL.md file.
category: Category of the skill ('public' or 'custom') category: Category of the skill ('public' or 'custom').
relative_path: Relative path from the category root to the skill
directory. Defaults to the skill directory name when omitted.
Returns: Returns:
Skill object if parsing succeeds, None otherwise Skill object if parsing succeeds, None otherwise.
""" """
if not skill_file.exists() or skill_file.name != "SKILL.md": if not skill_file.exists() or skill_file.name != "SKILL.md":
return None return None
@@ -24,90 +27,42 @@ def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None
try: try:
content = skill_file.read_text(encoding="utf-8") content = skill_file.read_text(encoding="utf-8")
# Extract YAML front matter # Extract YAML front-matter block between leading ``---`` fences.
# Pattern: ---\nkey: value\n---
front_matter_match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL) front_matter_match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
if not front_matter_match: if not front_matter_match:
return None return None
front_matter = front_matter_match.group(1) front_matter_text = front_matter_match.group(1)
# Parse YAML front matter with basic multiline string support try:
metadata = {} metadata = yaml.safe_load(front_matter_text)
lines = front_matter.split("\n") except yaml.YAMLError as exc:
current_key = None logger.error("Invalid YAML front-matter in %s: %s", skill_file, exc)
current_value = [] return None
is_multiline = False
multiline_style = None
indent_level = None
for line in lines: if not isinstance(metadata, dict):
if is_multiline: logger.error("Front-matter in %s is not a YAML mapping", skill_file)
if not line.strip(): return None
current_value.append("")
continue
current_indent = len(line) - len(line.lstrip()) # Extract required fields. Both must be non-empty strings.
if indent_level is None:
if current_indent > 0:
indent_level = current_indent
current_value.append(line[indent_level:])
continue
elif current_indent >= indent_level:
current_value.append(line[indent_level:])
continue
# If we reach here, it's either a new key or the end of multiline
if current_key and is_multiline:
if multiline_style == "|":
metadata[current_key] = "\n".join(current_value).rstrip()
else:
text = "\n".join(current_value).rstrip()
# Replace single newlines with spaces for folded blocks
metadata[current_key] = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
current_key = None
current_value = []
is_multiline = False
multiline_style = None
indent_level = None
if not line.strip():
continue
if ":" in line:
# Handle nested dicts simply by ignoring indentation for now,
# or just extracting top-level keys
key, value = line.split(":", 1)
key = key.strip()
value = value.strip()
if value in (">", "|"):
current_key = key
is_multiline = True
multiline_style = value
current_value = []
indent_level = None
else:
metadata[key] = value
if current_key and is_multiline:
if multiline_style == "|":
metadata[current_key] = "\n".join(current_value).rstrip()
else:
text = "\n".join(current_value).rstrip()
metadata[current_key] = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
# Extract required fields
name = metadata.get("name") name = metadata.get("name")
description = metadata.get("description") description = metadata.get("description")
if not name or not isinstance(name, str):
return None
if not description or not isinstance(description, str):
return None
# Normalise: strip surrounding whitespace that YAML may preserve.
name = name.strip()
description = description.strip()
if not name or not description: if not name or not description:
return None return None
license_text = metadata.get("license") license_text = metadata.get("license")
if license_text is not None:
license_text = str(license_text).strip() or None
return Skill( return Skill(
name=name, name=name,
@@ -117,9 +72,9 @@ def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None
skill_file=skill_file, skill_file=skill_file,
relative_path=relative_path or Path(skill_file.parent.name), relative_path=relative_path or Path(skill_file.parent.name),
category=category, category=category,
enabled=True, # Default to enabled, actual state comes from config file enabled=True, # Actual state comes from the extensions config file.
) )
except Exception as e: except Exception:
logger.error("Error parsing skill file %s: %s", skill_file, e) logger.exception("Unexpected error parsing skill file %s", skill_file)
return None return None
@@ -54,7 +54,8 @@ async def scan_skill_content(content: str, *, executable: bool = False, location
[ [
{"role": "system", "content": rubric}, {"role": "system", "content": rubric},
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
] ],
config={"run_name": "security_agent"},
) )
parsed = _extract_json_object(str(getattr(response, "content", "") or "")) parsed = _extract_json_object(str(getattr(response, "content", "") or ""))
if parsed and parsed.get("decision") in {"allow", "warn", "block"}: if parsed and parsed.get("decision") in {"allow", "warn", "block"}:
@@ -13,6 +13,8 @@ class SubagentConfig:
system_prompt: The system prompt that guides the subagent's behavior. system_prompt: The system prompt that guides the subagent's behavior.
tools: Optional list of tool names to allow. If None, inherits all tools. tools: Optional list of tool names to allow. If None, inherits all tools.
disallowed_tools: Optional list of tool names to deny. disallowed_tools: Optional list of tool names to deny.
skills: Optional list of skill names to load. If None, inherits all enabled skills.
If an empty list, no skills are loaded.
model: Model to use - 'inherit' uses parent's model. model: Model to use - 'inherit' uses parent's model.
max_turns: Maximum number of agent turns before stopping. max_turns: Maximum number of agent turns before stopping.
timeout_seconds: Maximum execution time in seconds (default: 900 = 15 minutes). timeout_seconds: Maximum execution time in seconds (default: 900 = 15 minutes).
@@ -23,6 +25,7 @@ class SubagentConfig:
system_prompt: str system_prompt: str
tools: list[str] | None = None tools: list[str] | None = None
disallowed_tools: list[str] | None = field(default_factory=lambda: ["task"]) disallowed_tools: list[str] | None = field(default_factory=lambda: ["task"])
skills: list[str] | None = None
model: str = "inherit" model: str = "inherit"
max_turns: int = 50 max_turns: int = 50
timeout_seconds: int = 900 timeout_seconds: int = 900
@@ -13,7 +13,7 @@ from typing import Any
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState
@@ -184,7 +184,63 @@ class SubagentExecutor:
state_schema=ThreadState, state_schema=ThreadState,
) )
def _build_initial_state(self, task: str) -> dict[str, Any]: async def _load_skill_messages(self) -> list[SystemMessage]:
"""Load skill content as conversation items based on config.skills.
Aligned with Codex's pattern: each subagent loads its own skills
per-session and injects them as conversation items (developer messages),
not as system prompt text. The config.skills whitelist controls which
skills are loaded:
- None: load all enabled skills
- []: no skills
- ["skill-a", "skill-b"]: only these skills
Returns:
List of SystemMessages containing skill content.
"""
if self.config.skills is not None and len(self.config.skills) == 0:
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} skills=[] — skipping skill loading")
return []
try:
from deerflow.skills.loader import load_skills
# Use asyncio.to_thread to avoid blocking the event loop (LangGraph ASGI requirement)
all_skills = await asyncio.to_thread(load_skills, enabled_only=True)
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} loaded {len(all_skills)} enabled skills from disk")
except Exception:
logger.warning(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}", exc_info=True)
return []
if not all_skills:
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} no enabled skills found")
return []
# Filter by config.skills whitelist
if self.config.skills is not None:
allowed = set(self.config.skills)
skills = [s for s in all_skills if s.name in allowed]
else:
skills = all_skills
if not skills:
return []
# Read each skill's SKILL.md content and create conversation items
messages = []
for skill in skills:
try:
content = await asyncio.to_thread(skill.skill_file.read_text, encoding="utf-8")
content = content.strip()
if content:
messages.append(SystemMessage(content=f'<skill name="{skill.name}">\n{content}\n</skill>'))
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} loaded skill: {skill.name}")
except Exception:
logger.debug(f"[trace={self.trace_id}] Failed to read skill {skill.name}", exc_info=True)
return messages
async def _build_initial_state(self, task: str) -> dict[str, Any]:
"""Build the initial state for agent execution. """Build the initial state for agent execution.
Args: Args:
@@ -193,8 +249,17 @@ class SubagentExecutor:
Returns: Returns:
Initial state dictionary. Initial state dictionary.
""" """
# Load skills as conversation items (Codex pattern)
skill_messages = await self._load_skill_messages()
messages: list = []
# Skill content injected as developer/system messages before the task
messages.extend(skill_messages)
# Then the actual task
messages.append(HumanMessage(content=task))
state: dict[str, Any] = { state: dict[str, Any] = {
"messages": [HumanMessage(content=task)], "messages": messages,
} }
# Pass through sandbox and thread data from parent # Pass through sandbox and thread data from parent
@@ -230,7 +295,7 @@ class SubagentExecutor:
try: try:
agent = self._create_agent() agent = self._create_agent()
state = self._build_initial_state(task) state = await self._build_initial_state(task)
# Build config with thread_id for sandbox access and recursion limit # Build config with thread_id for sandbox access and recursion limit
run_config: RunnableConfig = { run_config: RunnableConfig = {
@@ -10,53 +10,100 @@ from deerflow.subagents.config import SubagentConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _build_custom_subagent_config(name: str) -> SubagentConfig | None:
"""Build a SubagentConfig from config.yaml custom_agents section.
Args:
name: The name of the custom subagent.
Returns:
SubagentConfig if found in custom_agents, None otherwise.
"""
from deerflow.config.subagents_config import get_subagents_app_config
app_config = get_subagents_app_config()
custom = app_config.custom_agents.get(name)
if custom is None:
return None
return SubagentConfig(
name=name,
description=custom.description,
system_prompt=custom.system_prompt,
tools=custom.tools,
disallowed_tools=custom.disallowed_tools,
skills=custom.skills,
model=custom.model,
max_turns=custom.max_turns,
timeout_seconds=custom.timeout_seconds,
)
def get_subagent_config(name: str) -> SubagentConfig | None: def get_subagent_config(name: str) -> SubagentConfig | None:
"""Get a subagent configuration by name, with config.yaml overrides applied. """Get a subagent configuration by name, with config.yaml overrides applied.
Resolution order (mirrors Codex's config layering):
1. Built-in subagents (general-purpose, bash)
2. Custom subagents from config.yaml custom_agents section
3. Per-agent overrides from config.yaml agents section (timeout, max_turns, model, skills)
Args: Args:
name: The name of the subagent. name: The name of the subagent.
Returns: Returns:
SubagentConfig if found (with any config.yaml overrides applied), None otherwise. SubagentConfig if found (with any config.yaml overrides applied), None otherwise.
""" """
# Step 1: Look up built-in, then fall back to custom_agents
config = BUILTIN_SUBAGENTS.get(name) config = BUILTIN_SUBAGENTS.get(name)
if config is None:
config = _build_custom_subagent_config(name)
if config is None: if config is None:
return None return None
# Apply runtime overrides (timeout, max_turns, model) from config.yaml # Step 2: Apply per-agent overrides from config.yaml agents section.
# Only explicit per-agent overrides are applied here. Global defaults
# (timeout_seconds, max_turns at the top level) apply to built-in agents
# but must NOT override custom agents' own values — custom agents define
# their own defaults in the custom_agents section.
# Lazy import to avoid circular deps. # Lazy import to avoid circular deps.
from deerflow.config.subagents_config import get_subagents_app_config from deerflow.config.subagents_config import get_subagents_app_config
app_config = get_subagents_app_config() app_config = get_subagents_app_config()
effective_timeout = app_config.get_timeout_for(name) is_builtin = name in BUILTIN_SUBAGENTS
effective_max_turns = app_config.get_max_turns_for(name, config.max_turns) agent_override = app_config.agents.get(name)
overrides = {} overrides = {}
if effective_timeout != config.timeout_seconds:
logger.debug( # Timeout: per-agent override > global default (builtins only) > config's own value
"Subagent '%s': timeout overridden by config.yaml (%ss -> %ss)", if agent_override is not None and agent_override.timeout_seconds is not None:
name, if agent_override.timeout_seconds != config.timeout_seconds:
config.timeout_seconds, logger.debug("Subagent '%s': timeout overridden (%ss -> %ss)", name, config.timeout_seconds, agent_override.timeout_seconds)
effective_timeout, overrides["timeout_seconds"] = agent_override.timeout_seconds
) elif is_builtin and app_config.timeout_seconds != config.timeout_seconds:
overrides["timeout_seconds"] = effective_timeout logger.debug("Subagent '%s': timeout from global default (%ss -> %ss)", name, config.timeout_seconds, app_config.timeout_seconds)
if effective_max_turns != config.max_turns: overrides["timeout_seconds"] = app_config.timeout_seconds
logger.debug(
"Subagent '%s': max_turns overridden by config.yaml (%s -> %s)", # Max turns: per-agent override > global default (builtins only) > config's own value
name, if agent_override is not None and agent_override.max_turns is not None:
config.max_turns, if agent_override.max_turns != config.max_turns:
effective_max_turns, logger.debug("Subagent '%s': max_turns overridden (%s -> %s)", name, config.max_turns, agent_override.max_turns)
) overrides["max_turns"] = agent_override.max_turns
overrides["max_turns"] = effective_max_turns elif is_builtin and app_config.max_turns is not None and app_config.max_turns != config.max_turns:
logger.debug("Subagent '%s': max_turns from global default (%s -> %s)", name, config.max_turns, app_config.max_turns)
overrides["max_turns"] = app_config.max_turns
# Model: per-agent override only (no global default for model)
effective_model = app_config.get_model_for(name) effective_model = app_config.get_model_for(name)
if effective_model is not None and effective_model != config.model: if effective_model is not None and effective_model != config.model:
logger.debug( logger.debug("Subagent '%s': model overridden (%s -> %s)", name, config.model, effective_model)
"Subagent '%s': model overridden by config.yaml (%s -> %s)",
name,
config.model,
effective_model,
)
overrides["model"] = effective_model overrides["model"] = effective_model
# Skills: per-agent override only (no global default for skills)
effective_skills = app_config.get_skills_for(name)
if effective_skills is not None and effective_skills != config.skills:
logger.debug("Subagent '%s': skills overridden (%s -> %s)", name, config.skills, effective_skills)
overrides["skills"] = effective_skills
if overrides: if overrides:
config = replace(config, **overrides) config = replace(config, **overrides)
@@ -67,18 +114,33 @@ def list_subagents() -> list[SubagentConfig]:
"""List all available subagent configurations (with config.yaml overrides applied). """List all available subagent configurations (with config.yaml overrides applied).
Returns: Returns:
List of all registered SubagentConfig instances. List of all registered SubagentConfig instances (built-in + custom).
""" """
return [get_subagent_config(name) for name in BUILTIN_SUBAGENTS] configs = []
for name in get_subagent_names():
config = get_subagent_config(name)
if config is not None:
configs.append(config)
return configs
def get_subagent_names() -> list[str]: def get_subagent_names() -> list[str]:
"""Get all available subagent names. """Get all available subagent names (built-in + custom).
Returns: Returns:
List of subagent names. List of subagent names.
""" """
return list(BUILTIN_SUBAGENTS.keys()) names = list(BUILTIN_SUBAGENTS.keys())
# Merge custom_agents from config.yaml
from deerflow.config.subagents_config import get_subagents_app_config
app_config = get_subagents_app_config()
for custom_name in app_config.custom_agents:
if custom_name not in names:
names.append(custom_name)
return names
def get_available_subagent_names() -> list[str]: def get_available_subagent_names() -> list[str]:
@@ -87,11 +149,11 @@ def get_available_subagent_names() -> list[str]:
Returns: Returns:
List of subagent names visible to the current sandbox configuration. List of subagent names visible to the current sandbox configuration.
""" """
names = list(BUILTIN_SUBAGENTS.keys()) names = get_subagent_names()
try: try:
host_bash_allowed = is_host_bash_allowed() host_bash_allowed = is_host_bash_allowed()
except Exception: except Exception:
logger.debug("Could not determine host bash availability; exposing all built-in subagents") logger.debug("Could not determine host bash availability; exposing all subagents")
return names return names
if not host_bash_allowed: if not host_bash_allowed:
@@ -8,6 +8,7 @@ from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.utils.runtime import get_thread_id
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs" OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
@@ -33,9 +34,9 @@ def _normalize_presented_filepath(
if runtime.state is None: if runtime.state is None:
raise ValueError("Thread runtime state is not available") raise ValueError("Thread runtime state is not available")
thread_id = runtime.context.get("thread_id") if runtime.context else None thread_id = get_thread_id(runtime)
if not thread_id: if not thread_id:
raise ValueError("Thread ID is not available in runtime context") raise ValueError("Thread ID is not available in runtime context, runtime config, or LangGraph thread-local config")
thread_data = runtime.state.get("thread_data") or {} thread_data = runtime.state.get("thread_data") or {}
outputs_path = thread_data.get("outputs_path") outputs_path = thread_data.get("outputs_path")
@@ -6,6 +6,7 @@ from langchain_core.tools import tool
from langgraph.prebuilt import ToolRuntime from langgraph.prebuilt import ToolRuntime
from langgraph.types import Command from langgraph.types import Command
from deerflow.config.agents_config import validate_agent_name
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,19 +17,25 @@ def setup_agent(
soul: str, soul: str,
description: str, description: str,
runtime: ToolRuntime, runtime: ToolRuntime,
skills: list[str] | None = None,
) -> Command: ) -> Command:
"""Setup the custom DeerFlow agent. """Setup the custom DeerFlow agent.
Args: Args:
soul: Full SOUL.md content defining the agent's personality and behavior. soul: Full SOUL.md content defining the agent's personality and behavior.
description: One-line description of what the agent does. description: One-line description of what the agent does.
skills: Optional list of skill names this agent should use. None means use all enabled skills, empty list means no skills.
""" """
agent_name: str | None = runtime.context.get("agent_name") if runtime.context else None agent_name: str | None = runtime.context.get("agent_name") if runtime.context else None
agent_dir = None
is_new_dir = False
try: try:
agent_name = validate_agent_name(agent_name)
paths = get_paths() paths = get_paths()
agent_dir = paths.agent_dir(agent_name) if agent_name else paths.base_dir agent_dir = paths.agent_dir(agent_name) if agent_name else paths.base_dir
is_new_dir = not agent_dir.exists()
agent_dir.mkdir(parents=True, exist_ok=True) agent_dir.mkdir(parents=True, exist_ok=True)
if agent_name: if agent_name:
@@ -36,6 +43,8 @@ def setup_agent(
config_data: dict = {"name": agent_name} config_data: dict = {"name": agent_name}
if description: if description:
config_data["description"] = description config_data["description"] = description
if skills is not None:
config_data["skills"] = skills
config_file = agent_dir / "config.yaml" config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f: with open(config_file, "w", encoding="utf-8") as f:
@@ -55,8 +64,8 @@ def setup_agent(
except Exception as e: except Exception as e:
import shutil import shutil
if agent_name and agent_dir.exists(): if agent_name and is_new_dir and agent_dir is not None and agent_dir.exists():
# Cleanup the custom agent directory only if it was created but an error occurred during setup # Cleanup the custom agent directory only if it was newly created during this call
shutil.rmtree(agent_dir) shutil.rmtree(agent_dir)
logger.error(f"[agent_creator] Failed to create agent '{agent_name}': {e}", exc_info=True) logger.error(f"[agent_creator] Failed to create agent '{agent_name}': {e}", exc_info=True)
return Command(update={"messages": [ToolMessage(content=f"Error: {e}", tool_call_id=runtime.tool_call_id)]}) return Command(update={"messages": [ToolMessage(content=f"Error: {e}", tool_call_id=runtime.tool_call_id)]})
@@ -10,15 +10,26 @@ from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langgraph.config import get_stream_writer from langgraph.config import get_stream_writer
from langgraph.typing import ContextT from langgraph.typing import ContextT
from deerflow.agents.lead_agent.prompt import get_skills_prompt_section
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _merge_skill_allowlists(parent: list[str] | None, child: list[str] | None) -> list[str] | None:
"""Return the effective subagent skill allowlist under the parent policy."""
if parent is None:
return child
if child is None:
return list(parent)
parent_set = set(parent)
return [skill for skill in child if skill in parent_set]
@tool("task", parse_docstring=True) @tool("task", parse_docstring=True)
async def task_tool( async def task_tool(
runtime: ToolRuntime[ContextT, ThreadState], runtime: ToolRuntime[ContextT, ThreadState],
@@ -35,7 +46,7 @@ async def task_tool(
- Handle complex multi-step tasks autonomously - Handle complex multi-step tasks autonomously
- Execute commands or operations in isolated contexts - Execute commands or operations in isolated contexts
Available subagent types depend on the active sandbox configuration: Built-in subagent types:
- **general-purpose**: A capable agent for complex, multi-step tasks that require - **general-purpose**: A capable agent for complex, multi-step tasks that require
both exploration and action. Use when the task requires complex reasoning, both exploration and action. Use when the task requires complex reasoning,
multiple dependent steps, or would benefit from isolated context. multiple dependent steps, or would benefit from isolated context.
@@ -43,6 +54,11 @@ async def task_tool(
available when host bash is explicitly allowed or when using an isolated shell available when host bash is explicitly allowed or when using an isolated shell
sandbox such as `AioSandboxProvider`. sandbox such as `AioSandboxProvider`.
Additional custom subagent types may be defined in config.yaml under
`subagents.custom_agents`. Each custom type can have its own system prompt,
tools, skills, model, and timeout configuration. If an unknown subagent_type
is provided, the error message will list all available types.
When to use this tool: When to use this tool:
- Complex tasks requiring multiple steps or tools - Complex tasks requiring multiple steps or tools
- Tasks that produce verbose output - Tasks that produce verbose output
@@ -72,29 +88,25 @@ async def task_tool(
# Build config overrides # Build config overrides
overrides: dict = {} overrides: dict = {}
skills_section = get_skills_prompt_section() # Skills are loaded by SubagentExecutor per-session (aligned with Codex's pattern:
if skills_section: # each subagent loads its own skills based on config, injected as conversation items).
overrides["system_prompt"] = config.system_prompt + "\n\n" + skills_section # No longer appended to system_prompt here.
if max_turns is not None: if max_turns is not None:
overrides["max_turns"] = max_turns overrides["max_turns"] = max_turns
if overrides:
config = replace(config, **overrides)
# Extract parent context from runtime # Extract parent context from runtime
sandbox_state = None sandbox_state = None
thread_data = None thread_data = None
thread_id = None thread_id = None
parent_model = None parent_model = None
trace_id = None trace_id = None
metadata: dict = {}
if runtime is not None: if runtime is not None:
sandbox_state = runtime.state.get("sandbox") sandbox_state = runtime.state.get("sandbox")
thread_data = runtime.state.get("thread_data") thread_data = runtime.state.get("thread_data")
thread_id = runtime.context.get("thread_id") if runtime.context else None thread_id = get_thread_id(runtime)
if thread_id is None:
thread_id = runtime.config.get("configurable", {}).get("thread_id")
# Try to get parent model from configurable # Try to get parent model from configurable
metadata = runtime.config.get("metadata", {}) metadata = runtime.config.get("metadata", {})
@@ -103,12 +115,22 @@ async def task_tool(
# Get or generate trace_id for distributed tracing # Get or generate trace_id for distributed tracing
trace_id = metadata.get("trace_id") or str(uuid.uuid4())[:8] trace_id = metadata.get("trace_id") or str(uuid.uuid4())[:8]
parent_available_skills = metadata.get("available_skills")
if parent_available_skills is not None:
overrides["skills"] = _merge_skill_allowlists(list(parent_available_skills), config.skills)
if overrides:
config = replace(config, **overrides)
# Get available tools (excluding task tool to prevent nesting) # Get available tools (excluding task tool to prevent nesting)
# Lazy import to avoid circular dependency # Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools from deerflow.tools import get_available_tools
# Inherit parent agent's tool_groups so subagents respect the same restrictions
parent_tool_groups = metadata.get("tool_groups")
# Subagents should not have subagent tools enabled (prevent recursive nesting) # Subagents should not have subagent tools enabled (prevent recursive nesting)
tools = get_available_tools(model_name=parent_model, subagent_enabled=False) tools = get_available_tools(model_name=parent_model, groups=parent_tool_groups, subagent_enabled=False)
# Create executor # Create executor
executor = SubagentExecutor( executor = SubagentExecutor(
@@ -112,6 +112,15 @@ class DeferredToolRegistry:
def entries(self) -> list[DeferredToolEntry]: def entries(self) -> list[DeferredToolEntry]:
return list(self._entries) return list(self._entries)
@property
def deferred_names(self) -> set[str]:
"""Names of tools that are still hidden from model binding."""
return {entry.name for entry in self._entries}
def contains(self, name: str) -> bool:
"""Return whether *name* is still deferred."""
return any(entry.name == name for entry in self._entries)
def __len__(self) -> int: def __len__(self) -> int:
return len(self._entries) return len(self._entries)
@@ -28,6 +28,7 @@ from deerflow.skills.manager import (
validate_skill_name, validate_skill_name,
) )
from deerflow.skills.security_scanner import scan_skill_content from deerflow.skills.security_scanner import scan_skill_content
from deerflow.utils.runtime import get_thread_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -42,14 +43,6 @@ def _get_lock(name: str) -> asyncio.Lock:
return lock return lock
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None:
if runtime is None:
return None
if runtime.context and runtime.context.get("thread_id"):
return runtime.context.get("thread_id")
return runtime.config.get("configurable", {}).get("thread_id")
def _history_record(*, action: str, file_path: str, prev_content: str | None, new_content: str | None, thread_id: str | None, scanner: dict[str, Any]) -> dict[str, Any]: def _history_record(*, action: str, file_path: str, prev_content: str | None, new_content: str | None, thread_id: str | None, scanner: dict[str, Any]) -> dict[str, Any]:
return { return {
"action": action, "action": action,
@@ -98,7 +91,7 @@ async def _skill_manage_impl(
""" """
name = validate_skill_name(name) name = validate_skill_name(name)
lock = _get_lock(name) lock = _get_lock(name)
thread_id = _get_thread_id(runtime) thread_id = get_thread_id(runtime)
async with lock: async with lock:
if action == "create": if action == "create":
@@ -59,7 +59,22 @@ def get_available_tools(
if not is_host_bash_allowed(config): if not is_host_bash_allowed(config):
tool_configs = [tool for tool in tool_configs if not _is_host_bash_tool(tool)] tool_configs = [tool for tool in tool_configs if not _is_host_bash_tool(tool)]
loaded_tools = [resolve_variable(tool.use, BaseTool) for tool in tool_configs] loaded_tools_raw = [(cfg, resolve_variable(cfg.use, BaseTool)) for cfg in tool_configs]
# Warn when the config ``name`` field and the tool object's ``.name``
# attribute diverge — this mismatch is the root cause of issue #1803 where
# the LLM receives one name in its tool schema but the runtime router
# recognises a different name, producing "not a valid tool" errors.
for cfg, loaded in loaded_tools_raw:
if cfg.name != loaded.name:
logger.warning(
"Tool name mismatch: config name %r does not match tool .name %r (use: %s). The tool's own .name will be used for binding.",
cfg.name,
loaded.name,
cfg.use,
)
loaded_tools = [t for _, t in loaded_tools_raw]
# Conditionally add tools based on config # Conditionally add tools based on config
builtin_tools = BUILTIN_TOOLS.copy() builtin_tools = BUILTIN_TOOLS.copy()
@@ -134,4 +149,20 @@ def get_available_tools(
logger.warning(f"Failed to load ACP tool: {e}") logger.warning(f"Failed to load ACP tool: {e}")
logger.info(f"Total tools loaded: {len(loaded_tools)}, built-in tools: {len(builtin_tools)}, MCP tools: {len(mcp_tools)}, ACP tools: {len(acp_tools)}") logger.info(f"Total tools loaded: {len(loaded_tools)}, built-in tools: {len(builtin_tools)}, MCP tools: {len(mcp_tools)}, ACP tools: {len(acp_tools)}")
return loaded_tools + builtin_tools + mcp_tools + acp_tools
# Deduplicate by tool name — config-loaded tools take priority, followed by
# built-ins, MCP tools, and ACP tools. Duplicate names cause the LLM to
# receive ambiguous or concatenated function schemas (issue #1803).
all_tools = loaded_tools + builtin_tools + mcp_tools + acp_tools
seen_names: set[str] = set()
unique_tools: list[BaseTool] = []
for t in all_tools:
if t.name not in seen_names:
unique_tools.append(t)
seen_names.add(t.name)
else:
logger.warning(
"Duplicate tool name %r detected and skipped — check your config.yaml and MCP server registrations (issue #1803).",
t.name,
)
return unique_tools
@@ -19,6 +19,8 @@ import logging
import re import re
from pathlib import Path from pathlib import Path
from deerflow.config.app_config import get_app_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# File extensions that should be converted to markdown # File extensions that should be converted to markdown
@@ -286,6 +288,15 @@ def extract_outline(md_path: Path) -> list[dict]:
return outline return outline
def _get_uploads_config_value(key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access."""
cfg = get_app_config()
uploads_cfg = getattr(cfg, "uploads", None)
if isinstance(uploads_cfg, dict):
return uploads_cfg.get(key, default)
return getattr(uploads_cfg, key, default)
def _get_pdf_converter() -> str: def _get_pdf_converter() -> str:
"""Read pdf_converter setting from app config, defaulting to 'auto'. """Read pdf_converter setting from app config, defaulting to 'auto'.
@@ -294,16 +305,11 @@ def _get_pdf_converter() -> str:
fall through to unexpected behaviour. fall through to unexpected behaviour.
""" """
try: try:
from deerflow.config.app_config import get_app_config raw = str(_get_uploads_config_value("pdf_converter", "auto")).strip().lower()
if raw not in _ALLOWED_PDF_CONVERTERS:
cfg = get_app_config() logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw)
uploads_cfg = getattr(cfg, "uploads", None) return "auto"
if uploads_cfg is not None: return raw
raw = str(getattr(uploads_cfg, "pdf_converter", "auto")).strip().lower()
if raw not in _ALLOWED_PDF_CONVERTERS:
logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw)
return "auto"
return raw
except Exception: except Exception:
pass pass
return "auto" return "auto"
@@ -0,0 +1,90 @@
"""Runtime utilities for thread_id resolution and context access.
Thread ID Resolution Strategy
=============================
DeerFlow resolves the current ``thread_id`` from a three-level cascade:
1. **runtime.context["thread_id"]** -- Set by ``worker.py`` (gateway mode)
or by LangGraph Server (standard mode) when constructing the Runtime.
2. **runtime.config["configurable"]["thread_id"]** -- Available on
``ToolRuntime`` instances passed to tools via the ``@tool`` decorator.
Not available on ``Runtime`` instances received by middlewares.
3. **get_config()["configurable"]["thread_id"]** -- LangGraph's thread-local
config, available when executing inside a graph's runnable context.
About ``__pregel_runtime``
===========================
In gateway mode (``run_agent()`` in ``worker.py``), the agent graph does not
run inside the LangGraph Server. The server normally injects a ``Runtime``
object automatically. Since we run the graph ourselves, we must inject the
Runtime manually via ``config["configurable"]["__pregel_runtime"]``. This is
the standard mechanism provided by LangGraph's Pregel engine for injecting
runtime context into graph nodes. It is not a private/internal hack -- it is
the documented way to pass Runtime when running a graph outside the server.
Duck Typing
===========
Both ``langgraph.runtime.Runtime`` (middlewares) and
``langchain.tools.ToolRuntime`` (tools) expose a ``.context`` attribute (a
dict or None). ``ToolRuntime`` additionally exposes ``.config``. The
function below uses ``getattr`` with safe defaults so it works with either
type, with ``SimpleNamespace`` in tests, or with ``None``.
"""
from __future__ import annotations
from typing import Any
def get_thread_id(runtime: Any | None) -> str | None:
"""Resolve the current thread_id from a runtime object.
Follows a three-level fallback chain:
1. ``runtime.context.get("thread_id")`` -- if context is a non-empty dict.
2. ``runtime.config.get("configurable", {}).get("thread_id")`` -- if
the runtime has a config dict (ToolRuntime).
3. ``get_config().get("configurable", {}).get("thread_id")`` -- LangGraph's
thread-local config. Wrapped in ``try/except RuntimeError`` because it
raises outside a runnable context (e.g., unit tests).
Args:
runtime: A Runtime, ToolRuntime, SimpleNamespace, or None.
Returns:
The thread_id string, or None if it cannot be resolved.
"""
if runtime is None:
return None
# Level 1: runtime.context["thread_id"]
context = getattr(runtime, "context", None)
if context and isinstance(context, dict):
thread_id = context.get("thread_id")
if thread_id:
return thread_id
# Level 2: runtime.config["configurable"]["thread_id"]
config = getattr(runtime, "config", None)
if config and isinstance(config, dict):
thread_id = config.get("configurable", {}).get("thread_id")
if thread_id:
return thread_id
# Level 3: langgraph.config.get_config() -- only works inside runnable context
try:
from langgraph.config import get_config
config_data = get_config()
thread_id = config_data.get("configurable", {}).get("thread_id")
if thread_id:
return thread_id
except RuntimeError:
# Expected when not running inside a LangGraph runnable context (e.g., unit tests).
# In that case, thread_id cannot be resolved from thread-local config, so fall through.
pass
return None
+7 -2
View File
@@ -8,7 +8,7 @@ dependencies = [
"deerflow-harness", "deerflow-harness",
"fastapi>=0.115.0", "fastapi>=0.115.0",
"httpx>=0.28.0", "httpx>=0.28.0",
"python-multipart>=0.0.20", "python-multipart>=0.0.26",
"sse-starlette>=2.1.0", "sse-starlette>=2.1.0",
"uvicorn[standard]>=0.34.0", "uvicorn[standard]>=0.34.0",
"lark-oapi>=1.4.0", "lark-oapi>=1.4.0",
@@ -20,7 +20,12 @@ dependencies = [
] ]
[dependency-groups] [dependency-groups]
dev = ["pytest>=8.0.0", "ruff>=0.14.11"] dev = [
"prompt-toolkit>=3.0.0",
"pytest>=9.0.3",
"pytest-asyncio>=1.3.0",
"ruff>=0.14.11",
]
[tool.uv.workspace] [tool.uv.workspace]
members = ["packages/harness"] members = ["packages/harness"]
+60
View File
@@ -6,6 +6,7 @@ from pathlib import Path
import yaml import yaml
from deerflow.config.agents_api_config import get_agents_api_config
from deerflow.config.app_config import get_app_config, reset_app_config from deerflow.config.app_config import get_app_config, reset_app_config
@@ -28,6 +29,30 @@ def _write_config(path: Path, *, model_name: str, supports_thinking: bool) -> No
) )
def _write_config_with_agents_api(
path: Path,
*,
model_name: str,
supports_thinking: bool,
agents_api: dict | None = None,
) -> None:
config = {
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
"models": [
{
"name": model_name,
"use": "langchain_openai:ChatOpenAI",
"model": "gpt-test",
"supports_thinking": supports_thinking,
}
],
}
if agents_api is not None:
config["agents_api"] = agents_api
path.write_text(yaml.safe_dump(config), encoding="utf-8")
def _write_extensions_config(path: Path) -> None: def _write_extensions_config(path: Path) -> None:
path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8") path.write_text(json.dumps({"mcpServers": {}, "skills": {}}), encoding="utf-8")
@@ -79,3 +104,38 @@ def test_get_app_config_reloads_when_config_path_changes(tmp_path, monkeypatch):
assert second is not first assert second is not first
finally: finally:
reset_app_config() reset_app_config()
def test_get_app_config_resets_agents_api_config_when_section_removed(tmp_path, monkeypatch):
config_path = tmp_path / "config.yaml"
extensions_path = tmp_path / "extensions_config.json"
_write_extensions_config(extensions_path)
_write_config_with_agents_api(
config_path,
model_name="first-model",
supports_thinking=False,
agents_api={"enabled": True},
)
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_path))
monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(extensions_path))
reset_app_config()
try:
initial = get_app_config()
assert initial.models[0].name == "first-model"
assert get_agents_api_config().enabled is True
_write_config_with_agents_api(
config_path,
model_name="first-model",
supports_thinking=False,
)
next_mtime = config_path.stat().st_mtime + 5
os.utime(config_path, (next_mtime, next_mtime))
reloaded = get_app_config()
assert reloaded is not initial
assert get_agents_api_config().enabled is False
finally:
reset_app_config()
+133 -5
View File
@@ -2011,6 +2011,65 @@ class TestChannelService:
assert service.manager._langgraph_url == "http://custom-langgraph:2024" assert service.manager._langgraph_url == "http://custom-langgraph:2024"
assert service.manager._gateway_url == "http://custom-gateway:8001" assert service.manager._gateway_url == "http://custom-gateway:8001"
def test_disabled_channel_with_string_creds_emits_warning(self, caplog):
"""Warning is emitted when a channel has string credentials but enabled=false."""
import logging
from app.channels.service import ChannelService
async def go():
service = ChannelService(
channels_config={
"wecom": {"enabled": False, "bot_id": "corp123", "bot_secret": "secret"},
}
)
with caplog.at_level(logging.WARNING, logger="app.channels.service"):
await service.start()
await service.stop()
_run(go())
assert any("wecom" in r.message and r.levelno == logging.WARNING for r in caplog.records)
def test_disabled_channel_with_int_creds_emits_warning(self, caplog):
"""Warning is emitted even when YAML-parsed integer credentials are present."""
import logging
from app.channels.service import ChannelService
async def go():
# Simulate YAML parsing a numeric token/ID as an int
service = ChannelService(
channels_config={
"telegram": {"enabled": False, "bot_token": 123456789},
}
)
with caplog.at_level(logging.WARNING, logger="app.channels.service"):
await service.start()
await service.stop()
_run(go())
assert any("telegram" in r.message and r.levelno == logging.WARNING for r in caplog.records)
def test_disabled_channel_without_creds_emits_info(self, caplog):
"""Only an info log (no warning) is emitted when a channel is disabled with no credentials."""
import logging
from app.channels.service import ChannelService
async def go():
service = ChannelService(
channels_config={
"telegram": {"enabled": False},
}
)
with caplog.at_level(logging.DEBUG, logger="app.channels.service"):
await service.start()
await service.stop()
_run(go())
warning_records = [r for r in caplog.records if "telegram" in r.message and r.levelno == logging.WARNING]
assert not warning_records
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Slack send retry tests # Slack send retry tests
@@ -2046,6 +2105,11 @@ class TestSlackSendRetry:
class TestSlackAllowedUsers: class TestSlackAllowedUsers:
@staticmethod
def _submit_coro(coro, loop):
coro.close()
return MagicMock()
def test_numeric_allowed_users_match_string_event_user_id(self): def test_numeric_allowed_users_match_string_event_user_id(self):
from app.channels.slack import SlackChannel from app.channels.slack import SlackChannel
@@ -2067,13 +2131,9 @@ class TestSlackAllowedUsers:
"ts": "1710000000.000100", "ts": "1710000000.000100",
} }
def submit_coro(coro, loop):
coro.close()
return MagicMock()
with patch( with patch(
"app.channels.slack.asyncio.run_coroutine_threadsafe", "app.channels.slack.asyncio.run_coroutine_threadsafe",
side_effect=submit_coro, side_effect=self._submit_coro,
) as submit: ) as submit:
channel._handle_message_event(event) channel._handle_message_event(event)
@@ -2085,6 +2145,74 @@ class TestSlackAllowedUsers:
assert inbound.chat_id == "C123" assert inbound.chat_id == "C123"
assert inbound.text == "hello from slack" assert inbound.text == "hello from slack"
def test_string_allowed_users_match_event_user_id(self):
from app.channels.slack import SlackChannel
bus = MessageBus()
bus.publish_inbound = AsyncMock()
channel = SlackChannel(
bus=bus,
config={"allowed_users": "U123456"},
)
channel._loop = MagicMock()
channel._loop.is_running.return_value = True
channel._add_reaction = MagicMock()
channel._send_running_reply = MagicMock()
event = {
"user": "U123456",
"text": "hello from slack",
"channel": "C123",
"ts": "1710000000.000100",
}
with patch(
"app.channels.slack.asyncio.run_coroutine_threadsafe",
side_effect=self._submit_coro,
) as submit:
channel._handle_message_event(event)
channel._add_reaction.assert_called_once_with("C123", "1710000000.000100", "eyes")
channel._send_running_reply.assert_called_once_with("C123", "1710000000.000100")
submit.assert_called_once()
inbound = bus.publish_inbound.call_args.args[0]
assert inbound.user_id == "U123456"
assert inbound.chat_id == "C123"
assert inbound.text == "hello from slack"
def test_scalar_allowed_users_warns_and_matches_stringified_event_user_id(self, caplog):
from app.channels.slack import SlackChannel
bus = MessageBus()
bus.publish_inbound = AsyncMock()
with caplog.at_level("WARNING"):
channel = SlackChannel(
bus=bus,
config={"allowed_users": 123456},
)
channel._loop = MagicMock()
channel._loop.is_running.return_value = True
channel._add_reaction = MagicMock()
channel._send_running_reply = MagicMock()
event = {
"user": "123456",
"text": "hello from slack",
"channel": "C123",
"ts": "1710000000.000100",
}
with patch(
"app.channels.slack.asyncio.run_coroutine_threadsafe",
side_effect=self._submit_coro,
) as submit:
channel._handle_message_event(event)
assert "Slack allowed_users should be a list" in caplog.text
submit.assert_called_once()
inbound = bus.publish_inbound.call_args.args[0]
assert inbound.user_id == "123456"
def test_raises_after_all_retries_exhausted(self): def test_raises_after_all_retries_exhausted(self):
from app.channels.slack import SlackChannel from app.channels.slack import SlackChannel
+57
View File
@@ -0,0 +1,57 @@
from __future__ import annotations
import importlib.util
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[2]
CHECK_SCRIPT_PATH = REPO_ROOT / "scripts" / "check.py"
spec = importlib.util.spec_from_file_location("deerflow_check_script", CHECK_SCRIPT_PATH)
assert spec is not None
assert spec.loader is not None
check_script = importlib.util.module_from_spec(spec)
spec.loader.exec_module(check_script)
def test_find_pnpm_command_prefers_resolved_executable(monkeypatch):
def fake_which(name: str) -> str | None:
if name == "pnpm":
return r"C:\Users\tester\AppData\Roaming\npm\pnpm.CMD"
if name == "pnpm.cmd":
return r"C:\Users\tester\AppData\Roaming\npm\pnpm.cmd"
return None
monkeypatch.setattr(check_script.shutil, "which", fake_which)
assert check_script.find_pnpm_command() == [r"C:\Users\tester\AppData\Roaming\npm\pnpm.CMD"]
def test_find_pnpm_command_falls_back_to_corepack(monkeypatch):
def fake_which(name: str) -> str | None:
if name == "corepack":
return r"C:\Program Files\nodejs\corepack.exe"
return None
monkeypatch.setattr(check_script.shutil, "which", fake_which)
assert check_script.find_pnpm_command() == [
r"C:\Program Files\nodejs\corepack.exe",
"pnpm",
]
def test_find_pnpm_command_falls_back_to_corepack_cmd(monkeypatch):
def fake_which(name: str) -> str | None:
if name == "corepack":
return None
if name == "corepack.cmd":
return r"C:\Program Files\nodejs\corepack.cmd"
return None
monkeypatch.setattr(check_script.shutil, "which", fake_which)
assert check_script.find_pnpm_command() == [
r"C:\Program Files\nodejs\corepack.cmd",
"pnpm",
]
+73
View File
@@ -150,6 +150,79 @@ class TestGetCheckpointer:
mock_saver_cls.from_conn_string.assert_called_once() mock_saver_cls.from_conn_string.assert_called_once()
mock_saver_instance.setup.assert_called_once() mock_saver_instance.setup.assert_called_once()
def test_sqlite_creates_parent_dir(self):
"""Sync SQLite checkpointer should call ensure_sqlite_parent_dir before connecting.
This mirrors the async checkpointer's behaviour and prevents
'sqlite3.OperationalError: unable to open database file' when the
parent directory for the database file does not yet exist (e.g. when
using the harness package from an external virtualenv where the
.deer-flow directory has not been created).
"""
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "relative/test.db"})
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(return_value=mock_cm)
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
with (
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
patch("deerflow.agents.checkpointer.provider.ensure_sqlite_parent_dir") as mock_ensure,
patch(
"deerflow.agents.checkpointer.provider.resolve_sqlite_conn_str",
return_value="/tmp/resolved/relative/test.db",
),
):
reset_checkpointer()
cp = get_checkpointer()
assert cp is mock_saver_instance
mock_ensure.assert_called_once_with("/tmp/resolved/relative/test.db")
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/relative/test.db")
def test_sqlite_ensure_parent_dir_before_connect(self):
"""ensure_sqlite_parent_dir must be called before from_conn_string."""
load_checkpointer_config_from_dict({"type": "sqlite", "connection_string": "relative/test.db"})
call_order = []
mock_saver_instance = MagicMock()
mock_cm = MagicMock()
mock_cm.__enter__ = MagicMock(return_value=mock_saver_instance)
mock_cm.__exit__ = MagicMock(return_value=False)
mock_saver_cls = MagicMock()
mock_saver_cls.from_conn_string = MagicMock(side_effect=lambda *a, **kw: (call_order.append("connect"), mock_cm)[1])
mock_module = MagicMock()
mock_module.SqliteSaver = mock_saver_cls
def record_ensure(*a, **kw):
call_order.append("ensure")
with (
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite": mock_module}),
patch(
"deerflow.agents.checkpointer.provider.ensure_sqlite_parent_dir",
side_effect=record_ensure,
),
patch(
"deerflow.agents.checkpointer.provider.resolve_sqlite_conn_str",
return_value="/tmp/resolved/relative/test.db",
),
):
reset_checkpointer()
get_checkpointer()
assert call_order == ["ensure", "connect"]
def test_postgres_creates_saver(self): def test_postgres_creates_saver(self):
"""Postgres checkpointer is created when packages are available.""" """Postgres checkpointer is created when packages are available."""
load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"}) load_checkpointer_config_from_dict({"type": "postgres", "connection_string": "postgresql://localhost/db"})
@@ -1,8 +1,10 @@
"""Tests for ClarificationMiddleware, focusing on options type coercion.""" """Tests for ClarificationMiddleware, focusing on options type coercion."""
import json import json
from types import SimpleNamespace
import pytest import pytest
from langgraph.graph.message import add_messages
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
@@ -118,3 +120,60 @@ class TestFormatClarificationMessage:
assert "2. 2" in result assert "2. 2" in result
assert "3. True" in result assert "3. True" in result
assert "4. None" in result assert "4. None" in result
class TestClarificationCommandIdempotency:
"""Clarification tool-call retries should not duplicate messages in state."""
def test_repeated_tool_call_uses_stable_message_id(self, middleware):
request = SimpleNamespace(
tool_call={
"name": "ask_clarification",
"id": "call-clarify-1",
"args": {
"question": "Which environment should I use?",
"clarification_type": "approach_choice",
"options": ["dev", "prod"],
},
}
)
first = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called"))
second = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called"))
first_message = first.update["messages"][0]
second_message = second.update["messages"][0]
assert first_message.id == "clarification:call-clarify-1"
assert second_message.id == first_message.id
assert second_message.tool_call_id == first_message.tool_call_id
merged = add_messages(add_messages([], [first_message]), [second_message])
assert len(merged) == 1
assert merged[0].id == "clarification:call-clarify-1"
assert merged[0].content == first_message.content
def test_missing_tool_call_id_still_gets_stable_message_id(self, middleware):
request = SimpleNamespace(
tool_call={
"name": "ask_clarification",
"args": {
"question": "Which environment should I use?",
"clarification_type": "missing_info",
},
}
)
first = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called"))
second = middleware.wrap_tool_call(request, lambda _req: pytest.fail("handler should not be called"))
first_message = first.update["messages"][0]
second_message = second.update["messages"][0]
assert first_message.id.startswith("clarification:")
assert second_message.id == first_message.id
merged = add_messages(add_messages([], [first_message]), [second_message])
assert len(merged) == 1
@@ -0,0 +1,249 @@
"""Tests for ClaudeChatModel._apply_prompt_caching.
Validates that the function never places more than 4 cache_control breakpoints
(the hard limit enforced by the Anthropic API and AWS Bedrock) regardless of
how many system blocks, message content blocks, or tool definitions are present.
"""
from unittest import mock
import pytest
from deerflow.models.claude_provider import ClaudeChatModel
def _make_model(prompt_cache_size: int = 3) -> ClaudeChatModel:
"""Return a minimal ClaudeChatModel instance without network calls."""
with mock.patch.object(ClaudeChatModel, "model_post_init"):
m = ClaudeChatModel(
model="claude-sonnet-4-6",
anthropic_api_key="sk-ant-fake", # type: ignore[call-arg]
prompt_cache_size=prompt_cache_size,
)
m._is_oauth = False
m.enable_prompt_caching = True
return m
def _count_cache_control(payload: dict) -> int:
"""Count the total number of cache_control markers in a payload."""
count = 0
system = payload.get("system", [])
if isinstance(system, list):
for block in system:
if isinstance(block, dict) and "cache_control" in block:
count += 1
for msg in payload.get("messages", []):
if not isinstance(msg, dict):
continue
content = msg.get("content", [])
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and "cache_control" in block:
count += 1
for tool in payload.get("tools", []):
if isinstance(tool, dict) and "cache_control" in tool:
count += 1
return count
@pytest.fixture()
def model() -> ClaudeChatModel:
return _make_model()
# ---------------------------------------------------------------------------
# Basic correctness
# ---------------------------------------------------------------------------
def test_single_system_block_gets_cached(model):
payload: dict = {"system": [{"type": "text", "text": "sys"}]}
model._apply_prompt_caching(payload)
assert payload["system"][0].get("cache_control") == {"type": "ephemeral"}
def test_string_system_converted_and_cached(model):
payload: dict = {"system": "you are helpful"}
model._apply_prompt_caching(payload)
assert isinstance(payload["system"], list)
assert payload["system"][0].get("cache_control") == {"type": "ephemeral"}
def test_last_tool_gets_cached_when_budget_allows(model):
payload: dict = {
"tools": [{"name": "t1"}, {"name": "t2"}],
}
model._apply_prompt_caching(payload)
# With no system or messages the last tool should be cached.
assert payload["tools"][-1].get("cache_control") == {"type": "ephemeral"}
assert "cache_control" not in payload["tools"][0]
def test_recent_messages_get_cached(model):
"""The last prompt_cache_size messages' content blocks should be cached."""
payload: dict = {
"messages": [
{"role": "user", "content": [{"type": "text", "text": "hello"}]},
],
}
model._apply_prompt_caching(payload)
assert payload["messages"][0]["content"][0].get("cache_control") == {"type": "ephemeral"}
def test_string_message_content_converted_and_cached(model):
payload: dict = {
"messages": [
{"role": "user", "content": "simple string"},
],
}
model._apply_prompt_caching(payload)
assert isinstance(payload["messages"][0]["content"], list)
assert payload["messages"][0]["content"][0].get("cache_control") == {"type": "ephemeral"}
# ---------------------------------------------------------------------------
# Budget enforcement (the core regression test for issue #2448)
# ---------------------------------------------------------------------------
def test_never_exceeds_4_breakpoints_with_large_system(model):
"""Many system text blocks must not produce more than 4 breakpoints total."""
payload: dict = {
"system": [{"type": "text", "text": f"sys {i}"} for i in range(6)],
"tools": [{"name": "t1"}],
}
model._apply_prompt_caching(payload)
assert _count_cache_control(payload) <= 4
def test_never_exceeds_4_breakpoints_multi_turn_with_multi_block_messages(model):
"""Multi-turn conversation where each message has multiple content blocks."""
# 1 system block + 3 messages × 2 blocks + 1 tool = 8 candidates → must cap at 4
payload: dict = {
"system": [{"type": "text", "text": "system prompt"}],
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "user text"},
{"type": "tool_result", "tool_use_id": "x", "content": "result"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "assistant text"},
{"type": "tool_use", "id": "y", "name": "bash", "input": {}},
],
},
{
"role": "user",
"content": [
{"type": "text", "text": "follow up"},
{"type": "text", "text": "second block"},
],
},
],
"tools": [{"name": "bash"}],
}
model._apply_prompt_caching(payload)
total = _count_cache_control(payload)
assert total <= 4, f"Expected ≤ 4 breakpoints, got {total}"
def test_never_exceeds_4_breakpoints_many_messages(model):
"""Large number of messages with multiple blocks per message."""
messages = []
for i in range(10):
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": f"msg {i} block a"},
{"type": "text", "text": f"msg {i} block b"},
],
}
)
payload: dict = {
"system": [{"type": "text", "text": "sys 1"}, {"type": "text", "text": "sys 2"}],
"messages": messages,
"tools": [{"name": "tool_a"}, {"name": "tool_b"}],
}
model._apply_prompt_caching(payload)
total = _count_cache_control(payload)
assert total <= 4, f"Expected ≤ 4 breakpoints, got {total}"
def test_exactly_4_breakpoints_when_4_or_more_candidates(model):
"""When there are at least 4 candidates, exactly 4 breakpoints are placed."""
payload: dict = {
"system": [{"type": "text", "text": f"sys {i}"} for i in range(3)],
"messages": [
{"role": "user", "content": [{"type": "text", "text": "user"}]},
{"role": "assistant", "content": [{"type": "text", "text": "asst"}]},
{"role": "user", "content": [{"type": "text", "text": "follow"}]},
],
"tools": [{"name": "bash"}],
}
model._apply_prompt_caching(payload)
total = _count_cache_control(payload)
assert total == 4
def test_breakpoints_placed_on_last_candidates(model):
"""Breakpoints should be on the *last* candidates, not the first."""
# 5 system blocks but budget = 4 → first system block should NOT be cached,
# last 4 (indices 1-4) should be.
payload: dict = {
"system": [{"type": "text", "text": f"sys {i}"} for i in range(5)],
}
model._apply_prompt_caching(payload)
# First block is NOT in the last-4 window
assert "cache_control" not in payload["system"][0]
# Last 4 blocks ARE cached
for i in range(1, 5):
assert payload["system"][i].get("cache_control") == {"type": "ephemeral"}, f"block {i} should be cached"
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
def test_no_candidates_is_a_no_op(model):
payload: dict = {}
model._apply_prompt_caching(payload)
assert _count_cache_control(payload) == 0
def test_non_text_system_blocks_not_added_as_candidates(model):
"""Image blocks in system should not receive cache_control."""
payload: dict = {
"system": [
{"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": "abc"}},
{"type": "text", "text": "text block"},
],
}
model._apply_prompt_caching(payload)
assert "cache_control" not in payload["system"][0]
assert payload["system"][1].get("cache_control") == {"type": "ephemeral"}
def test_old_messages_outside_cache_window_not_cached(model):
"""Messages older than prompt_cache_size should not be cached."""
m = _make_model(prompt_cache_size=1)
payload: dict = {
"messages": [
{"role": "user", "content": [{"type": "text", "text": "old message"}]},
{"role": "user", "content": [{"type": "text", "text": "recent message"}]},
],
}
m._apply_prompt_caching(payload)
# Only the last message should be within the cache window
assert "cache_control" not in payload["messages"][0]["content"][0]
assert payload["messages"][1]["content"][0].get("cache_control") == {"type": "ephemeral"}
+5
View File
@@ -38,6 +38,7 @@ def mock_app_config():
config = MagicMock() config = MagicMock()
config.models = [model] config.models = [model]
config.token_usage.enabled = False
return config return config
@@ -107,6 +108,7 @@ class TestConfigQueries:
def test_list_models(self, client): def test_list_models(self, client):
result = client.list_models() result = client.list_models()
assert "models" in result assert "models" in result
assert result["token_usage"] == {"enabled": False}
assert len(result["models"]) == 1 assert len(result["models"]) == 1
assert result["models"][0]["name"] == "test-model" assert result["models"][0]["name"] == "test-model"
# Verify Gateway-aligned fields are present # Verify Gateway-aligned fields are present
@@ -2196,7 +2198,9 @@ class TestGatewayConformance:
model.display_name = "Test Model" model.display_name = "Test Model"
model.description = "A test model" model.description = "A test model"
model.supports_thinking = False model.supports_thinking = False
model.supports_reasoning_effort = False
mock_app_config.models = [model] mock_app_config.models = [model]
mock_app_config.token_usage.enabled = True
with patch("deerflow.client.get_app_config", return_value=mock_app_config): with patch("deerflow.client.get_app_config", return_value=mock_app_config):
client = DeerFlowClient() client = DeerFlowClient()
@@ -2206,6 +2210,7 @@ class TestGatewayConformance:
assert len(parsed.models) == 1 assert len(parsed.models) == 1
assert parsed.models[0].name == "test-model" assert parsed.models[0].name == "test-model"
assert parsed.models[0].model == "gpt-test" assert parsed.models[0].model == "gpt-test"
assert parsed.token_usage.enabled is True
def test_get_model(self, mock_app_config): def test_get_model(self, mock_app_config):
model = MagicMock() model = MagicMock()
+67 -6
View File
@@ -9,6 +9,8 @@ import pytest
import yaml import yaml
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from deerflow.config.agents_api_config import AgentsApiConfig, get_agents_api_config, set_agents_api_config
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -387,13 +389,38 @@ def _make_test_app(tmp_path: Path):
@pytest.fixture() @pytest.fixture()
def agent_client(tmp_path): def agent_client(tmp_path):
"""TestClient with agents router, using tmp_path as base_dir.""" """TestClient with agents router, using tmp_path as base_dir."""
paths_instance = _make_paths(tmp_path) import app.gateway.routers.agents as agents_router
with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch("app.gateway.routers.agents.get_paths", return_value=paths_instance): paths_instance = _make_paths(tmp_path)
app = _make_test_app(tmp_path) previous_config = AgentsApiConfig(**get_agents_api_config().model_dump())
with TestClient(app) as client:
client._tmp_path = tmp_path # type: ignore[attr-defined] with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch.object(agents_router, "get_paths", return_value=paths_instance):
yield client set_agents_api_config(AgentsApiConfig(enabled=True))
try:
app = _make_test_app(tmp_path)
with TestClient(app) as client:
client._tmp_path = tmp_path # type: ignore[attr-defined]
yield client
finally:
set_agents_api_config(previous_config)
@pytest.fixture()
def disabled_agent_client(tmp_path):
"""TestClient with agents router while the management API is disabled."""
import app.gateway.routers.agents as agents_router
paths_instance = _make_paths(tmp_path)
previous_config = AgentsApiConfig(**get_agents_api_config().model_dump())
with patch("deerflow.config.agents_config.get_paths", return_value=paths_instance), patch.object(agents_router, "get_paths", return_value=paths_instance):
set_agents_api_config(AgentsApiConfig(enabled=False))
try:
app = _make_test_app(tmp_path)
with TestClient(app) as client:
yield client
finally:
set_agents_api_config(previous_config)
class TestAgentsAPI: class TestAgentsAPI:
@@ -559,3 +586,37 @@ class TestUserProfileAPI:
response = agent_client.put("/api/user-profile", json={"content": ""}) response = agent_client.put("/api/user-profile", json={"content": ""})
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["content"] is None assert response.json()["content"] is None
class TestAgentsApiDisabled:
def test_agents_list_returns_403(self, disabled_agent_client):
response = disabled_agent_client.get("/api/agents")
assert response.status_code == 403
assert "agents_api.enabled=true" in response.json()["detail"]
def test_agent_get_returns_403(self, disabled_agent_client):
response = disabled_agent_client.get("/api/agents/example-agent")
assert response.status_code == 403
def test_agent_name_check_returns_403(self, disabled_agent_client):
response = disabled_agent_client.get("/api/agents/check", params={"name": "example-agent"})
assert response.status_code == 403
def test_agent_create_returns_403(self, disabled_agent_client):
response = disabled_agent_client.post("/api/agents", json={"name": "example-agent", "soul": "blocked"})
assert response.status_code == 403
def test_agent_update_returns_403(self, disabled_agent_client):
response = disabled_agent_client.put("/api/agents/example-agent", json={"description": "blocked"})
assert response.status_code == 403
def test_agent_delete_returns_403(self, disabled_agent_client):
response = disabled_agent_client.delete("/api/agents/example-agent")
assert response.status_code == 403
def test_user_profile_routes_return_403(self, disabled_agent_client):
get_response = disabled_agent_client.get("/api/user-profile")
put_response = disabled_agent_client.put("/api/user-profile", json={"content": "blocked"})
assert get_response.status_code == 403
assert put_response.status_code == 403
@@ -119,6 +119,31 @@ class TestBuildPatchedMessagesPatching:
assert "interrupted" in tool_msg.content.lower() assert "interrupted" in tool_msg.content.lower()
assert tool_msg.name == "bash" assert tool_msg.name == "bash"
def test_raw_provider_tool_calls_are_patched(self):
mw = DanglingToolCallMiddleware()
msgs = [
AIMessage(
content="",
tool_calls=[],
additional_kwargs={
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "bash", "arguments": '{"command":"ls"}'},
}
]
},
)
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert len(patched) == 2
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "call_1"
assert patched[1].name == "bash"
assert patched[1].status == "error"
class TestWrapModelCall: class TestWrapModelCall:
def test_no_patch_passthrough(self): def test_no_patch_passthrough(self):
+22 -3
View File
@@ -12,6 +12,7 @@ from deerflow.utils.file_conversion import (
_MIN_CHARS_PER_PAGE, _MIN_CHARS_PER_PAGE,
MAX_OUTLINE_ENTRIES, MAX_OUTLINE_ENTRIES,
_do_convert, _do_convert,
_get_pdf_converter,
_pymupdf_output_too_sparse, _pymupdf_output_too_sparse,
convert_file_to_markdown, convert_file_to_markdown,
extract_outline, extract_outline,
@@ -214,9 +215,27 @@ class TestDoConvert:
assert result == "MarkItDown fallback" assert result == "MarkItDown fallback"
# --------------------------------------------------------------------------- class TestGetPdfConverter:
# convert_file_to_markdown — async + file writing def test_reads_dict_backed_uploads_config(self):
# --------------------------------------------------------------------------- cfg = MagicMock()
cfg.uploads = {"pdf_converter": "markitdown"}
with patch("deerflow.utils.file_conversion.get_app_config", return_value=cfg):
assert _get_pdf_converter() == "markitdown"
def test_reads_attribute_backed_uploads_config(self):
cfg = MagicMock()
cfg.uploads = MagicMock(pdf_converter="pymupdf4llm")
with patch("deerflow.utils.file_conversion.get_app_config", return_value=cfg):
assert _get_pdf_converter() == "pymupdf4llm"
def test_invalid_value_falls_back_to_auto(self):
cfg = MagicMock()
cfg.uploads = {"pdf_converter": "not-a-real-converter"}
with patch("deerflow.utils.file_conversion.get_app_config", return_value=cfg):
assert _get_pdf_converter() == "auto"
class TestConvertFileToMarkdown: class TestConvertFileToMarkdown:
@@ -0,0 +1,68 @@
"""Regression tests for Gateway lifespan shutdown.
These tests guard the invariant that lifespan shutdown is *bounded*: a
misbehaving channel whose ``stop()`` blocks forever must not keep the
uvicorn worker alive. A hung worker is the precondition for the
signal-reentrancy deadlock described in
``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``.
"""
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from unittest.mock import MagicMock, patch
from fastapi import FastAPI
@asynccontextmanager
async def _noop_langgraph_runtime(_app):
yield
async def _run_lifespan_with_hanging_stop() -> float:
"""Drive the lifespan context with stop_channel_service hanging forever.
Returns the elapsed wall-clock seconds.
"""
from app.gateway.app import _SHUTDOWN_HOOK_TIMEOUT_SECONDS, lifespan
async def hang_forever() -> None:
await asyncio.sleep(3600)
app = FastAPI()
fake_service = MagicMock()
fake_service.get_status = MagicMock(return_value={})
async def fake_start():
return fake_service
with (
patch("app.gateway.app.get_app_config"),
patch("app.gateway.app.get_gateway_config", return_value=MagicMock(host="x", port=0)),
patch("app.gateway.app.langgraph_runtime", _noop_langgraph_runtime),
patch("app.channels.service.start_channel_service", side_effect=fake_start),
patch("app.channels.service.stop_channel_service", side_effect=hang_forever),
):
loop = asyncio.get_event_loop()
start = loop.time()
async with lifespan(app):
pass
elapsed = loop.time() - start
assert _SHUTDOWN_HOOK_TIMEOUT_SECONDS < 30.0, "Timeout constant must stay modest"
return elapsed
def test_shutdown_is_bounded_when_channel_stop_hangs():
"""Lifespan exit must complete near the configured timeout, not hang."""
from app.gateway.app import _SHUTDOWN_HOOK_TIMEOUT_SECONDS
elapsed = asyncio.run(_run_lifespan_with_hanging_stop())
# Generous upper bound: timeout + 2s slack for scheduling overhead.
assert elapsed < _SHUTDOWN_HOOK_TIMEOUT_SECONDS + 2.0, f"Lifespan shutdown took {elapsed:.2f}s; expected <= {_SHUTDOWN_HOOK_TIMEOUT_SECONDS + 2.0:.1f}s"
# Lower bound: the wait_for should actually have waited.
assert elapsed >= _SHUTDOWN_HOOK_TIMEOUT_SECONDS - 0.5, f"Lifespan exited too quickly ({elapsed:.2f}s); wait_for may not have been invoked."
+45
View File
@@ -145,6 +145,21 @@ def test_build_run_config_explicit_agent_name_not_overwritten():
assert config["configurable"]["agent_name"] == "explicit-agent" assert config["configurable"]["agent_name"] == "explicit-agent"
def test_build_run_config_context_custom_agent_injects_agent_name():
"""Custom assistant_id must be forwarded as context['agent_name'] in context mode."""
from app.gateway.services import build_run_config
config = build_run_config(
"thread-1",
{"context": {"model_name": "deepseek-v3"}},
None,
assistant_id="finalis",
)
assert config["context"]["agent_name"] == "finalis"
assert "configurable" not in config
def test_resolve_agent_factory_returns_make_lead_agent(): def test_resolve_agent_factory_returns_make_lead_agent():
"""resolve_agent_factory always returns make_lead_agent regardless of assistant_id.""" """resolve_agent_factory always returns make_lead_agent regardless of assistant_id."""
from app.gateway.services import resolve_agent_factory from app.gateway.services import resolve_agent_factory
@@ -298,6 +313,36 @@ def test_build_run_config_with_context():
assert config["recursion_limit"] == 100 assert config["recursion_limit"] == 100
def test_build_run_config_null_context_becomes_empty_context():
"""When caller sends context=null, treat it as an empty context object."""
from app.gateway.services import build_run_config
config = build_run_config("thread-1", {"context": None}, None)
assert config["context"] == {}
assert "configurable" not in config
def test_build_run_config_rejects_non_mapping_context():
"""When caller sends a non-object context, raise a clear error instead of a TypeError."""
import pytest
from app.gateway.services import build_run_config
with pytest.raises(ValueError, match="context"):
build_run_config("thread-1", {"context": "bad-context"}, None)
def test_build_run_config_null_context_custom_agent_injects_agent_name():
"""Custom assistant_id can still be injected when context=null starts context mode."""
from app.gateway.services import build_run_config
config = build_run_config("thread-1", {"context": None}, None, assistant_id="finalis")
assert config["context"] == {"agent_name": "finalis"}
assert "configurable" not in config
def test_build_run_config_context_plus_configurable_warns(caplog): def test_build_run_config_context_plus_configurable_warns(caplog):
"""When caller sends both 'context' and 'configurable', prefer 'context' and log a warning.""" """When caller sends both 'context' and 'configurable', prefer 'context' and log a warning."""
import logging import logging
+49
View File
@@ -80,6 +80,28 @@ async def test_crawl_network_error(jina_client, monkeypatch):
assert "failed" in result.lower() assert "failed" in result.lower()
@pytest.mark.anyio
async def test_crawl_transient_failure_logs_without_traceback(jina_client, monkeypatch, caplog):
"""Transient network failures must log at WARNING without a traceback and include the exception type."""
async def mock_post(self, url, **kwargs):
raise httpx.ConnectTimeout("timed out")
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
with caplog.at_level(logging.DEBUG, logger="deerflow.community.jina_ai.jina_client"):
result = await jina_client.crawl("https://example.com")
jina_records = [r for r in caplog.records if r.name == "deerflow.community.jina_ai.jina_client"]
assert len(jina_records) == 1, f"expected exactly one log record, got {len(jina_records)}"
record = jina_records[0]
assert record.levelno == logging.WARNING, f"expected WARNING, got {record.levelname}"
assert record.exc_info is None, "transient failures must not attach a traceback"
assert "ConnectTimeout" in record.getMessage()
assert result.startswith("Error:")
assert "ConnectTimeout" in result
@pytest.mark.anyio @pytest.mark.anyio
async def test_crawl_passes_headers(jina_client, monkeypatch): async def test_crawl_passes_headers(jina_client, monkeypatch):
"""Test that correct headers are sent.""" """Test that correct headers are sent."""
@@ -175,3 +197,30 @@ async def test_web_fetch_tool_returns_markdown_on_success(monkeypatch):
result = await web_fetch_tool.ainvoke("https://example.com") result = await web_fetch_tool.ainvoke("https://example.com")
assert "Hello world" in result assert "Hello world" in result
assert not result.startswith("Error:") assert not result.startswith("Error:")
@pytest.mark.anyio
async def test_web_fetch_tool_offloads_extraction_to_thread(monkeypatch):
"""Test that readability extraction is offloaded via asyncio.to_thread to avoid blocking the event loop."""
import asyncio
async def mock_crawl(self, url, **kwargs):
return "<html><body><p>threaded</p></body></html>"
mock_config = MagicMock()
mock_config.get_tool_config.return_value = None
monkeypatch.setattr("deerflow.community.jina_ai.tools.get_app_config", lambda: mock_config)
monkeypatch.setattr(JinaClient, "crawl", mock_crawl)
to_thread_called = False
original_to_thread = asyncio.to_thread
async def tracking_to_thread(func, *args, **kwargs):
nonlocal to_thread_called
to_thread_called = True
return await original_to_thread(func, *args, **kwargs)
monkeypatch.setattr("deerflow.community.jina_ai.tools.asyncio.to_thread", tracking_to_thread)
result = await web_fetch_tool.ainvoke("https://example.com")
assert to_thread_called, "extract_article must be called via asyncio.to_thread to avoid blocking the event loop"
assert "threaded" in result
@@ -8,6 +8,7 @@ import pytest
from deerflow.agents.lead_agent import agent as lead_agent_module from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.config.app_config import AppConfig from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.model_config import ModelConfig from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.summarization_config import SummarizationConfig from deerflow.config.summarization_config import SummarizationConfig
@@ -112,6 +113,74 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey
assert result["model"] is not None assert result["model"] is not None
def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch):
app_config = _make_app_config(
[
_make_model("default-model", supports_thinking=False),
_make_model("context-model", supports_thinking=True),
]
)
import deerflow.tools as tools_module
get_available_tools = MagicMock(return_value=[])
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(tools_module, "get_available_tools", get_available_tools)
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None: [])
captured: dict[str, object] = {}
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None):
captured["name"] = name
captured["thinking_enabled"] = thinking_enabled
captured["reasoning_effort"] = reasoning_effort
return object()
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
result = lead_agent_module.make_lead_agent(
{
"context": {
"model_name": "context-model",
"thinking_enabled": False,
"reasoning_effort": "high",
"is_plan_mode": True,
"subagent_enabled": True,
"max_concurrent_subagents": 7,
}
}
)
assert captured == {
"name": "context-model",
"thinking_enabled": False,
"reasoning_effort": "high",
}
get_available_tools.assert_called_once_with(model_name="context-model", groups=None, subagent_enabled=True)
assert result["model"] is not None
def test_make_lead_agent_rejects_invalid_bootstrap_agent_name(monkeypatch):
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
with pytest.raises(ValueError, match="Invalid agent name"):
lead_agent_module.make_lead_agent(
{
"configurable": {
"model_name": "safe-model",
"thinking_enabled": False,
"is_plan_mode": False,
"subagent_enabled": False,
"is_bootstrap": True,
"agent_name": "../../../tmp/evil",
}
}
)
def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch): def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
app_config = _make_app_config( app_config = _make_app_config(
[ [
@@ -145,6 +214,7 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
"get_summarization_config", "get_summarization_config",
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"), lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
) )
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
captured: dict[str, object] = {} captured: dict[str, object] = {}
fake_model = object() fake_model = object()
@@ -156,10 +226,56 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
return fake_model return fake_model
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model) monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "SummarizationMiddleware", lambda **kwargs: kwargs) monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
middleware = lead_agent_module._create_summarization_middleware() middleware = lead_agent_module._create_summarization_middleware()
assert captured["name"] == "model-masswork" assert captured["name"] == "model-masswork"
assert captured["thinking_enabled"] is False assert captured["thinking_enabled"] is False
assert middleware["model"] is fake_model assert middleware["model"] is fake_model
def test_create_summarization_middleware_registers_memory_flush_hook_when_memory_enabled(monkeypatch):
monkeypatch.setattr(
lead_agent_module,
"get_summarization_config",
lambda: SummarizationConfig(enabled=True),
)
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=True))
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: object())
captured: dict[str, object] = {}
def _fake_middleware(**kwargs):
captured.update(kwargs)
return kwargs
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", _fake_middleware)
lead_agent_module._create_summarization_middleware()
assert captured["before_summarization"] == [lead_agent_module.memory_flush_hook]
def test_create_summarization_middleware_passes_skill_read_tool_names(monkeypatch):
app_config = _make_app_config([_make_model("default-model", supports_thinking=False)])
monkeypatch.setattr(
lead_agent_module,
"get_summarization_config",
lambda: SummarizationConfig(enabled=True, skill_file_read_tool_names=["read_file", "cat"]),
)
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: object())
captured: dict[str, object] = {}
def _fake_middleware(**kwargs):
captured.update(kwargs)
return kwargs
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", _fake_middleware)
lead_agent_module._create_summarization_middleware()
assert captured["skill_file_read_tool_names"] == ["read_file", "cat"]
@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any
import pytest import pytest
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
@@ -134,3 +135,303 @@ def test_async_model_call_propagates_graph_bubble_up() -> None:
with pytest.raises(GraphBubbleUp): with pytest.raises(GraphBubbleUp):
asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler)) asyncio.run(middleware.awrap_model_call(SimpleNamespace(), handler))
def test_circuit_half_open_graph_bubble_up_resets_probe() -> None:
"""Verify that GraphBubbleUp in half_open state resets probe_in_flight."""
middleware = _build_middleware()
# Step 1: Manually set state to half_open and check_circuit() to set probe_in_flight=True
middleware._circuit_state = "half_open"
middleware._circuit_probe_in_flight = False
# Call _check_circuit() once to simulate the probe being allowed through
assert middleware._check_circuit() is False
assert middleware._circuit_probe_in_flight is True
# Step 2: Now trigger handler that raises GraphBubbleUp
def handler(_request) -> AIMessage:
raise GraphBubbleUp()
# Mock _check_circuit() to return False (since we already did the probe check)
import unittest.mock
with unittest.mock.patch.object(middleware, "_check_circuit", return_value=False):
with pytest.raises(GraphBubbleUp):
middleware.wrap_model_call(SimpleNamespace(), handler)
# Verify probe_in_flight was reset, state should remain half_open
assert middleware._circuit_probe_in_flight is False
assert middleware._circuit_state == "half_open"
@pytest.mark.anyio
async def test_async_circuit_half_open_graph_bubble_up_resets_probe() -> None:
"""Verify that GraphBubbleUp in half_open state resets probe_in_flight (async version)."""
middleware = _build_middleware()
# Step 1: Manually set state to half_open and check_circuit() to set probe_in_flight=True
middleware._circuit_state = "half_open"
middleware._circuit_probe_in_flight = False
# Call _check_circuit() once to simulate the probe being allowed through
assert middleware._check_circuit() is False
assert middleware._circuit_probe_in_flight is True
# Step 2: Now trigger handler that raises GraphBubbleUp
async def handler(_request) -> AIMessage:
raise GraphBubbleUp()
# Mock _check_circuit() to return False (since we already did the probe check)
import unittest.mock
with unittest.mock.patch.object(middleware, "_check_circuit", return_value=False):
with pytest.raises(GraphBubbleUp):
await middleware.awrap_model_call(SimpleNamespace(), handler)
# Verify probe_in_flight was reset, state should remain half_open
assert middleware._circuit_probe_in_flight is False
assert middleware._circuit_state == "half_open"
# ---------- Circuit Breaker Tests ----------
def transient_failing_handler(request: Any) -> Any:
raise FakeError("Server Error", status_code=502) # Used for transient error
def quota_failing_handler(request: Any) -> Any:
raise FakeError("Quota exceeded", body={"error": {"code": "insufficient_quota"}}) # Used for quota error
def success_handler(request: Any) -> Any:
return AIMessage(content="Success")
def mock_classify_retriable(exc: BaseException) -> tuple[bool, str]:
return True, "transient"
def mock_classify_non_retriable(exc: BaseException) -> tuple[bool, str]:
return False, "quota"
def test_circuit_breaker_trips_and_recovers(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify that circuit breaker trips, fast fails, correctly transitions to Half-Open, and recovers or re-opens."""
# Mock time.sleep to avoid slow tests during retry loops (Speed up from ~4s to 0.1s)
waits: list[float] = []
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
# Mock time.time to decouple from private implementation details and enable time travel
current_time = 1000.0
monkeypatch.setattr("time.time", lambda: current_time)
middleware = LLMErrorHandlingMiddleware()
middleware.circuit_failure_threshold = 3
middleware.circuit_recovery_timeout_sec = 10
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
request: Any = {"messages": []}
# --- 0. Test initial state & Success ---
# Success handler does not increase count. If it's already 0, it stays 0.
middleware.wrap_model_call(request, success_handler)
assert middleware._circuit_failure_count == 0
assert middleware._check_circuit() is False
# --- 1. Trip the circuit ---
# Fails 3 overall calls. Threshold (3) is reached.
middleware.wrap_model_call(request, transient_failing_handler)
assert middleware._circuit_failure_count == 1
middleware.wrap_model_call(request, transient_failing_handler)
assert middleware._circuit_failure_count == 2
middleware.wrap_model_call(request, transient_failing_handler)
assert middleware._circuit_failure_count == 3
assert middleware._check_circuit() is True # Circuit is OPEN
# --- 2. Fast Fail ---
# 2nd call: fast fail immediately without calling handler.
# Count should not increase during OPEN state.
result = middleware.wrap_model_call(request, success_handler)
assert result.content == middleware._build_circuit_breaker_message()
assert middleware._circuit_failure_count == 3
# --- 3. Half-Open -> Fail -> Re-Open ---
# Time travel 11 seconds (timeout is 10s). Current time becomes 1011.0
current_time += 11.0
# Verify that the timeout was set EXACTLY relative to current_time + timeout_sec
assert middleware._circuit_open_until == current_time - 11.0 + middleware.circuit_recovery_timeout_sec
# Fails again! The request will go through the 3-attempt retry loop again.
middleware.wrap_model_call(request, transient_failing_handler)
assert middleware._circuit_failure_count == middleware.circuit_failure_threshold
assert middleware._circuit_state == "open" # Re-OPENed
# --- 4. Half-Open -> Success -> Reset ---
# Time travel another 11 seconds
current_time += 11.0
# Succeeds this time! Should completely reset.
result = middleware.wrap_model_call(request, success_handler)
assert result.content == "Success"
assert middleware._circuit_failure_count == 0 # Fully RESET!
assert middleware._check_circuit() is False
def test_circuit_breaker_does_not_trip_on_non_retriable_errors(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify that circuit breaker ignores business errors like Quota or Auth."""
waits: list[float] = []
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
middleware = LLMErrorHandlingMiddleware()
middleware.circuit_failure_threshold = 3
monkeypatch.setattr(middleware, "_classify_error", mock_classify_non_retriable)
request: Any = {"messages": []}
for _ in range(3):
middleware.wrap_model_call(request, quota_failing_handler)
assert middleware._circuit_failure_count == 0
assert middleware._check_circuit() is False
# ---------- ReadError / RemoteProtocolError retriable classification ----------
class _ReadError(Exception):
"""Local stand-in for httpx.ReadError — same class name, no httpx dependency."""
class _RemoteProtocolError(Exception):
"""Local stand-in for httpx.RemoteProtocolError — same class name, no httpx dependency."""
_ReadError.__name__ = "ReadError"
_RemoteProtocolError.__name__ = "RemoteProtocolError"
def test_classify_error_read_error_is_retriable() -> None:
middleware = _build_middleware()
exc = _ReadError("Connection dropped mid-stream")
exc.__class__.__name__ = "ReadError"
retriable, reason = middleware._classify_error(exc)
assert retriable is True
assert reason == "transient"
def test_classify_error_remote_protocol_error_is_retriable() -> None:
middleware = _build_middleware()
exc = _RemoteProtocolError("Server closed connection unexpectedly")
exc.__class__.__name__ = "RemoteProtocolError"
retriable, reason = middleware._classify_error(exc)
assert retriable is True
assert reason == "transient"
def test_sync_read_error_triggers_retry_loop(monkeypatch: pytest.MonkeyPatch) -> None:
middleware = _build_middleware(retry_max_attempts=3, retry_base_delay_ms=10, retry_cap_delay_ms=10)
attempts = 0
waits: list[float] = []
monkeypatch.setattr("time.sleep", lambda d: waits.append(d))
def handler(_request) -> AIMessage:
nonlocal attempts
attempts += 1
raise _ReadError("Connection dropped mid-stream")
result = middleware.wrap_model_call(SimpleNamespace(), handler)
assert isinstance(result, AIMessage)
assert "temporarily unavailable" in result.content
assert attempts == 3 # exhausted all retries
assert len(waits) == 2 # slept between attempts 1→2 and 2→3
@pytest.mark.anyio
async def test_async_read_error_triggers_retry_loop(monkeypatch: pytest.MonkeyPatch) -> None:
middleware = _build_middleware(retry_max_attempts=3, retry_base_delay_ms=10, retry_cap_delay_ms=10)
attempts = 0
waits: list[float] = []
async def fake_sleep(d: float) -> None:
waits.append(d)
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
async def handler(_request) -> AIMessage:
nonlocal attempts
attempts += 1
raise _ReadError("Connection dropped mid-stream")
result = await middleware.awrap_model_call(SimpleNamespace(), handler)
assert isinstance(result, AIMessage)
assert "temporarily unavailable" in result.content
assert attempts == 3 # exhausted all retries
assert len(waits) == 2 # slept between attempts 1→2 and 2→3
@pytest.mark.anyio
async def test_async_circuit_breaker_trips_and_recovers(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify async version of circuit breaker correctly handles state transitions."""
waits: list[float] = []
async def fake_sleep(d: float) -> None:
waits.append(d)
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
current_time = 1000.0
monkeypatch.setattr("time.time", lambda: current_time)
middleware = LLMErrorHandlingMiddleware()
middleware.circuit_failure_threshold = 3
middleware.circuit_recovery_timeout_sec = 10
monkeypatch.setattr(middleware, "_classify_error", mock_classify_retriable)
async def async_failing_handler(request: Any) -> Any:
raise FakeError("Server Error", status_code=502)
request: Any = {"messages": []}
# --- 1. Trip the circuit ---
# Fails 3 overall calls. Threshold (3) is reached.
await middleware.awrap_model_call(request, async_failing_handler)
assert middleware._circuit_failure_count == 1
await middleware.awrap_model_call(request, async_failing_handler)
assert middleware._circuit_failure_count == 2
await middleware.awrap_model_call(request, async_failing_handler)
assert middleware._circuit_failure_count == 3
assert middleware._check_circuit() is True
# --- 2. Fast Fail ---
# 2nd call: fast fail immediately without calling handler
async def async_success_handler(request: Any) -> Any:
return AIMessage(content="Success")
result = await middleware.awrap_model_call(request, async_success_handler)
assert result.content == middleware._build_circuit_breaker_message()
assert middleware._circuit_failure_count == 3 # Unchanged
# --- 3. Half-Open -> Fail -> Re-Open ---
# Time travel 11 seconds
current_time += 11.0
# Verify timeout formula
assert middleware._circuit_open_until == current_time - 11.0 + middleware.circuit_recovery_timeout_sec
# Fails again! The request goes through the 3-attempt retry loop.
await middleware.awrap_model_call(request, async_failing_handler)
assert middleware._circuit_failure_count == middleware.circuit_failure_threshold
assert middleware._circuit_state == "open" # Re-OPENed
# --- 4. Half-Open -> Success -> Reset ---
# Time travel another 11 seconds
current_time += 11.0
result = await middleware.awrap_model_call(request, async_success_handler)
assert result.content == "Success"
assert middleware._circuit_failure_count == 0 # RESET
assert middleware._check_circuit() is False
@@ -255,7 +255,9 @@ class TestMultipleMounts:
sandbox.execute_command("cat /mnt/data/test.txt") sandbox.execute_command("cat /mnt/data/test.txt")
# Verify the command received the resolved local path # Verify the command received the resolved local path
assert str(data_dir) in captured.get("command", "") command = captured.get("command", [])
assert isinstance(command, list) and len(command) >= 3
assert str(data_dir) in command[2]
def test_reverse_resolve_path_does_not_match_partial_prefix(self, tmp_path): def test_reverse_resolve_path_does_not_match_partial_prefix(self, tmp_path):
foo_dir = tmp_path / "foo" foo_dir = tmp_path / "foo"
@@ -413,6 +413,45 @@ class TestHardStopWithListContent:
assert msg.content.startswith("thinking...") assert msg.content.startswith("thinking...")
assert _HARD_STOP_MSG in msg.content assert _HARD_STOP_MSG in msg.content
def test_hard_stop_clears_raw_tool_call_metadata(self):
"""Forced-stop messages must not retain provider-level raw tool-call payloads."""
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
runtime = _make_runtime()
call = [_bash_call("ls")]
def _make_provider_state():
return {
"messages": [
AIMessage(
content="thinking...",
tool_calls=call,
additional_kwargs={
"tool_calls": [
{
"id": "call_ls",
"type": "function",
"function": {"name": "bash", "arguments": '{"command":"ls"}'},
"thought_signature": "sig-1",
}
],
"function_call": {"name": "bash", "arguments": '{"command":"ls"}'},
},
response_metadata={"finish_reason": "tool_calls"},
)
]
}
for _ in range(3):
mw._apply(_make_provider_state(), runtime)
result = mw._apply(_make_provider_state(), runtime)
assert result is not None
msg = result["messages"][0]
assert msg.tool_calls == []
assert "tool_calls" not in msg.additional_kwargs
assert "function_call" not in msg.additional_kwargs
assert msg.response_metadata["finish_reason"] == "stop"
class TestToolFrequencyDetection: class TestToolFrequencyDetection:
"""Tests for per-tool-type frequency detection (Layer 2). """Tests for per-tool-type frequency detection (Layer 2).
@@ -0,0 +1,274 @@
"""Tests for custom MCP tool interceptors loaded via extensions_config.json."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from deerflow.mcp.tools import get_mcp_tools
def _make_patches(*, interceptor_paths=None):
"""Set up mocks for get_mcp_tools() with optional custom interceptors.
Returns a dict of patch context managers.
"""
mock_client = MagicMock()
mock_client.get_tools = AsyncMock(return_value=[])
extra = {}
if interceptor_paths is not None:
extra["mcpInterceptors"] = interceptor_paths
return {
"client_cls": patch(
"langchain_mcp_adapters.client.MultiServerMCPClient",
return_value=mock_client,
),
"from_file": patch(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
return_value=MagicMock(
model_extra=extra,
get_enabled_mcp_servers=MagicMock(return_value={}),
),
),
"build_servers": patch(
"deerflow.mcp.tools.build_servers_config",
return_value={"test-server": {}},
),
"oauth_headers": patch(
"deerflow.mcp.tools.get_initial_oauth_headers",
new_callable=AsyncMock,
return_value={},
),
"oauth_interceptor": patch(
"deerflow.mcp.tools.build_oauth_tool_interceptor",
return_value=None,
),
}
def _get_interceptors(mock_cls):
"""Extract the tool_interceptors list passed to MultiServerMCPClient."""
kw = mock_cls.call_args
return kw.kwargs.get("tool_interceptors") or kw[1].get("tool_interceptors", [])
def test_custom_interceptor_loaded_and_appended():
"""A valid interceptor builder path is resolved, called, and appended to tool_interceptors."""
async def fake_interceptor(request, handler):
return await handler(request)
def fake_builder():
return fake_interceptor
p = _make_patches(interceptor_paths=["my_package.auth:build_interceptor"])
with (
p["client_cls"] as mock_cls,
p["from_file"],
p["build_servers"],
p["oauth_headers"],
p["oauth_interceptor"],
patch("deerflow.mcp.tools.resolve_variable", return_value=fake_builder),
):
asyncio.run(get_mcp_tools())
interceptors = _get_interceptors(mock_cls)
assert len(interceptors) == 1
assert interceptors[0] is fake_interceptor
def test_multiple_custom_interceptors():
"""Multiple interceptor paths are all loaded in order."""
async def interceptor_a(request, handler):
return await handler(request)
async def interceptor_b(request, handler):
return await handler(request)
builders = {
"pkg.a:build_a": lambda: interceptor_a,
"pkg.b:build_b": lambda: interceptor_b,
}
p = _make_patches(interceptor_paths=["pkg.a:build_a", "pkg.b:build_b"])
with (
p["client_cls"] as mock_cls,
p["from_file"],
p["build_servers"],
p["oauth_headers"],
p["oauth_interceptor"],
patch("deerflow.mcp.tools.resolve_variable", side_effect=lambda path: builders[path]),
):
asyncio.run(get_mcp_tools())
interceptors = _get_interceptors(mock_cls)
assert len(interceptors) == 2
assert interceptors[0] is interceptor_a
assert interceptors[1] is interceptor_b
def test_custom_interceptor_builder_returning_none_is_skipped():
"""If a builder returns None, it is not appended to the interceptor list."""
p = _make_patches(interceptor_paths=["pkg.noop:build_noop"])
with (
p["client_cls"] as mock_cls,
p["from_file"],
p["build_servers"],
p["oauth_headers"],
p["oauth_interceptor"],
patch("deerflow.mcp.tools.resolve_variable", return_value=lambda: None),
):
asyncio.run(get_mcp_tools())
assert len(_get_interceptors(mock_cls)) == 0
def test_custom_interceptor_resolve_error_logs_warning_and_continues():
"""A broken interceptor path logs a warning and does not block tool loading."""
p = _make_patches(interceptor_paths=["broken.path:does_not_exist"])
with (
p["client_cls"],
p["from_file"],
p["build_servers"],
p["oauth_headers"],
p["oauth_interceptor"],
patch("deerflow.mcp.tools.resolve_variable", side_effect=ImportError("no such module")),
patch("deerflow.mcp.tools.logger.warning") as mock_warn,
):
tools = asyncio.run(get_mcp_tools())
assert tools == []
mock_warn.assert_called_once()
assert "broken.path:does_not_exist" in mock_warn.call_args[0][0]
def test_custom_interceptor_builder_exception_logs_warning_and_continues():
"""If the builder function itself raises, the error is caught and logged."""
def exploding_builder():
raise RuntimeError("builder exploded")
p = _make_patches(interceptor_paths=["pkg.bad:exploding_builder"])
with (
p["client_cls"],
p["from_file"],
p["build_servers"],
p["oauth_headers"],
p["oauth_interceptor"],
patch("deerflow.mcp.tools.resolve_variable", return_value=exploding_builder),
patch("deerflow.mcp.tools.logger.warning") as mock_warn,
):
tools = asyncio.run(get_mcp_tools())
assert tools == []
mock_warn.assert_called_once()
assert "pkg.bad:exploding_builder" in mock_warn.call_args[0][0]
def test_no_mcp_interceptors_field_is_safe():
"""When mcpInterceptors is absent from config, no interceptors are added."""
p = _make_patches(interceptor_paths=None)
with (
p["client_cls"] as mock_cls,
p["from_file"],
p["build_servers"],
p["oauth_headers"],
p["oauth_interceptor"],
):
asyncio.run(get_mcp_tools())
assert len(_get_interceptors(mock_cls)) == 0
def test_custom_interceptor_coexists_with_oauth_interceptor():
"""Custom interceptors are appended after the OAuth interceptor."""
async def oauth_fn(request, handler):
return await handler(request)
async def custom_fn(request, handler):
return await handler(request)
p = _make_patches(interceptor_paths=["pkg.custom:build_custom"])
with (
p["client_cls"] as mock_cls,
p["from_file"],
p["build_servers"],
p["oauth_headers"],
patch("deerflow.mcp.tools.build_oauth_tool_interceptor", return_value=oauth_fn),
patch("deerflow.mcp.tools.resolve_variable", return_value=lambda: custom_fn),
):
asyncio.run(get_mcp_tools())
interceptors = _get_interceptors(mock_cls)
assert len(interceptors) == 2
assert interceptors[0] is oauth_fn
assert interceptors[1] is custom_fn
def test_mcp_interceptors_single_string_is_normalized():
"""A single string value for mcpInterceptors is normalized to a list."""
async def fake_interceptor(request, handler):
return await handler(request)
p = _make_patches(interceptor_paths="pkg.single:build_it")
with (
p["client_cls"] as mock_cls,
p["from_file"],
p["build_servers"],
p["oauth_headers"],
p["oauth_interceptor"],
patch("deerflow.mcp.tools.resolve_variable", return_value=lambda: fake_interceptor),
):
asyncio.run(get_mcp_tools())
assert len(_get_interceptors(mock_cls)) == 1
def test_mcp_interceptors_invalid_type_logs_warning():
"""A non-list, non-string value for mcpInterceptors logs a warning and is skipped."""
p = _make_patches(interceptor_paths=42)
with (
p["client_cls"] as mock_cls,
p["from_file"],
p["build_servers"],
p["oauth_headers"],
p["oauth_interceptor"],
patch("deerflow.mcp.tools.logger.warning") as mock_warn,
):
asyncio.run(get_mcp_tools())
assert len(_get_interceptors(mock_cls)) == 0
mock_warn.assert_called_once()
assert "must be a list" in mock_warn.call_args[0][0]
def test_custom_interceptor_non_callable_return_logs_warning():
"""If a builder returns a non-callable value, it is skipped with a warning."""
p = _make_patches(interceptor_paths=["pkg.bad:returns_string"])
with (
p["client_cls"] as mock_cls,
p["from_file"],
p["build_servers"],
p["oauth_headers"],
p["oauth_interceptor"],
patch("deerflow.mcp.tools.resolve_variable", return_value=lambda: "not_a_callable"),
patch("deerflow.mcp.tools.logger.warning") as mock_warn,
):
asyncio.run(get_mcp_tools())
assert len(_get_interceptors(mock_cls)) == 0
mock_warn.assert_called_once()
assert "non-callable" in mock_warn.call_args[0][0]
+73
View File
@@ -1,3 +1,5 @@
import threading
import time
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
@@ -89,3 +91,74 @@ def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
correction_detected=False, correction_detected=False,
reinforcement_detected=True, reinforcement_detected=True,
) )
def test_flush_nowait_cancels_existing_timer_and_starts_immediate_timer() -> None:
queue = MemoryUpdateQueue()
existing_timer = MagicMock()
queue._timer = existing_timer
created_timer = MagicMock()
with patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls:
queue.flush_nowait()
existing_timer.cancel.assert_called_once_with()
timer_cls.assert_called_once_with(0, queue._process_queue)
assert created_timer.daemon is True
created_timer.start.assert_called_once_with()
assert queue._timer is created_timer
def test_add_nowait_cancels_existing_timer_and_starts_immediate_timer() -> None:
queue = MemoryUpdateQueue()
existing_timer = MagicMock()
queue._timer = existing_timer
created_timer = MagicMock()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls,
):
queue.add_nowait(thread_id="thread-1", messages=["conversation"], agent_name="lead-agent")
existing_timer.cancel.assert_called_once_with()
timer_cls.assert_called_once_with(0, queue._process_queue)
assert queue.pending_count == 1
assert queue._queue[0].agent_name == "lead-agent"
assert created_timer.daemon is True
created_timer.start.assert_called_once_with()
def test_process_queue_reschedules_immediately_when_already_processing() -> None:
queue = MemoryUpdateQueue()
queue._processing = True
created_timer = MagicMock()
with patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls:
queue._process_queue()
timer_cls.assert_called_once_with(0, queue._process_queue)
assert created_timer.daemon is True
created_timer.start.assert_called_once_with()
def test_flush_nowait_is_non_blocking() -> None:
queue = MemoryUpdateQueue()
started = threading.Event()
finished = threading.Event()
def _slow_process_queue() -> None:
started.set()
time.sleep(0.2)
finished.set()
queue._process_queue = _slow_process_queue
start = time.perf_counter()
queue.flush_nowait()
elapsed = time.perf_counter() - start
assert started.wait(0.1) is True
assert elapsed < 0.1
assert finished.is_set() is False
assert finished.wait(1.0) is True
+87
View File
@@ -110,6 +110,93 @@ class TestFileMemoryStorage:
assert result is True assert result is True
assert memory_file.exists() assert memory_file.exists()
def test_save_does_not_mutate_caller_dict(self, tmp_path):
"""save() must not mutate the caller's dict (lastUpdated side-effect)."""
memory_file = tmp_path / "memory.json"
def mock_get_paths():
mock_paths = MagicMock()
mock_paths.memory_file = memory_file
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
storage = FileMemoryStorage()
original = {"version": "1.0", "facts": []}
before_keys = set(original.keys())
storage.save(original)
assert set(original.keys()) == before_keys, "save() must not add keys to caller's dict"
assert "lastUpdated" not in original
def test_cache_not_corrupted_when_save_fails(self, tmp_path):
"""Cache must remain clean when save() raises OSError.
If save() fails, the cache must NOT be updated with the new data.
Together with the deepcopy in updater._finalize_update(), this prevents
stale mutations from leaking into the cache when persistence fails.
"""
memory_file = tmp_path / "memory.json"
memory_file.parent.mkdir(parents=True, exist_ok=True)
original_data = {"version": "1.0", "facts": [{"content": "original"}]}
import json as _json
memory_file.write_text(_json.dumps(original_data))
def mock_get_paths():
mock_paths = MagicMock()
mock_paths.memory_file = memory_file
return mock_paths
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
storage = FileMemoryStorage()
# Warm the cache
cached = storage.load()
assert cached["facts"][0]["content"] == "original"
# Simulate save failure: mkdir succeeds but open() raises
modified = {"version": "1.0", "facts": [{"content": "mutated"}]}
with patch("builtins.open", side_effect=OSError("disk full")):
result = storage.save(modified)
assert result is False
# Cache must still reflect the original data, not the failed write
after = storage.load()
assert after["facts"][0]["content"] == "original"
def test_cache_thread_safety(self, tmp_path):
"""Concurrent load/reload calls must not race on _memory_cache."""
memory_file = tmp_path / "memory.json"
memory_file.parent.mkdir(parents=True, exist_ok=True)
import json as _json
memory_file.write_text(_json.dumps({"version": "1.0", "facts": []}))
def mock_get_paths():
mock_paths = MagicMock()
mock_paths.memory_file = memory_file
return mock_paths
errors: list[Exception] = []
def load_many(storage: FileMemoryStorage) -> None:
try:
for _ in range(50):
storage.load()
except Exception as exc:
errors.append(exc)
with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths):
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
storage = FileMemoryStorage()
threads = [threading.Thread(target=load_many, args=(storage,)) for _ in range(8)]
for t in threads:
t.start()
for t in threads:
t.join()
assert not errors, f"Thread-safety errors: {errors}"
def test_reload_forces_cache_invalidation(self, tmp_path): def test_reload_forces_cache_invalidation(self, tmp_path):
"""Should force reload from file and invalidate cache.""" """Should force reload from file and invalidate cache."""
memory_file = tmp_path / "memory.json" memory_file = tmp_path / "memory.json"
+169 -9
View File
@@ -1,9 +1,13 @@
from unittest.mock import MagicMock, patch import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from deerflow.agents.memory.prompt import format_conversation_for_update from deerflow.agents.memory.prompt import format_conversation_for_update
from deerflow.agents.memory.updater import ( from deerflow.agents.memory.updater import (
MemoryUpdater, MemoryUpdater,
_extract_text, _extract_text,
_run_async_update_sync,
clear_memory_data, clear_memory_data,
create_memory_fact, create_memory_fact,
delete_memory_fact, delete_memory_fact,
@@ -523,15 +527,16 @@ class TestUpdateMemoryStructuredResponse:
model = MagicMock() model = MagicMock()
response = MagicMock() response = MagicMock()
response.content = content response.content = content
model.invoke.return_value = response model.ainvoke = AsyncMock(return_value=response)
return model return model
def test_string_response_parses(self): def test_string_response_parses(self):
updater = MemoryUpdater() updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with ( with (
patch.object(updater, "_get_model", return_value=self._make_mock_model(valid_json)), patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
@@ -546,6 +551,7 @@ class TestUpdateMemoryStructuredResponse:
result = updater.update_memory([msg, ai_msg]) result = updater.update_memory([msg, ai_msg])
assert result is True assert result is True
model.ainvoke.assert_awaited_once()
def test_list_content_response_parses(self): def test_list_content_response_parses(self):
"""LLM response as list-of-blocks should be extracted, not repr'd.""" """LLM response as list-of-blocks should be extracted, not repr'd."""
@@ -570,6 +576,30 @@ class TestUpdateMemoryStructuredResponse:
assert result is True assert result is True
def test_async_update_memory_uses_ainvoke(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Hello"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Hi there"
ai_msg.tool_calls = []
result = asyncio.run(updater.aupdate_memory([msg, ai_msg]))
assert result is True
model.ainvoke.assert_awaited_once()
assert model.ainvoke.await_args.kwargs["config"] == {"run_name": "memory_agent"}
def test_correction_hint_injected_when_detected(self): def test_correction_hint_injected_when_detected(self):
updater = MemoryUpdater() updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}' valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
@@ -592,7 +622,7 @@ class TestUpdateMemoryStructuredResponse:
result = updater.update_memory([msg, ai_msg], correction_detected=True) result = updater.update_memory([msg, ai_msg], correction_detected=True)
assert result is True assert result is True
prompt = model.invoke.call_args[0][0] prompt = model.ainvoke.await_args.args[0]
assert "Explicit correction signals were detected" in prompt assert "Explicit correction signals were detected" in prompt
def test_correction_hint_empty_when_not_detected(self): def test_correction_hint_empty_when_not_detected(self):
@@ -617,9 +647,89 @@ class TestUpdateMemoryStructuredResponse:
result = updater.update_memory([msg, ai_msg], correction_detected=False) result = updater.update_memory([msg, ai_msg], correction_detected=False)
assert result is True assert result is True
prompt = model.invoke.call_args[0][0] prompt = model.ainvoke.await_args.args[0]
assert "Explicit correction signals were detected" not in prompt assert "Explicit correction signals were detected" not in prompt
def test_sync_update_memory_wrapper_works_in_running_loop(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Hello from loop"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Hi"
ai_msg.tool_calls = []
async def run_in_loop():
return updater.update_memory([msg, ai_msg])
result = asyncio.run(run_in_loop())
assert result is True
model.ainvoke.assert_awaited_once()
def test_sync_update_memory_returns_false_when_bridge_submit_fails(self):
updater = MemoryUpdater()
with (
patch(
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
side_effect=RuntimeError("executor down"),
),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Hello from loop"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Hi"
ai_msg.tool_calls = []
async def run_in_loop():
return updater.update_memory([msg, ai_msg])
result = asyncio.run(run_in_loop())
assert result is False
class TestRunAsyncUpdateSync:
def test_closes_unawaited_awaitable_when_bridge_fails_before_handoff(self):
class CloseableAwaitable:
def __init__(self):
self.closed = False
def __await__(self):
pytest.fail("awaitable should not have been awaited")
yield
def close(self):
self.closed = True
awaitable = CloseableAwaitable()
with patch(
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
side_effect=RuntimeError("executor down"),
):
async def run_in_loop():
return _run_async_update_sync(awaitable)
result = asyncio.run(run_in_loop())
assert result is False
assert awaitable.closed is True
class TestFactDeduplicationCaseInsensitive: class TestFactDeduplicationCaseInsensitive:
"""Tests that fact deduplication is case-insensitive.""" """Tests that fact deduplication is case-insensitive."""
@@ -694,7 +804,7 @@ class TestReinforcementHint:
model = MagicMock() model = MagicMock()
response = MagicMock() response = MagicMock()
response.content = f"```json\n{json_response}\n```" response.content = f"```json\n{json_response}\n```"
model.invoke.return_value = response model.ainvoke = AsyncMock(return_value=response)
return model return model
def test_reinforcement_hint_injected_when_detected(self): def test_reinforcement_hint_injected_when_detected(self):
@@ -719,7 +829,7 @@ class TestReinforcementHint:
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True) result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
assert result is True assert result is True
prompt = model.invoke.call_args[0][0] prompt = model.ainvoke.await_args.args[0]
assert "Positive reinforcement signals were detected" in prompt assert "Positive reinforcement signals were detected" in prompt
def test_reinforcement_hint_absent_when_not_detected(self): def test_reinforcement_hint_absent_when_not_detected(self):
@@ -744,7 +854,7 @@ class TestReinforcementHint:
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False) result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
assert result is True assert result is True
prompt = model.invoke.call_args[0][0] prompt = model.ainvoke.await_args.args[0]
assert "Positive reinforcement signals were detected" not in prompt assert "Positive reinforcement signals were detected" not in prompt
def test_both_hints_present_when_both_detected(self): def test_both_hints_present_when_both_detected(self):
@@ -769,6 +879,56 @@ class TestReinforcementHint:
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True) result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
assert result is True assert result is True
prompt = model.invoke.call_args[0][0] prompt = model.ainvoke.await_args.args[0]
assert "Explicit correction signals were detected" in prompt assert "Explicit correction signals were detected" in prompt
assert "Positive reinforcement signals were detected" in prompt assert "Positive reinforcement signals were detected" in prompt
class TestFinalizeCacheIsolation:
"""_finalize_update must not mutate the cached memory object."""
def test_deepcopy_prevents_cache_corruption_on_save_failure(self):
"""If save() fails, the in-memory snapshot used by _finalize_update
must remain independent of any object the storage layer may still hold in
its cache. The deepcopy in _finalize_update achieves this the object
passed to _apply_updates is always a fresh copy, never the cache reference.
"""
updater = MemoryUpdater()
original_memory = _make_memory(facts=[{"id": "fact_orig", "content": "original", "category": "context", "confidence": 0.9, "createdAt": "2024-01-01T00:00:00Z", "source": "t1"}])
import json as _json
new_fact_json = _json.dumps(
{
"user": {},
"history": {},
"newFacts": [{"content": "new fact", "category": "context", "confidence": 0.9}],
"factsToRemove": [],
}
)
mock_response = MagicMock()
mock_response.content = new_fact_json
mock_model = AsyncMock()
mock_model.ainvoke = AsyncMock(return_value=mock_response)
saved_objects: list[dict] = []
save_mock = MagicMock(side_effect=lambda m, a=None: saved_objects.append(m) or False) # always fails
with (
patch.object(updater, "_get_model", return_value=mock_model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True, fact_confidence_threshold=0.7)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=original_memory),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=save_mock)),
):
msg = MagicMock()
msg.type = "human"
msg.content = "hello"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "world"
ai_msg.tool_calls = []
updater.update_memory([msg, ai_msg], thread_id="t1")
# original_memory must not have been mutated — deepcopy isolates the mutation
assert len(original_memory["facts"]) == 1, "original_memory must not be mutated by _apply_updates"
assert original_memory["facts"][0]["content"] == "original"
+10 -10
View File
@@ -3,14 +3,14 @@
Covers two functions introduced to prevent ephemeral file-upload context from Covers two functions introduced to prevent ephemeral file-upload context from
persisting in long-term memory: persisting in long-term memory:
- _filter_messages_for_memory (memory_middleware) - filter_messages_for_memory (message_processing)
- _strip_upload_mentions_from_memory (updater) - _strip_upload_mentions_from_memory (updater)
""" """
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
@@ -31,7 +31,7 @@ def _ai(text: str, tool_calls=None) -> AIMessage:
# =========================================================================== # ===========================================================================
# _filter_messages_for_memory # filter_messages_for_memory
# =========================================================================== # ===========================================================================
@@ -45,7 +45,7 @@ class TestFilterMessagesForMemory:
_human(_UPLOAD_BLOCK), _human(_UPLOAD_BLOCK),
_ai("I have read the file. It says: Hello."), _ai("I have read the file. It says: Hello."),
] ]
result = _filter_messages_for_memory(msgs) result = filter_messages_for_memory(msgs)
assert result == [] assert result == []
def test_upload_with_real_question_preserves_question(self): def test_upload_with_real_question_preserves_question(self):
@@ -56,7 +56,7 @@ class TestFilterMessagesForMemory:
_human(combined), _human(combined),
_ai("The file contains: Hello DeerFlow."), _ai("The file contains: Hello DeerFlow."),
] ]
result = _filter_messages_for_memory(msgs) result = filter_messages_for_memory(msgs)
assert len(result) == 2 assert len(result) == 2
human_result = result[0] human_result = result[0]
@@ -71,7 +71,7 @@ class TestFilterMessagesForMemory:
_human("What is the capital of France?"), _human("What is the capital of France?"),
_ai("The capital of France is Paris."), _ai("The capital of France is Paris."),
] ]
result = _filter_messages_for_memory(msgs) result = filter_messages_for_memory(msgs)
assert len(result) == 2 assert len(result) == 2
assert result[0].content == "What is the capital of France?" assert result[0].content == "What is the capital of France?"
assert result[1].content == "The capital of France is Paris." assert result[1].content == "The capital of France is Paris."
@@ -84,7 +84,7 @@ class TestFilterMessagesForMemory:
ToolMessage(content="Search results", tool_call_id="1"), ToolMessage(content="Search results", tool_call_id="1"),
_ai("Here are the results."), _ai("Here are the results."),
] ]
result = _filter_messages_for_memory(msgs) result = filter_messages_for_memory(msgs)
human_msgs = [m for m in result if m.type == "human"] human_msgs = [m for m in result if m.type == "human"]
ai_msgs = [m for m in result if m.type == "ai"] ai_msgs = [m for m in result if m.type == "ai"]
assert len(human_msgs) == 1 assert len(human_msgs) == 1
@@ -101,7 +101,7 @@ class TestFilterMessagesForMemory:
_human("What is 2 + 2?"), _human("What is 2 + 2?"),
_ai("4"), _ai("4"),
] ]
result = _filter_messages_for_memory(msgs) result = filter_messages_for_memory(msgs)
human_contents = [m.content for m in result if m.type == "human"] human_contents = [m.content for m in result if m.type == "human"]
ai_contents = [m.content for m in result if m.type == "ai"] ai_contents = [m.content for m in result if m.type == "ai"]
@@ -121,14 +121,14 @@ class TestFilterMessagesForMemory:
] ]
) )
msgs = [msg, _ai("Done.")] msgs = [msg, _ai("Done.")]
result = _filter_messages_for_memory(msgs) result = filter_messages_for_memory(msgs)
assert result == [] assert result == []
def test_file_path_not_in_filtered_content(self): def test_file_path_not_in_filtered_content(self):
"""After filtering, no upload file path should appear in any message.""" """After filtering, no upload file path should appear in any message."""
combined = _UPLOAD_BLOCK + "\n\nSummarise the file please." combined = _UPLOAD_BLOCK + "\n\nSummarise the file please."
msgs = [_human(combined), _ai("It says hello.")] msgs = [_human(combined), _ai("It says hello.")]
result = _filter_messages_for_memory(msgs) result = filter_messages_for_memory(msgs)
all_content = " ".join(m.content for m in result if isinstance(m.content, str)) all_content = " ".join(m.content for m in result if isinstance(m.content, str))
assert "/mnt/user-data/uploads/" not in all_content assert "/mnt/user-data/uploads/" not in all_content
assert "<uploaded_files>" not in all_content assert "<uploaded_files>" not in all_content
+397
View File
@@ -0,0 +1,397 @@
"""
Unit tests for MindIEChatModel adapter.
"""
from unittest.mock import AsyncMock, patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult
# ── Import the module under test ──────────────────────────────────────────────
from deerflow.models.mindie_provider import (
MindIEChatModel,
_fix_messages,
_parse_xml_tool_call_to_dict,
)
# ═════════════════════════════════════════════════════════════════════════════
# Helpers
# ═════════════════════════════════════════════════════════════════════════════
def _make_chat_result(content: str, tool_calls=None) -> ChatResult:
msg = AIMessage(content=content)
if tool_calls:
msg.tool_calls = tool_calls
gen = ChatGeneration(message=msg)
return ChatResult(generations=[gen])
# ═════════════════════════════════════════════════════════════════════════════
# 1. _fix_messages
# ═════════════════════════════════════════════════════════════════════════════
class TestFixMessages:
# ── list content → str ────────────────────────────────────────────────────
def test_list_content_extracted_to_str(self):
msg = HumanMessage(
content=[
{"type": "text", "text": "Hello"},
{"type": "text", "text": " world"},
]
)
result = _fix_messages([msg])
assert result[0].content == "Hello world"
def test_list_content_ignores_non_text_blocks(self):
msg = HumanMessage(
content=[
{"type": "image_url", "image_url": "http://x.com/img.png"},
{"type": "text", "text": "caption"},
]
)
result = _fix_messages([msg])
assert result[0].content == "caption"
def test_empty_list_content_becomes_space(self):
msg = HumanMessage(content=[])
result = _fix_messages([msg])
assert result[0].content == " "
# ── plain str content ─────────────────────────────────────────────────────
def test_plain_string_content_preserved(self):
msg = HumanMessage(content="hi there")
result = _fix_messages([msg])
assert result[0].content == "hi there"
def test_empty_string_content_becomes_space(self):
msg = HumanMessage(content="")
result = _fix_messages([msg])
assert result[0].content == " "
# ── AIMessage with tool_calls → XML ───────────────────────────────────────
def test_ai_message_with_tool_calls_serialised_to_xml(self):
msg = AIMessage(
content="Sure",
tool_calls=[
{
"name": "get_weather",
"args": {"city": "London"},
"id": "call_abc",
}
],
)
result = _fix_messages([msg])
out = result[0]
assert isinstance(out, AIMessage)
assert "<tool_call>" in out.content
assert "<function=get_weather>" in out.content
assert '<parameter=city>"London"</parameter>' in out.content
assert not getattr(out, "tool_calls", [])
def test_ai_message_text_preserved_before_xml(self):
msg = AIMessage(
content="Here you go",
tool_calls=[{"name": "search", "args": {"q": "pytest"}, "id": "x"}],
)
result = _fix_messages([msg])
assert result[0].content.startswith("Here you go")
def test_ai_message_multiple_tool_calls(self):
msg = AIMessage(
content="",
tool_calls=[
{"name": "tool_a", "args": {"x": 1}, "id": "id1"},
{"name": "tool_b", "args": {"y": 2}, "id": "id2"},
],
)
result = _fix_messages([msg])
content = result[0].content
assert content.count("<tool_call>") == 2
assert "<function=tool_a>" in content
assert "<function=tool_b>" in content
# ── ToolMessage → HumanMessage ────────────────────────────────────────────
def test_tool_message_becomes_human_message(self):
msg = ToolMessage(content="42 degrees", tool_call_id="call_abc")
result = _fix_messages([msg])
out = result[0]
assert isinstance(out, HumanMessage)
assert "<tool_response>" in out.content
assert "42 degrees" in out.content
def test_tool_message_with_list_content(self):
msg = ToolMessage(
content=[{"type": "text", "text": "result"}],
tool_call_id="call_xyz",
)
result = _fix_messages([msg])
assert isinstance(result[0], HumanMessage)
assert "result" in result[0].content
# ── Mixed message list ────────────────────────────────────────────────────
def test_mixed_message_types_ordering_preserved(self):
msgs = [
HumanMessage(content="q"),
AIMessage(content="a"),
ToolMessage(content="tool out", tool_call_id="c1"),
HumanMessage(content="follow up"),
]
result = _fix_messages(msgs)
assert len(result) == 4
assert isinstance(result[2], HumanMessage)
assert result[3].content == "follow up"
# ── SystemMessage pass-through ────────────────────────────────────────────
def test_system_message_passed_through_unchanged(self):
msg = SystemMessage(content="You are helpful.")
result = _fix_messages([msg])
assert result[0].content == "You are helpful."
# ═════════════════════════════════════════════════════════════════════════════
# 2. _parse_xml_tool_call_to_dict
# ═════════════════════════════════════════════════════════════════════════════
class TestParseXmlToolCalls:
def test_no_tool_call_returns_original(self):
content = "Just a normal reply."
clean, calls = _parse_xml_tool_call_to_dict(content)
assert clean == content
assert calls == []
def test_single_tool_call_parsed(self):
content = "<tool_call> <function=search> <parameter=query>pytest</parameter> </function> </tool_call>"
clean, calls = _parse_xml_tool_call_to_dict(content)
assert clean == ""
assert len(calls) == 1
assert calls[0]["name"] == "search"
assert calls[0]["args"]["query"] == "pytest"
assert calls[0]["id"].startswith("call_")
def test_multiple_tool_calls_parsed(self):
content = "<tool_call><function=a><parameter=x>1</parameter></function></tool_call><tool_call><function=b><parameter=y>2</parameter></function></tool_call>"
_, calls = _parse_xml_tool_call_to_dict(content)
assert len(calls) == 2
assert calls[0]["name"] == "a"
assert calls[1]["name"] == "b"
def test_text_before_tool_call_preserved(self):
content = "Here is the answer.\n<tool_call><function=f><parameter=k>v</parameter></function></tool_call>"
clean, calls = _parse_xml_tool_call_to_dict(content)
assert clean == "Here is the answer."
assert len(calls) == 1
def test_integer_param_deserialised(self):
content = "<tool_call><function=f><parameter=n>42</parameter></function></tool_call>"
_, calls = _parse_xml_tool_call_to_dict(content)
assert calls[0]["args"]["n"] == 42
def test_list_param_deserialised(self):
content = '<tool_call><function=f><parameter=lst>["a","b"]</parameter></function></tool_call>'
_, calls = _parse_xml_tool_call_to_dict(content)
assert calls[0]["args"]["lst"] == ["a", "b"]
def test_dict_param_deserialised(self):
content = '<tool_call><function=f><parameter=d>{"k": 1}</parameter></function></tool_call>'
_, calls = _parse_xml_tool_call_to_dict(content)
assert calls[0]["args"]["d"] == {"k": 1}
def test_bool_param_deserialised(self):
content = "<tool_call><function=f><parameter=flag>true</parameter></function></tool_call>"
_, calls = _parse_xml_tool_call_to_dict(content)
assert calls[0]["args"]["flag"] is True
def test_malformed_param_stays_string(self):
content = "<tool_call><function=f><parameter=bad>{broken json</parameter></function></tool_call>"
_, calls = _parse_xml_tool_call_to_dict(content)
assert calls[0]["args"]["bad"] == "{broken json"
def test_non_string_input_returned_as_is(self):
result = _parse_xml_tool_call_to_dict(None)
assert result == (None, [])
def test_unique_ids_generated(self):
block = "<tool_call><function=f><parameter=k>v</parameter></function></tool_call>"
_, c1 = _parse_xml_tool_call_to_dict(block)
_, c2 = _parse_xml_tool_call_to_dict(block)
assert c1[0]["id"] != c2[0]["id"]
# ═════════════════════════════════════════════════════════════════════════════
# 3. MindIEChatModel._patch_result_with_tools
# ═════════════════════════════════════════════════════════════════════════════
class TestPatchResult:
def _model(self):
with patch.object(MindIEChatModel, "__init__", return_value=None):
m = MindIEChatModel.__new__(MindIEChatModel)
return m
def test_escaped_newlines_fixed(self):
model = self._model()
result = _make_chat_result("line1\\nline2")
patched = model._patch_result_with_tools(result)
assert patched.generations[0].message.content == "line1\nline2"
def test_xml_tool_calls_extracted(self):
model = self._model()
content = "<tool_call><function=calc><parameter=expr>1+1</parameter></function></tool_call>"
result = _make_chat_result(content)
patched = model._patch_result_with_tools(result)
msg = patched.generations[0].message
assert msg.content == ""
assert len(msg.tool_calls) == 1
assert msg.tool_calls[0]["name"] == "calc"
def test_patch_result_appends_to_existing_tool_calls(self):
model = self._model()
existing = [{"name": "existing", "args": {}, "id": "e1"}]
content = "<tool_call><function=new_tool><parameter=k>v</parameter></function></tool_call>"
result = _make_chat_result(content, tool_calls=existing)
patched = model._patch_result_with_tools(result)
msg = patched.generations[0].message
assert len(msg.tool_calls) == 2
names = [tc["name"] for tc in msg.tool_calls]
assert "existing" in names
assert "new_tool" in names
def test_no_tool_call_content_unchanged(self):
model = self._model()
result = _make_chat_result("plain reply")
patched = model._patch_result_with_tools(result)
assert patched.generations[0].message.content == "plain reply"
def test_non_string_content_skipped(self):
model = self._model()
msg = AIMessage(content=[{"type": "text", "text": "hi"}])
gen = ChatGeneration(message=msg)
result = ChatResult(generations=[gen])
patched = model._patch_result_with_tools(result)
assert patched is not None
# ═════════════════════════════════════════════════════════════════════════════
# 4. MindIEChatModel._generate (sync)
# ═════════════════════════════════════════════════════════════════════════════
class TestGenerate:
def test_generate_calls_fix_messages_and_patch(self):
with patch("deerflow.models.mindie_provider.ChatOpenAI._generate") as mock_super_gen, patch.object(MindIEChatModel, "__init__", return_value=None):
mock_super_gen.return_value = _make_chat_result("hello")
model = MindIEChatModel.__new__(MindIEChatModel)
msgs = [HumanMessage(content="ping")]
result = model._generate(msgs)
assert mock_super_gen.called
called_msgs = mock_super_gen.call_args[0][0]
assert all(isinstance(m.content, str) for m in called_msgs)
assert result.generations[0].message.content == "hello"
# ═════════════════════════════════════════════════════════════════════════════
# 5. MindIEChatModel._agenerate (async)
# ═════════════════════════════════════════════════════════════════════════════
class TestAGenerate:
@pytest.mark.asyncio
async def test_agenerate_patches_result(self):
with patch("deerflow.models.mindie_provider.ChatOpenAI._agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None):
mock_ag.return_value = _make_chat_result("world\\nfoo")
model = MindIEChatModel.__new__(MindIEChatModel)
result = await model._agenerate([HumanMessage(content="hi")])
assert result.generations[0].message.content == "world\nfoo"
# ═════════════════════════════════════════════════════════════════════════════
# 6. MindIEChatModel._astream (async generator)
# ═════════════════════════════════════════════════════════════════════════════
class TestAStream:
async def _collect(self, gen):
chunks = []
async for chunk in gen:
chunks.append(chunk)
return chunks
@pytest.mark.asyncio
async def test_no_tools_uses_real_stream(self):
from langchain_core.messages import AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk
async def fake_stream(*args, **kwargs):
for char in ["hel", "lo"]:
yield ChatGenerationChunk(message=AIMessageChunk(content=char))
with patch("deerflow.models.mindie_provider.ChatOpenAI._astream", side_effect=fake_stream), patch.object(MindIEChatModel, "__init__", return_value=None):
model = MindIEChatModel.__new__(MindIEChatModel)
chunks = await self._collect(model._astream([HumanMessage(content="hi")]))
assert "".join(c.message.content for c in chunks) == "hello"
@pytest.mark.asyncio
async def test_no_tools_fixes_escaped_newlines_in_stream(self):
from langchain_core.messages import AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk
async def fake_stream(*args, **kwargs):
yield ChatGenerationChunk(message=AIMessageChunk(content="a\\nb"))
with patch("deerflow.models.mindie_provider.ChatOpenAI._astream", side_effect=fake_stream), patch.object(MindIEChatModel, "__init__", return_value=None):
model = MindIEChatModel.__new__(MindIEChatModel)
chunks = await self._collect(model._astream([HumanMessage(content="x")]))
assert chunks[0].message.content == "a\nb"
@pytest.mark.asyncio
async def test_with_tools_fake_streams_text_in_chunks(self):
with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None):
long_text = "A" * 50
mock_ag.return_value = _make_chat_result(long_text)
model = MindIEChatModel.__new__(MindIEChatModel)
chunks = await self._collect(model._astream([HumanMessage(content="q")], tools=[{"type": "function", "function": {"name": "dummy"}}]))
full = "".join(c.message.content for c in chunks)
assert full == long_text
assert len(chunks) > 1
@pytest.mark.asyncio
async def test_with_tools_emits_tool_call_chunk(self):
tool_calls = [{"name": "fn", "args": {}, "id": "c1"}]
with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None):
mock_ag.return_value = _make_chat_result("ok", tool_calls=tool_calls)
model = MindIEChatModel.__new__(MindIEChatModel)
chunks = await self._collect(model._astream([HumanMessage(content="q")], tools=[{"type": "function", "function": {"name": "fn"}}]))
tool_chunks = [c for c in chunks if getattr(c.message, "tool_calls", [])]
assert tool_chunks, "No chunk carried tool_calls"
assert tool_chunks[-1].message.tool_calls[0]["name"] == "fn"
@pytest.mark.asyncio
async def test_with_tools_empty_text_still_emits_tool_chunk(self):
tool_calls = [{"name": "x", "args": {}, "id": "c2"}]
with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None):
mock_ag.return_value = _make_chat_result("", tool_calls=tool_calls)
model = MindIEChatModel.__new__(MindIEChatModel)
chunks = await self._collect(model._astream([HumanMessage(content="q")], tools=[{"type": "function", "function": {"name": "x"}}]))
assert any(getattr(c.message, "tool_calls", []) for c in chunks)

Some files were not shown because too many files have changed in this diff Show More