Compare commits

..

23 Commits

Author SHA1 Message Date
rayhpeng 9ed83c84dc fix(runtime): use pass for protocol stubs 2026-06-01 15:31:46 +08:00
rayhpeng 30bb2d5149 refactor(runtime): add run DDD boundary skeleton 2026-06-01 09:22:32 +08:00
AochenShen99 9f3be2a9fa fix(agents): offload UploadsMiddleware uploads scan off the event loop (#3311)
UploadsMiddleware defines only the sync `before_agent` hook. LangChain wires a
sync-only hook as `RunnableCallable(before_agent, None)`, and LangGraph's
`ainvoke` runs it directly on the event loop when `afunc is None` — so the
per-message uploads-directory scan (`exists`/`iterdir`/`stat` plus reading
sibling `.md` outlines) blocks the asyncio event loop on every message that has
an uploads directory.

Add `abefore_agent` that offloads the scan to a worker thread via
`run_in_executor`; it copies the current context, preserving the `user_id`
contextvar read by `get_effective_user_id()`.

Add a runtime anchor under `tests/blocking_io/` that drives the real
`create_agent` graph via `ainvoke` under the strict Blockbuster gate, so a
regression back onto the event loop fails CI. Update blocking-IO docs.
2026-05-30 21:46:35 +08:00
Ryker_Feng e8e9edcb6e fix(channels): ignore hidden control messages when extracting replies (#3219) (#3270) 2026-05-29 23:06:58 +08:00
AochenShen99 4093c83383 refactor(provider): share assistant payload replay matching (#3307)
* Share assistant payload replay matching

* fix(provider): recover assistant field when ordinal AI index is taken

The mismatch-length fallback in `_match_ai_message` only tried the exact
`fallback_ordinal` AI index. When serialization drops or reorders an
assistant message, a unique signature match can consume a non-ordinal
index, leaving a later ambiguous payload's ordinal already used — so its
provider field (e.g. `reasoning_content`) was silently dropped.

Scan forward from the ordinal for the next unused `AIMessage` (wrapping to
earlier indices) to preserve the positional bias while still recovering
the field. Forward scanning avoids a naive min-unused pick that could
restore the wrong field after a leading message is dropped.

Add a regression test for the dropped-leading-message case.

* fix(provider): avoid earlier assistant fallback replay
2026-05-29 23:05:59 +08:00
AochenShen99 052b1e2102 test(runtime): add Blockbuster runtime anchor for JsonlRunEventStore async IO (#3313)
* test(runtime): add Blockbuster runtime anchor for JsonlRunEventStore async IO

#3084 offloaded `JsonlRunEventStore`'s file IO via `asyncio.to_thread` and added
a mock-based offload assertion (`tests/test_jsonl_event_store_async_io.py`) that
covers `put()` only. That guard is not part of the Blockbuster runtime gate
(`tests/blocking_io/`) run by `backend-blocking-io-tests.yml`.

Add a runtime anchor that drives the full async surface (`put`, `put_batch`,
`list_messages`, `list_events`, `list_messages_by_run`, `count_messages`,
`delete_by_run`, `delete_by_thread`) under the strict Blockbuster gate, so any
blocking IO reintroduced on the event loop in any of these methods fails CI —
not only removal of a specific `to_thread` call. Verified each offloaded method
goes red when its offload is reverted. Test-only; no production change.

* test(runtime): exercise list_events event_types filter branch

Per review feedback: the anchor called list_events without event_types,
so the filter branch never ran after _read_run_events' filesystem IO.
Add a second list_events call with event_types=["message"] so the full
read path -- including the filter branch -- executes under the gate.
2026-05-29 23:02:41 +08:00
Xinmin Zeng ca487578a4 feat(agent): add ToolOutputBudgetMiddleware for oversized tool output protection (#3303)
* feat(agent): add ToolOutputBudgetMiddleware for oversized tool output protection

Closes #3289. Adds a unified middleware that enforces per-result budgets
on ALL tool outputs (MCP, sandbox, community, custom), preventing
oversized external tool results from blowing the model context window.

Design informed by claude-code (persistToolResult), hermes-agent
(tool_result_storage), and pi (OutputAccumulator) — the three most
mature implementations in production coding-agent frameworks.

Key features:
- Disk externalization: oversized outputs written to thread-local
  .tool-results/ directory, replaced with compact preview + file
  reference. Model can read full output via read_file with offset/limit.
- Fallback truncation: head+tail truncation when disk is unavailable
  (no thread_data, write failure), ensuring the context is always
  protected.
- read_file exemption: prevents persist-read-persist infinite loops
  (independently discovered by claude-code, hermes-agent, and pi).
- Per-tool threshold overrides via config.
- Line-boundary-aware truncation (no partial lines in previews).
- Multimodal content passthrough (images/structured blocks skip budget).
- Historical ToolMessage patching in wrap_model_call for checkpoint
  recovery scenarios.

Related: #3222 (design RFC), #1844 (comprehensive context management),
#3137 (write_file args compaction), #1677 (sandbox tool truncation).

* test: add MCP content_and_artifact format coverage

Add 5 tests for MCP tool output format (list of content blocks):
- text content blocks are extracted and budgeted
- multiple text blocks are joined and budgeted
- image content blocks are skipped (multimodal passthrough)
- mixed text+image blocks are skipped
- small text blocks pass through unchanged

Total test count: 59 (was 54).

* fix(agent): address Codex review findings for ToolOutputBudgetMiddleware

Three issues identified by Codex code review, all fixed:

1. `enabled` config field was unused — middleware now checks
   `config.enabled` and skips all processing when disabled.

2. `_build_fallback` could exceed `fallback_max_chars` — the marker
   text itself (~139 chars) was not deducted from the budget. Now
   pre-computes marker overhead and falls back to hard slice when
   max_chars is smaller than the marker.

3. Sync file I/O in async path — `awrap_tool_call` now delegates
   `_patch_result` to `asyncio.to_thread` to avoid blocking the
   event loop during disk writes.

Tests updated to use realistic fallback_max_chars values (500+)
that can accommodate the marker overhead, plus two new tests:
- `test_result_never_exceeds_max_chars` (parametric across sizes)
- `test_very_small_max_chars_does_not_crash`

* fix(agent): address Copilot review — path traversal, async perf, shared config

1. Path traversal defense: sanitize tool_name via _sanitize_tool_name()
   (strips separators, .., absolute paths), validate storage_subdir is
   relative, and verify resolved filepath stays inside storage_dir.

2. Async hot-path optimization: add _needs_budget() cheap check before
   asyncio.to_thread offload — small outputs (99% of calls) skip the
   thread overhead entirely.

3. Replace shared module-level _DEFAULT_CONFIG with _default_config()
   factory to prevent cross-instance mutation of mutable fields.

12 new tests: TestSanitizeToolName (5), TestExternalizePathTraversal (3),
TestNeedsBudget (4).

* fix(agent): correct preview hint to match read_file actual API

read_file uses start_line/end_line (1-indexed line numbers), not
offset/limit. The previous wording was copied from hermes-agent
which has a different read_file interface.

* perf(agent): hoist hot-path imports, add model-call pre-scan (review #3303)

Address maintainer review feedback:

1. Hoist inline imports to module level — `import asyncio` (was in
   awrap_tool_call hot path) and `from dataclasses import replace`
   (was in _patch_result) now live at module top.

2. Add a cheap pre-scan to _patch_model_messages so the historical
   message list is not rebuilt on every model call when nothing is
   oversized (the common case once results are budgeted at tool-call
   time). Also adds the same _needs_budget gate to the sync
   wrap_tool_call for symmetry with awrap_tool_call.

The pre-scan is refactored into per-tool-aware helpers
(_effective_trigger / _tool_message_over_budget) that mirror the exact
trigger conditions in _budget_content — including tool_overrides — so
the fast-path can never produce a false negative (silently skipping
budgeting for a tool with a low per-tool threshold).

7 new regression tests lock the per-tool-override-through-pre-scan path
and the model-call early return.

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-29 22:59:26 +08:00
Nan Gao e683ed6a76 fix(runtime): guide malformed write_file recovery (#3040)
* fix(runtime): guide malformed write_file recovery

* fix(runtime): align write_file recovery guidance
2026-05-29 17:46:24 +08:00
Eilen Shin 872079b894 docs: clean standalone LangGraph server remnants (#3301) 2026-05-29 11:36:45 +08:00
john lee cbf8b194e8 fix(runtime): harden JSONL async I/O and DB put_batch thread validation (#3084)
* fix(runtime): harden JSONL async I/O and DB put_batch thread validation (#2816)

- JsonlRunEventStore: offload all file I/O to asyncio.to_thread() so the
  event loop is never blocked; add per-thread asyncio.Lock to serialise
  concurrent puts and prevent interleaved JSONL lines
- Split _ensure_seq_loaded into a sync _compute_max_seq (runs in thread)
  and an async wrapper; seq counter is recovered from disk on fresh store init
- DbRunEventStore.put_batch: raise ValueError when events span multiple
  thread_ids (previously silently assumed same thread)
- Add test_jsonl_event_store_async_io.py: 12 tests covering lock reuse,
  concurrent seq monotonicity, disk recovery, and mixed-thread batch rejection

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: address Copilot review comments

- delete_by_thread: pop _write_locks after releasing the lock to prevent
  unbounded growth when threads are repeatedly created and deleted
- tests: add regression guard asserting asyncio.to_thread is called for
  _write_record in put(); assert _write_locks entry removed on delete

* fix(lint): move patch import to local scope to fix ruff I001

* fix(lint): apply ruff check+format fixes to test file

* fix(runtime): address review feedback for JSONL async I/O hardening (#2816)

Use setdefault for atomic lock init in _get_write_lock; pop _write_locks
inside the held lock scope in delete_by_thread; update test docstring
and assert lock entry also cleared on delete.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: rayhpeng <rayhpeng@gmail.com>
2026-05-29 09:27:53 +08:00
Nan Gao d46a5779bc fix(chat): preserve messages after summarization (#3280)
* fix(chat): preserve messages after summarization

* make format

* fix(chat): address summarization review comments
2026-05-29 08:24:47 +08:00
Xinmin Zeng 2ace78d1e5 fix(frontend): surface backend detail when agent name check fails (#3048)
* fix(frontend): surface backend detail when agent name check fails

The new-agent page caught AgentNameCheckError but only branched on
reason === "backend_unreachable". Everything else (notably the 422
"Invalid agent name '...'. Must match ^[A-Za-z0-9-]+$" response from
GET /api/agents/check when the user submits a name with disallowed
characters — trailing space, dot, Chinese, invisible whitespace from
copy-paste) fell through to the generic fallback "Could not verify
name availability — please try again", swallowing the detail that
already told the user exactly what to fix.

Add a request_failed branch that surfaces err.message (which
checkAgentName already populates from the backend's detail at
core/agents/api.ts). The disabled / backend_unreachable / unknown-
error paths are unchanged.

Pin the contract with unit tests covering: 200 success, fetch
rejection, 502/503/504 network errors, agents_api disabled detail,
422 validation detail carried verbatim, statusText fallback when
detail is absent, and a regression guard against misclassifying a
422 as agents_api disabled.

Closes #3041

* fix(frontend): localise the error prefix when surfacing backend detail

The previous commit surfaced the backend's raw `err.message` on the
new-agent page when the name check failed. The detail itself is
English (backend's `_validate_agent_name` text, any 5xx business
message, etc.) and dropping it bare into a zh-CN page produced a
jarring English-among-Chinese line that didn't match neighbouring
strings like "已存在同名智能体" / "无法验证名称可用性".

Add `nameStepCheckErrorWithDetail` as a templated string ("Name
check failed: {detail}" / "名称校验失败:{detail}"), mirroring the
existing `nameStepBootstrapMessage` `{name}` template pattern. The
page wraps `err.message` in it when present and falls back to the
plain `nameStepCheckError` when the detail is empty.

Rendered output (verified locally with a Console fetch mock that
returns 500 + detail):

  zh-CN: 名称校验失败:Database connection lost: SQLAlchemy connection
         pool exhausted (max 5 connections, all in use)
  en-US: Name check failed: Database connection lost: SQLAlchemy
         connection pool exhausted (max 5 connections, all in use)

The localised prefix tells the user *what operation* failed; the
raw detail tells them *why*. Translating the detail itself would
be lossy (any unbounded backend string would need a translation
table) and would break the debuggability the previous commit
delivered.

Refs #3041

* fix(frontend): distinguish backend detail from generated fallback in AgentNameCheckError

Addresses Copilot's review on #3048: the previous commits keyed off
`err.message`, but `checkAgentName` substitutes a generated fallback
string ("Failed to check agent name: ${statusText}") when the backend
sent no detail. That guaranteed `err.message` was always truthy, made
the `nameStepCheckError` fallback branch unreachable in practice, and
could surface awkward strings like "名称校验失败:Failed to check
agent name: Bad Gateway" in the UI.

Add an explicit `detail: string | null` field to AgentNameCheckError.
`checkAgentName` populates it only when the backend response actually
carried a string `detail` (defensive guard against the dict-shaped
detail that other deer-flow endpoints use for typed error codes).
The new-agent page now selects on `err.detail` instead of `err.message`
so the localised fallback wins when no real detail exists.

Also fix the prettier formatting that broke lint-frontend CI on the
previous push.

Test changes:
- The 422 carry-through test now asserts both `detail` and `message`
  hold the backend string verbatim.
- A new "falls back to statusText in message but leaves detail null"
  test pins the contract that no real detail ⇒ no UI surface leak.
- A new "treats non-string detail as null" test guards against future
  backend schema drift toward dict-shaped detail.

Refs #3041 #3048
2026-05-28 18:38:45 +08:00
AochenShen99 8330b244a9 docs: add blocking IO detection usage and maintenance (#3233)
* docs: add blocking IO detection usage and maintenance

* docs: address blocking io doc review feedback
2026-05-28 18:26:26 +08:00
AochenShen99 44677c5eb4 feat(provider) Add patched MiMo reasoning content support (#3298)
* Add patched MiMo reasoning content support

* Clarify MiMo patched model coverage

* Remove unused MiMo payload index

* Address MiMo review nits
2026-05-28 18:24:32 +08:00
Admire 2fdfff0db3 fix(frontend): fix Mermaid preview failure in historical messages (#3196)
* fix(frontend): render historical mermaid diagrams

* fix(frontend): address mermaid review feedback

* Stabilize cancel lifecycle test

* fix(frontend): handle mermaid fence variants

* fix(frontend): normalize mermaid arrow spacing

* fix(frontend): handle mermaid CRLF fences

* chore: keep mermaid fix frontend-scoped

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-28 18:20:02 +08:00
zgenu 737abc0e45 fix: ignore stale run reconnect conflicts (#3284)
* fix: ignore stale run reconnect conflicts

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix: ignore stale run reconnect conflicts

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-28 17:29:30 +08:00
AochenShen99 8decfd327e Fix custom skill install permissions (#3241)
* Fix custom skill install permissions

* Fix skill upload test portability

* Keep custom skill writes sandbox readable

* Clear sandbox write bits on skill permissions

* Limit custom skill write permission updates
2026-05-28 15:48:32 +08:00
Xinmin Zeng 0287240728 fix(frontend): show new thread in sidebar immediately on creation (#3276) (#3283)
When a user starts a new conversation, the sidebar list did not display
it until the AI finished streaming and generated a title. This made it
impossible to switch back to an in-progress conversation when working
with multiple threads concurrently.

Optimistically insert the new thread into the TanStack Query cache
during the `onCreated` callback so the sidebar renders a placeholder
entry ("New chat") as soon as the backend acknowledges thread creation.
The existing `onUpdateEvent` title handler and `onFinish` query
invalidation then update the entry in-place with the real title.
2026-05-28 15:27:38 +08:00
Lucy Shen 37451500eb fix(gateway): split stream_existing_run into per-method routes for unique OpenAPI operationIds (#3228)
* fix(gateway): split stream_existing_run into per-method routes for unique OpenAPI operationIds

`@router.api_route("/.../stream", methods=["GET", "POST"])` registers a
single FastAPI route that holds both methods. FastAPI's auto-generated
`operationId` is computed once per route from a single method picked out
of `route.methods`, so when OpenAPI generation iterates over every method
on that route both end up sharing the same `operationId`. That triggers
`UserWarning: Duplicate Operation ID stream_existing_run_..._stream_(get|post) for function stream_existing_run`
during `app.openapi()` and produces an invalid OpenAPI spec for SDK /
codegen consumers.

Register GET and POST as two separate routes on the same handler so each
method gets a distinct auto-generated `operationId` ("..._stream_get" and
"..._stream_post"). Behavior is otherwise unchanged: same handler, same
`require_permission` decoration, same response.

Add `tests/test_openapi_operation_ids.py` to lock in the invariant:
no duplicate-operationId warnings during spec generation, globally unique
operationIds across the spec, and distinct GET / POST operationIds on the
stream endpoint specifically. Reverted the source change locally and
confirmed all three tests fail before the fix.

* test(runtime): widen CancelledError catch in _ScriptedAgent to fix cancel-race flake

`_ScriptedAgent.astream()` previously only caught `asyncio.CancelledError`
inside the inner `if self.block_after_first_chunk:` while-loop. Cancellation
arriving during any earlier `await` in the same body
(`self.model.ainvoke`, `_write_checkpoint`, the `yield`) would propagate
without setting `controller.cancelled`, so callers waiting on
`controller.cancelled.wait(5)` after `POST /cancel` returned 204 could race
and time out.

`test_cancel_interrupt_stops_running_background_run` waits only for the
`started` event (set on the first line of `astream`) before issuing cancel,
so its race window spans all three pre-loop `await`s. On a clean `main`
checkout, stress-running the test 20× reproduces the failure 6/20
(~30%). `test_cancel_rollback_restores_pre_run_checkpoint`, which waits
for the later `checkpoint_written` event, passes 20/20 — confirming the
race lives entirely in the gap between `started.set()` and the
cancellation-aware block.

Widen the try/except to cover the entire `astream` body so any
`CancelledError` sets the controller event; the non-cancel path is
unchanged (no exception means no event set). After this change the
previously flaky test passes 50/50, the rollback test still passes 30/30,
and the full backend suite remains at 3649 passed / 19 skipped.

Test-only change — `backend/tests/test_runtime_lifecycle_e2e.py` is the
only file touched; the production cancel pipeline is unaffected.
2026-05-28 08:20:52 +08:00
Lawrance_YXLiao 3cb75887c1 fix(memory): parse wrapped memory update json responses (#3252)
* fix(memory): parse wrapped memory update json responses

* test(memory): format wrapped response coverage

* fix(memory): guard malformed nested memory facts

* fix(memory): require full update object when parsing responses

* fix(memory): fail closed on unsafe partial removals

* style(memory): format updater tests
2026-05-28 07:46:44 +08:00
AochenShen99 a5599c100c fix(gateway): honour on_disconnect on /wait endpoints (#3267)
* fix(gateway): honour on_disconnect on /wait endpoints (#3265)

The non-streaming /threads/{tid}/runs/wait and /runs/wait handlers used
to await record.task directly with no disconnect handling and silently
swallow CancelledError. When a long tool call (e.g. pip install inside
a custom skill) kept the connection idle long enough for an
intermediate HTTP layer to time out, the handler would still read the
in-progress checkpoint and return it as if the run had completed
normally -- masking a half-finished run as a successful response.

Add wait_for_run_completion in app.gateway.services that mirrors
sse_consumer's bridge-consumption pattern: subscribe to the stream
bridge until END_SENTINEL, poll request.is_disconnected on every
wake-up, and on real client disconnect cancel the background run when
record.on_disconnect is "cancel". Wire it into both wait endpoints.

The streaming path was unaffected because sse_consumer already has
this loop; this just brings /wait to parity.

* fix(gateway): skip checkpoint serialization on /wait disconnect

Copilot review on #3267 caught a follow-on of the same #3265 bug: when
the client disconnects, wait_for_run_completion breaks out of the bridge
loop and cancels the run, but the /wait endpoint then continues to read
the checkpointer and serializes whatever partial checkpoint exists as a
normal 200 response.

Have the helper return a bool — True only when END_SENTINEL was observed
— and skip the checkpoint serialization path on False. Also reorder the
inner check so END_SENTINEL is honoured even when is_disconnected() flips
true in the same iteration; the run truly finished so the real final
checkpoint is still valid.
2026-05-28 07:22:39 +08:00
dependabot[bot] 9e332c594a chore(deps): bump uuid from 10.0.0 to 14.0.0 in /frontend (#3281)
Bumps [uuid](https://github.com/uuidjs/uuid) from 10.0.0 to 14.0.0.
- [Release notes](https://github.com/uuidjs/uuid/releases)
- [Changelog](https://github.com/uuidjs/uuid/blob/main/CHANGELOG.md)
- [Commits](https://github.com/uuidjs/uuid/compare/v10.0.0...v14.0.0)

---
updated-dependencies:
- dependency-name: uuid
  dependency-version: 14.0.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-28 07:14:44 +08:00
Willem Jiang 162fb2143e fix(mcp): skip session pooling for HTTP/SSE transports to avoid anyioRuntimeError (#3203) (#3224)
* fix(mcp): skip session pooling for HTTP/SSE transports to avoid anyio RuntimeError (#3203)

  HTTP/SSE transports use anyio.TaskGroup internally for streamable
  connections. These task groups have cancel scopes bound to the async task
  that created them, so closing a pooled session from a different task
  raises RuntimeError. Restrict session pooling to stdio transports only.

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* docs: clarify MCP pooling applies only to stdio tools

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/2dd9881d-54c6-45fd-90bc-154a09e29841

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
2026-05-27 08:32:57 +08:00
105 changed files with 5991 additions and 583 deletions
+5 -5
View File
@@ -59,7 +59,7 @@ smoke-test/
2. **Check pnpm** - Package manager
3. **Check uv** - Python package manager
4. **Check nginx** - Reverse proxy
5. **Check required ports** - Confirm that ports 2026, 3000, 8001, and 2024 are not occupied
5. **Check required ports** - Confirm that ports 2026, 3000, and 8001 are not occupied
**Docker mode environment check** (if Docker is selected):
1. **Check whether Docker is installed** - Run `docker --version`
@@ -93,17 +93,17 @@ smoke-test/
### Phase 5: Service Health Check
**Local mode health check**:
1. **Check process status** - Confirm that LangGraph, Gateway, Frontend, and Nginx processes are all running
1. **Check process status** - Confirm that Gateway, Frontend, and Nginx processes are all running
2. **Check frontend service** - Visit `http://localhost:2026` and verify that the page loads
3. **Check API Gateway** - Verify the `http://localhost:2026/health` endpoint
4. **Check LangGraph service** - Verify the availability of relevant endpoints
4. **Check LangGraph-compatible API** - Verify the `/api/langgraph/*` route exposed by Gateway
5. **Frontend route smoke check** - Run `bash .agent/skills/smoke-test/scripts/frontend_check.sh` to verify key routes under `/workspace`
**Docker mode health check** (when using Docker):
1. **Check container status** - Run `docker ps` and confirm that all containers are running
2. **Check frontend service** - Visit `http://localhost:2026` and verify that the page loads
3. **Check API Gateway** - Verify the `http://localhost:2026/health` endpoint
4. **Check LangGraph service** - Verify the availability of relevant endpoints
4. **Check LangGraph-compatible API** - Verify the `/api/langgraph/*` route exposed by Gateway
5. **Frontend route smoke check** - Run `bash .agent/skills/smoke-test/scripts/frontend_check.sh` to verify key routes under `/workspace`
### Optional Functional Verification
@@ -135,7 +135,7 @@ smoke-test/
The following warnings can appear during smoke testing and do not block a successful result:
- Feishu/Lark SSL errors in Gateway logs (certificate verification failure) can be ignored if that channel is not enabled
- Warnings in LangGraph logs about missing methods in the custom checkpointer, such as `adelete_for_runs` or `aprune`, do not affect the core functionality
- Warnings in Gateway logs about missing methods in the custom checkpointer, such as `adelete_for_runs` or `aprune`, do not affect the core functionality
## Key Tools
+8 -10
View File
@@ -138,7 +138,6 @@ This document describes the detailed operating steps for each phase of the DeerF
lsof -i :2026 # Main port
lsof -i :3000 # Frontend
lsof -i :8001 # Gateway
lsof -i :2024 # LangGraph
```
**Success Criteria**: All ports are free, or they are occupied only by DeerFlow-related processes.
@@ -258,7 +257,7 @@ This document describes the detailed operating steps for each phase of the DeerF
**Steps**:
1. Run `make dev-daemon` (background mode)
**Description**: This command starts all services (LangGraph, Gateway, Frontend, Nginx).
**Description**: This command starts all services (Gateway embedded runtime, Frontend, Nginx).
**Notes**:
- `make dev` runs in the foreground and stops with Ctrl+C
@@ -272,7 +271,6 @@ This document describes the detailed operating steps for each phase of the DeerF
**Steps**:
1. Wait 90-120 seconds for all services to start completely
2. You can monitor startup progress by checking these log files:
- `logs/langgraph.log`
- `logs/gateway.log`
- `logs/frontend.log`
- `logs/nginx.log`
@@ -316,11 +314,10 @@ This document describes the detailed operating steps for each phase of the DeerF
**Steps**:
1. Run the following command to check processes:
```bash
ps aux | grep -E "(langgraph|uvicorn|next|nginx)" | grep -v grep
ps aux | grep -E "(uvicorn|next|nginx)" | grep -v grep
```
**Success Criteria**: Confirm that the following processes are running:
- LangGraph (`langgraph dev`)
- Gateway (`uvicorn app.gateway.app:app`)
- Frontend (`next dev` or `next start`)
- Nginx (`nginx`)
@@ -356,10 +353,11 @@ curl http://localhost:2026/health
---
#### 5.1.4 Check LangGraph Service
#### 5.1.4 Check LangGraph-compatible API
**Steps**:
1. Visit relevant LangGraph endpoints to verify availability
1. Visit `http://localhost:2026/api/langgraph/assistants/lead_agent` to verify Gateway's LangGraph-compatible API route is reachable.
2. A `401` response is acceptable when authentication is enabled and no session cookie is provided.
---
@@ -373,7 +371,6 @@ curl http://localhost:2026/health
- `deer-flow-nginx`
- `deer-flow-frontend`
- `deer-flow-gateway`
- `deer-flow-langgraph` (if not in gateway mode)
---
@@ -406,10 +403,11 @@ curl http://localhost:2026/health
---
#### 5.2.4 Check LangGraph Service
#### 5.2.4 Check LangGraph-compatible API
**Steps**:
1. Visit relevant LangGraph endpoints to verify availability
1. Visit `http://localhost:2026/api/langgraph/assistants/lead_agent` to verify Gateway's LangGraph-compatible API route is reachable.
2. A `401` response is acceptable when authentication is enabled and no session cookie is provided.
---
@@ -254,7 +254,6 @@ Processes exit quickly after running `make dev-daemon`.
**Solutions**:
1. Check log files:
```bash
tail -f logs/langgraph.log
tail -f logs/gateway.log
tail -f logs/frontend.log
tail -f logs/nginx.log
@@ -367,24 +366,7 @@ Errors appear in `gateway.log`.
uv sync
```
4. Confirm that the LangGraph service is running normally (if not in gateway mode)
---
### Issue: LangGraph Fails to Start
**Symptoms**:
Errors appear in `langgraph.log`.
**Solutions**:
1. Check LangGraph logs:
```bash
tail -f logs/langgraph.log
```
2. Check config.yaml
3. Check whether Python dependencies are complete
4. Confirm that port 2024 is not occupied
4. Confirm that the Gateway process is running normally.
---
@@ -519,7 +501,7 @@ Accessing `/health` returns an error or times out.
2. Confirm that config.yaml exists and has valid formatting
3. Check whether Python dependencies are complete
4. Confirm that the LangGraph service is running normally
4. Confirm that the Gateway process is running normally.
**Solutions** (Docker mode):
1. Check gateway container logs:
@@ -529,7 +511,7 @@ Accessing `/health` returns an error or times out.
2. Confirm that config.yaml is mounted correctly
3. Check whether Python dependencies are complete
4. Confirm that the LangGraph service is running normally
4. Confirm that the Gateway process is running normally.
---
@@ -539,7 +521,7 @@ Accessing `/health` returns an error or times out.
#### View All Service Processes
```bash
ps aux | grep -E "(langgraph|uvicorn|next|nginx)" | grep -v grep
ps aux | grep -E "(uvicorn|next|nginx)" | grep -v grep
```
#### View Service Logs
@@ -548,7 +530,6 @@ ps aux | grep -E "(langgraph|uvicorn|next|nginx)" | grep -v grep
tail -f logs/*.log
# View specific service logs
tail -f logs/langgraph.log
tail -f logs/gateway.log
tail -f logs/frontend.log
tail -f logs/nginx.log
@@ -65,7 +65,7 @@ if ! command -v lsof >/dev/null 2>&1; then
echo " Install lsof and rerun this check"
all_passed=false
else
for port in 2026 3000 8001 2024; do
for port in 2026 3000 8001; do
if lsof -i :$port >/dev/null 2>&1; then
echo "⚠ Port $port is already in use:"
lsof -i :$port | head -2
@@ -54,7 +54,6 @@ echo "=========================================="
echo ""
echo "🌐 Access URL: http://localhost:2026"
echo "📋 View logs:"
echo " - logs/langgraph.log"
echo " - logs/gateway.log"
echo " - logs/frontend.log"
echo " - logs/nginx.log"
@@ -76,12 +76,11 @@ if [ "$mode" = "docker" ]; then
all_passed=false
fi
else
summary_hint="logs/{langgraph,gateway,frontend,nginx}.log"
summary_hint="logs/{gateway,frontend,nginx}.log"
print_step "1. Checking local service ports..."
check_listen_port "Nginx" 2026
check_listen_port "Frontend" 3000
check_listen_port "Gateway" 8001
check_listen_port "LangGraph" 2024
fi
echo ""
@@ -104,8 +103,8 @@ else
fi
echo ""
echo "5. Checking LangGraph service..."
check_http_status "LangGraph service" "http://localhost:2024/" "200|301|302|307|308|404"
echo "5. Checking LangGraph-compatible Gateway API..."
check_http_status "LangGraph-compatible Gateway API" "http://localhost:2026/api/langgraph/assistants/lead_agent" "200|401"
echo ""
echo "=========================================="
@@ -78,7 +78,7 @@
- [x] Container status - {{status_containers}}
- [x] Frontend service - {{status_frontend}}
- [x] API Gateway - {{status_api_gateway}}
- [x] LangGraph service - {{status_langgraph}}
- [x] LangGraph-compatible Gateway API - {{status_langgraph}}
**Phase Status**: {{stage5_status}}
@@ -147,7 +147,6 @@ Commit Message: {{git_commit_message}}
| deer-flow-nginx | {{nginx_status}} | {{nginx_uptime}} |
| deer-flow-frontend | {{frontend_status}} | {{frontend_uptime}} |
| deer-flow-gateway | {{gateway_status}} | {{gateway_uptime}} |
| deer-flow-langgraph | {{langgraph_status}} | {{langgraph_uptime}} |
---
@@ -80,7 +80,7 @@
- [x] Process status - {{status_processes}}
- [x] Frontend service - {{status_frontend}}
- [x] API Gateway - {{status_api_gateway}}
- [x] LangGraph service - {{status_langgraph}}
- [x] LangGraph-compatible Gateway API - {{status_langgraph}}
**Phase Status**: {{stage5_status}}
@@ -152,7 +152,7 @@ Commit Message: {{git_commit_message}}
| Nginx | {{nginx_status}} | {{nginx_endpoint}} |
| Frontend | {{frontend_status}} | {{frontend_endpoint}} |
| Gateway | {{gateway_status}} | {{gateway_endpoint}} |
| LangGraph | {{langgraph_status}} | {{langgraph_endpoint}} |
| Gateway LangGraph API | {{langgraph_status}} | {{langgraph_endpoint}} |
---
@@ -166,7 +166,7 @@ Commit Message: {{git_commit_message}}
### If the Test Fails
1. [ ] Review references/troubleshooting.md for common solutions
2. [ ] Check local logs: `logs/{langgraph,gateway,frontend,nginx}.log`
2. [ ] Check local logs: `logs/{gateway,frontend,nginx}.log`
3. [ ] Verify configuration file format and content
4. [ ] If needed, fully reset the environment: `make stop && make clean && make install && make dev-daemon`
+10 -5
View File
@@ -122,10 +122,14 @@ Blocking-IO runtime gate (`tests/blocking_io/`):
`tests/support/detectors/blocking_io_runtime.py`). Any sync blocking IO
call whose stack passes through DeerFlow business code while running on
the asyncio event loop raises `BlockingError` and fails the test.
- Two regression anchors live there: `test_skills_load.py` (locks the
- Regression anchors live there: `test_skills_load.py` (locks the
`asyncio.to_thread` offload around `LocalSkillStorage.load_skills`, fix
for #1917) and `test_sqlite_lifespan.py` (locks the offload around
SQLite path resolution plus `ensure_sqlite_parent_dir`, fix for #1912).
for #1917); `test_sqlite_lifespan.py` (locks the offload around
SQLite path resolution plus `ensure_sqlite_parent_dir`, fix for #1912);
`test_jsonl_run_event_store.py` (locks `JsonlRunEventStore`'s async
API offloading its file IO via `asyncio.to_thread`, fix #3084); and
`test_uploads_middleware.py` (locks `UploadsMiddleware.abefore_agent`
offloading the uploads-directory scan off the event loop).
- `test_gate_smoke.py` is a meta-test asserting the gate actually catches
unoffloaded blocking IO and that the `@pytest.mark.allow_blocking_io`
opt-out works.
@@ -277,6 +281,7 @@ CORS is same-origin by default when requests enter through nginx on port 2026. S
- When a persistent `RunStore` is configured, `get()` and `list_by_thread()` hydrate historical runs from the store. In-memory records win for the same `run_id` so task, abort, and stream-control state stays attached to active local runs.
- `cancel()` and `create_or_reject(..., multitask_strategy="interrupt"|"rollback")` persist interrupted status through `RunStore.update_status()`, matching normal `set_status()` transitions.
- Store-only hydrated runs are readable history. If the current worker has no in-memory task/control state for that run, cancellation APIs can return 409 because this worker cannot stop the task.
- `POST /wait` (both thread-scoped and `/api/runs/wait`) drains the stream bridge via `wait_for_run_completion()` instead of bare `await record.task`, so it honours the run's `on_disconnect` setting and cancels the background run on real client disconnect rather than returning a stale checkpoint (issue #3265).
Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runtime, all other `/api/*` → Gateway REST APIs.
@@ -342,7 +347,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
- **Cache invalidation**: Detects config file changes via mtime comparison
- **Transports**: stdio (command-based), SSE, HTTP
- **OAuth (HTTP/SSE)**: Supports token endpoint flows (`client_credentials`, `refresh_token`) with automatic token refresh + Authorization header injection
- **Runtime updates**: Gateway API saves to extensions_config.json; LangGraph detects via mtime
- **Runtime updates**: Gateway API saves to extensions_config.json; the Gateway-embedded runtime detects changes via mtime
### Skills System (`packages/harness/deerflow/skills/`)
@@ -369,7 +374,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
### IM Channels System (`app/channels/`)
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via the LangGraph Server.
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
+16
View File
@@ -173,6 +173,8 @@ def _extract_response_text(result: dict | list) -> str:
# Stop at the last human message — anything before it is a previous turn
if msg_type == "human":
if _is_hidden_human_control_message(msg):
continue
break
# Check for tool messages from ask_clarification (interrupt case)
@@ -313,6 +315,8 @@ def _extract_artifacts(result: dict | list) -> list[str]:
continue
# Stop at the last human message — anything before it is a previous turn
if msg.get("type") == "human":
if _is_hidden_human_control_message(msg):
continue
break
# Look for AI messages with present_files tool calls
if msg.get("type") == "ai":
@@ -325,6 +329,18 @@ def _extract_artifacts(result: dict | list) -> list[str]:
return artifacts
def _is_hidden_human_control_message(msg: Mapping[str, Any]) -> bool:
"""Return whether a human message is an internal control message hidden from UI."""
if msg.get("type") != "human":
return False
additional_kwargs = msg.get("additional_kwargs")
if not isinstance(additional_kwargs, Mapping):
return False
return additional_kwargs.get("hide_from_ui") is True
def _format_artifact_text(artifacts: list[str]) -> str:
"""Format artifact paths into a human-readable text block listing filenames."""
import posixpath
+3 -4
View File
@@ -276,10 +276,9 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
logger.info(f"MCP configuration updated and saved to: {config_path}")
# NOTE: No need to reload/reset cache here - LangGraph Server (separate process)
# will detect config file changes via mtime and reinitialize MCP tools automatically
# Reload the configuration and update the global cache
# Reload the Gateway configuration and update the global cache. The
# agent runtime lives in Gateway, so this keeps API reads and tool
# execution aligned after extensions_config.json changes.
reloaded_config = reload_extensions_config()
servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in reloaded_config.mcp_servers.items()}
return McpConfigResponse(mcp_servers=servers)
+16 -16
View File
@@ -7,7 +7,6 @@ is reused so that conversation history is preserved across calls.
from __future__ import annotations
import asyncio
import logging
import uuid
@@ -17,7 +16,7 @@ from fastapi.responses import StreamingResponse
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.routers.thread_runs import RunCreateRequest
from app.gateway.services import sse_consumer, start_run
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
from deerflow.runtime import serialize_channel_values
logger = logging.getLogger(__name__)
@@ -66,24 +65,25 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
Otherwise a new temporary thread is created.
"""
thread_id = _resolve_thread_id(body)
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = await start_run(body, thread_id, request)
completed = True
if record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
completed = await wait_for_run_completion(bridge, record, request, run_mgr)
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
if completed:
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}
+22 -16
View File
@@ -21,7 +21,7 @@ from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.services import sse_consumer, start_run
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
logger = logging.getLogger(__name__)
@@ -175,24 +175,25 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until it completes, returning the final state."""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = await start_run(body, thread_id, request)
completed = True
if record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
completed = await wait_for_run_completion(bridge, record, request, run_mgr)
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
if completed:
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}
@@ -277,7 +278,12 @@ async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingRe
)
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
# Register GET and POST as separate routes so each method gets a unique OpenAPI
# operationId. ``api_route(methods=["GET", "POST"])`` shares one route registration
# across both methods, which makes FastAPI emit the same ``operationId`` twice and
# warn about a duplicate operation id during OpenAPI generation.
@router.get("/{thread_id}/runs/{run_id}/stream", response_model=None)
@router.post("/{thread_id}/runs/{run_id}/stream", response_model=None)
@require_permission("runs", "read", owner_check=True)
async def stream_existing_run(
thread_id: str,
+48
View File
@@ -402,3 +402,51 @@ async def sse_consumer(
if record.status in (RunStatus.pending, RunStatus.running):
if record.on_disconnect == DisconnectMode.cancel:
await run_mgr.cancel(record.run_id)
async def wait_for_run_completion(
bridge: StreamBridge,
record: RunRecord,
request: Request,
run_mgr: RunManager,
) -> bool:
"""Block until the run publishes ``END_SENTINEL``, honouring on_disconnect.
The non-streaming ``/wait`` endpoints used to ``await record.task``
directly with no disconnect handling. When the client (or an
intermediate HTTP proxy) timed out during a long tool call such as
``pip install``, the handler would swallow ``CancelledError`` and
serialize whatever checkpoint happened to exist masking a half-finished
run as a normal completion (issue #3265).
This helper consumes the same bridge that ``sse_consumer`` does so the
wait path shares its disconnect semantics: each wake-up polls
``request.is_disconnected()``; on a real disconnect it cancels the
background run when ``record.on_disconnect`` is ``cancel``. The bridge's
heartbeat sentinels guarantee at least one wake-up per
``heartbeat_interval`` even when the agent emits no events for a while.
Returns:
``True`` when ``END_SENTINEL`` was observed (run reached a terminal
state), ``False`` when the loop exited because the client
disconnected. Callers must skip checkpoint serialization on
``False`` so a partial checkpoint is not returned as a normal
response.
"""
completed = False
try:
async for entry in bridge.subscribe(record.run_id):
# END_SENTINEL means the run reached a terminal state; honour it
# even if the client just disconnected so the caller still serializes
# the real final checkpoint.
if entry is END_SENTINEL:
completed = True
return True
if await request.is_disconnected():
break
# Heartbeats and regular events: keep waiting for END_SENTINEL.
return completed
finally:
if not completed and record.status in (RunStatus.pending, RunStatus.running):
if record.on_disconnect == DisconnectMode.cancel:
await run_mgr.cancel(record.run_id)
+30 -51
View File
@@ -4,10 +4,12 @@
| 模式 | 启动命令 | Auth 层 | 端口 |
|------|---------|---------|------|
| 标准模式 | `make dev` | Gateway AuthMiddleware + LangGraph auth | 2026 (nginx) |
| Gateway 模式 | `make dev-pro` | Gateway AuthMiddleware(全量) | 2026 (nginx) |
| 标准模式 | `make dev` | Gateway AuthMiddleware(全量) | 2026 (nginx) |
| 直连 Gateway | `cd backend && make gateway` | Gateway AuthMiddleware | 8001 |
| 直连 LangGraph | `cd backend && make dev` | LangGraph auth | 2024 |
| 直连 LangGraph 兼容性 | 手动运行 LangGraph 工具链时使用 | LangGraph auth | 2024 |
`make dev`、Docker dev 和生产部署默认都运行 Gateway embedded runtime。
`app.gateway.langgraph_auth` 仅用于保留的直连 LangGraph 工具链 / Studio 兼容性测试,不是标准服务启动路径。
每种模式下都需执行以下测试。
@@ -21,10 +23,8 @@
# 清除已有数据
rm -f backend/.deer-flow/data/deerflow.db
# 选择模式启动
make dev # 标准模式
# 或
make dev-pro # Gateway 模式
# 启动标准模式(Gateway embedded runtime
make dev
```
**验证点:**
@@ -57,7 +57,7 @@ make dev
## 二、接口流程测试
> 以下用 `BASE=http://localhost:2026` 为例。标准模式和 Gateway 模式都用此地址。
> 以下用 `BASE=http://localhost:2026` 为例。标准模式经 nginx 暴露此地址。
> 直连测试替换为对应端口。
>
> **CSRF token 提取**:多处用到从 cookie jar 提取 CSRF token,统一使用:
@@ -211,20 +211,18 @@ curl -s -X POST $BASE/api/threads/search \
**预期:** 返回 0 或仅包含 user2 自己的 thread
### 2.3 标准模式 LangGraph Server 隔离
### 2.3 LangGraph-compatible Gateway 路由隔离
> 仅在标准模式下测试。Gateway 模式不跑 LangGraph Server。
#### TC-API-10: LangGraph 端点需要 cookie
#### TC-API-10: LangGraph-compatible 端点需要 cookie
```bash
# 不带 cookie 访问 LangGraph 接口
# 不带 cookie 访问 LangGraph-compatible 接口
curl -s -w "%{http_code}" $BASE/api/langgraph/threads
```
**预期:** 401
#### TC-API-11: LangGraph 带 cookie 可访问
#### TC-API-11: LangGraph-compatible 路由带 cookie 可访问
```bash
curl -s $BASE/api/langgraph/threads -b user1.txt | jq length
@@ -232,10 +230,10 @@ curl -s $BASE/api/langgraph/threads -b user1.txt | jq length
**预期:** 200,返回 user1 的 thread 列表
#### TC-API-12: LangGraph 隔离 — 用户只看到自己的
#### TC-API-12: LangGraph-compatible 路由隔离 — 用户只看到自己的
```bash
# user2 查 LangGraph threads
# user2 查 threads
curl -s $BASE/api/langgraph/threads -b user2.txt | jq length
```
@@ -1234,21 +1232,11 @@ P2=$(awk -F': ' '/^password:/ {print $2}' /tmp/deerflow-reset-p2.txt)
## 七、模式差异测试
> 以下用 `GW=http://localhost:8001` 表示直连 Gateway`BASE=http://localhost:2026` 表示经 nginx。
> Gateway 模式启动命令:`make dev-pro`(或 `./scripts/serve.sh --dev --gateway`)。
> 标准启动命令:`make dev`(或 `./scripts/serve.sh --dev`)。
### 7.1 标准模式独有
### 7.1 标准启动模式
> 启动命令:`make dev`(或 `./scripts/serve.sh --dev`
#### TC-MODE-01: LangGraph Server 独立运行,需 cookie
```bash
# 无 cookie 访问 LangGraph
curl -s -w "%{http_code}" -o /dev/null $BASE/api/langgraph/threads/search
# 预期: 403LangGraph auth handler 拒绝)
```
#### TC-MODE-02: LangGraph auth 的 token_version 检查
#### TC-MODE-01: Gateway AuthMiddleware 的 token_version 检查
```bash
# 登录拿 cookie
@@ -1261,9 +1249,9 @@ curl -s -X POST $BASE/api/v1/auth/change-password \
-b cookies.txt -H "Content-Type: application/json" -H "X-CSRF-Token: $CSRF" \
-d '{"current_password":"正确密码","new_password":"NewPass1!"}' -c new_cookies.txt
# 用旧 cookie 访问 LangGraph
# 用旧 cookie 访问 LangGraph-compatible 路由
curl -s -w "%{http_code}" $BASE/api/langgraph/threads/search -b cookies.txt
# 预期: 403token_version 不匹配)
# 预期: 401token_version 不匹配)
# 用新 cookie 访问
CSRF2=$(grep csrf_token new_cookies.txt | awk '{print $NF}')
@@ -1272,7 +1260,7 @@ curl -s -w "%{http_code}" -X POST $BASE/api/langgraph/threads/search \
# 预期: 200
```
#### TC-MODE-03: LangGraph auth 的 owner filter 隔离
#### TC-MODE-02: Gateway owner filter 隔离
```bash
# user1 创建 thread
@@ -1297,18 +1285,9 @@ print('OK: user2 sees', len(threads), 'threads, none belong to user1')
"
```
### 7.2 Gateway 模式独有
> 启动命令:`make dev-pro`(或 `./scripts/serve.sh --dev --gateway`
> 无 LangGraph Server 进程,agent runtime 嵌入 Gateway。
#### TC-MODE-04: 所有请求经 AuthMiddleware
#### TC-MODE-03: 所有请求经 AuthMiddleware
```bash
# 确认 LangGraph Server 未运行
curl -s -w "%{http_code}" -o /dev/null http://localhost:2024/ok
# 预期: 000(连接被拒)
# Gateway API 受保护
curl -s -w "%{http_code}" -o /dev/null $BASE/api/models
# 预期: 401
@@ -1319,7 +1298,7 @@ curl -s -w "%{http_code}" -o /dev/null -X POST $BASE/api/langgraph/threads/searc
# 预期: 401
```
#### TC-MODE-05: Gateway 模式下完整 auth 流程
#### TC-MODE-04: 标准模式下完整 auth 流程
```bash
# 登录
@@ -1334,7 +1313,7 @@ curl -s -X POST $BASE/api/langgraph/threads \
-d '{"metadata":{}}' | python3 -c "import sys,json; print(json.load(sys.stdin)['thread_id'])"
# 预期: 返回 thread_id
# CSRF 保护(Gateway 模式下 CSRFMiddleware 直接覆盖所有路由)
# CSRF 保护(CSRFMiddleware 覆盖所有 Gateway 路由)
curl -s -w "%{http_code}" -o /dev/null -X POST $BASE/api/langgraph/threads \
-b cookies.txt -H "Content-Type: application/json" -d '{"metadata":{}}'
# 预期: 403CSRF token missing
@@ -1433,7 +1412,7 @@ done
### 7.4 Docker 部署
> 启动命令:`./scripts/deploy.sh`(标准)或 `./scripts/deploy.sh --gateway`Gateway 模式)
> 启动命令:`./scripts/deploy.sh`
> Docker Compose 文件:`docker/docker-compose.yaml`
>
> 前置条件:
@@ -1542,16 +1521,16 @@ docker logs deer-flow-gateway 2>&1 | grep -iE "Password: .{15,}" && echo "FAIL:
- 容器日志输出**路径**(不是密码本身),符合 CodeQL `py/clear-text-logging-sensitive-data` 规则
- `grep "Password:"` 在日志中**应当无匹配**(旧行为已废弃,simplify pass 移除了日志泄露路径)
#### TC-DOCKER-06: Gateway 模式 Docker 部署
#### TC-DOCKER-06: Docker 部署
```bash
# Gateway 模式:无 langgraph 容器
./scripts/deploy.sh --gateway
# 标准 Docker 模式:runtime 嵌入 gateway 容器
./scripts/deploy.sh
sleep 15
# 确认 langgraph 容器存在
docker ps --filter name=deer-flow-langgraph --format '{{.Names}}' | wc -l
# 预期: 0
# 确认 gateway 容器存在
docker ps --filter name=deer-flow-gateway --format '{{.Names}}'
# 预期: deer-flow-gateway
# auth 流程正常:未登录受保护接口返回 401
curl -s -w "%{http_code}" -o /dev/null $BASE/api/models
+154
View File
@@ -0,0 +1,154 @@
# Blocking IO detection usage and maintenance
This document describes how to use and maintain DeerFlow backend blocking-IO
detection for async event-loop safety.
The goal is narrow: find and prevent synchronous IO from blocking backend
async event-loop paths. Static and runtime detection are complementary, but
they have different jobs.
## Static detector
The static detector is the discovery tool. It scans backend source code and
reports candidate blocking-IO call sites that may need human review.
Run it from the repository root:
```bash
make detect-blocking-io
```
Or from `backend/`:
```bash
make detect-blocking-io
```
The report is written to:
```text
.deer-flow/blocking-io-findings.json
```
Use this output for review and triage. A static finding is a candidate, not
proof that production blocks the event loop at runtime. The current static
rules are intentionally broad; prefer triaging existing output before adding
new static rules.
Add a static rule only when review finds a recurring high-risk blocking
pattern that is invisible to the current detector.
## Runtime detector
The runtime detector is the CI regression guard. It uses Blockbuster to fail a
focused test when code under `app.*` or `deerflow.*` performs blocking IO on
the asyncio event-loop thread.
Run it from `backend/`:
```bash
make test-blocking-io
```
The runtime gate starts from confirmed production bugs and protects those
paths from regressing. It does not prove that the entire backend is free of
blocking IO; it only covers the production paths exercised by
`backend/tests/blocking_io/`.
## Maintenance workflow
Use the static detector to find candidates, then use review to decide which
async production paths are worth protecting in CI.
The normal workflow is:
1. Run the static detector to find backend blocking-IO candidates.
2. Use human review to pick high-risk production async paths.
3. Add or update a focused runtime anchor in `backend/tests/blocking_io/`.
4. Let CI prevent that path from regressing.
Runtime detection has two maintenance paths.
### Add a runtime rule
Add a runtime rule when Blockbuster's default rules do not cover a generic
blocking primitive used by production code.
Rules belong in:
```text
backend/tests/support/detectors/blocking_io_runtime.py
```
Add them to `_PROJECT_BLOCKING_RULES`, not directly inside individual tests.
Keeping rules centralized makes it clear which extra primitives DeerFlow
expects Blockbuster to catch.
Example shape:
```python
import subprocess
from blockbuster import BlockBusterFunction
_PROJECT_BLOCKING_RULES = (
(
"subprocess.Popen.__init__",
BlockBusterFunction(
subprocess.Popen,
"__init__",
scanned_modules=["app", "deerflow"],
),
),
)
```
Do not add a runtime rule just because a business path is not tested. A rule
only expands what Blockbuster can intercept after code runs.
### Add a runtime anchor
Add a runtime anchor when a high-risk async production path should be protected
by CI but no existing `backend/tests/blocking_io/` test executes it.
Anchors belong in:
```text
backend/tests/blocking_io/
```
A good anchor should:
- Call the real production async entry point.
- Avoid bypassing the blocking surface with test-only `asyncio.to_thread`
wrappers.
- Use real local filesystem inputs when the bug shape is filesystem IO.
- Mock only the external dependency boundary, such as a network service or
third-party saver class.
- Fail if a future change moves the blocking operation back onto the event
loop.
Avoid testing only the low-level helper unless that helper is the production
async entry point. The runtime gate is most useful when it protects the caller
that production actually executes.
## Current runtime coverage
The runtime anchors protect confirmed blocking-IO bug shapes:
- SQLite checkpointer setup, including path resolution and parent-directory
creation.
- Subagent skill metadata loading through `SubagentExecutor._load_skills()`.
- `JsonlRunEventStore` async API (`put` / `list_*` / `delete_*`): the JSONL
run-event backend offloads its synchronous file IO via `asyncio.to_thread`
(fix #3084); this anchor drives the real async API under the gate so any
blocking IO reintroduced on the loop fails, not only removal of one
`to_thread` call.
- `UploadsMiddleware.before_agent` uploads-directory scan: a sync-only middleware
hook runs on the event loop under async graph execution, so the scan is
offloaded via `abefore_agent` + `run_in_executor`.
- Gate health checks: Blockbuster catches unoffloaded calls, opt-out works, and
patches are restored after exceptions.
As static detection and review identify more high-risk async paths, add new
runtime anchors incrementally.
+33
View File
@@ -36,6 +36,7 @@ models:
- OpenAI (`langchain_openai:ChatOpenAI`)
- Anthropic (`langchain_anthropic:ChatAnthropic`)
- DeepSeek (`langchain_deepseek:ChatDeepSeek`)
- Xiaomi MiMo (`deerflow.models.patched_mimo:PatchedChatMiMo`)
- Claude Code OAuth (`deerflow.models.claude_provider:ClaudeChatModel`)
- Codex CLI (`deerflow.models.openai_codex_provider:CodexChatModel`)
- Any LangChain-compatible provider
@@ -166,6 +167,37 @@ models:
For Gemini accessed **without** thinking (e.g. via OpenRouter where thinking is not activated), the plain `langchain_openai:ChatOpenAI` with `supports_thinking: false` is sufficient and no patch is needed.
**MiMo with thinking via OpenAI-compatible API**:
MiMo returns `reasoning_content` on assistant messages in thinking mode. In multi-turn agent conversations with tool calls, subsequent requests must preserve that historical `reasoning_content` on assistant messages or the MiMo API can return HTTP 400. Standard `langchain_openai:ChatOpenAI` drops this provider-specific field, so use `deerflow.models.patched_mimo:PatchedChatMiMo`:
For pay-as-you-go API keys (`sk-...`), use `https://api.xiaomimimo.com/v1`. For Token Plan keys (`tp-...`), use the regional Token Plan Base URL shown in the MiMo console, such as `https://token-plan-cn.xiaomimimo.com/v1`. MiMo documents these key types as separate and non-interchangeable.
`PatchedChatMiMo` is model-id agnostic. Use it for every MiMo thinking model entry you configure, including model entries referenced by `subagents.*.model` overrides (for example `mimo-v2.5-pro`, `mimo-v2.5`, `mimo-v2-pro`, `mimo-v2-omni`, or `mimo-v2-flash`).
```yaml
models:
- name: mimo-v2.5-pro
display_name: MiMo V2.5 Pro
use: deerflow.models.patched_mimo:PatchedChatMiMo
model: mimo-v2.5-pro
api_key: $MIMO_API_KEY
base_url: https://api.xiaomimimo.com/v1
max_tokens: 8192
supports_thinking: true
supports_vision: false
when_thinking_enabled:
extra_body:
thinking:
type: enabled
when_thinking_disabled:
extra_body:
thinking:
type: disabled
```
`PatchedChatMiMo` preserves MiMo's `choices[].message.reasoning_content`, streaming `delta.reasoning_content`, and request-history assistant `reasoning_content` fields. It does not reuse the DeepSeek provider.
### Tool Groups
Organize tools into logical groups:
@@ -319,6 +351,7 @@ models:
- `OPENAI_API_KEY` - OpenAI API key
- `ANTHROPIC_API_KEY` - Anthropic API key
- `DEEPSEEK_API_KEY` - DeepSeek API key
- `MIMO_API_KEY` - Xiaomi MiMo API key
- `NOVITA_API_KEY` - Novita API key (OpenAI-compatible endpoint)
- `TAVILY_API_KEY` - Tavily search API key
- `DEER_FLOW_PROJECT_ROOT` - Project root for relative runtime paths
+1 -1
View File
@@ -26,7 +26,7 @@
- Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search)
- [x] Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
- Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O
- For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker)
- For production: tune Gateway worker/runtime settings for long-running agent workloads
## Resolved Issues
@@ -227,6 +227,110 @@ def _extract_text(content: Any) -> str:
return str(content)
_REQUIRED_MEMORY_UPDATE_TOP_LEVEL_KEYS = frozenset({"user", "history", "newFacts", "factsToRemove"})
def _normalize_memory_update_fact(fact: Any) -> dict[str, Any] | None:
"""Normalize a single fact entry from a model-produced memory update."""
if not isinstance(fact, dict):
return None
raw_content = fact.get("content")
if not isinstance(raw_content, str):
return None
content = raw_content.strip()
if not content:
return None
raw_category = fact.get("category")
category = raw_category.strip() if isinstance(raw_category, str) and raw_category.strip() else "context"
raw_confidence = fact.get("confidence", 0.5)
if isinstance(raw_confidence, bool):
return None
if isinstance(raw_confidence, str):
raw_confidence = raw_confidence.strip()
if not raw_confidence:
return None
try:
raw_confidence = float(raw_confidence)
except ValueError:
return None
elif isinstance(raw_confidence, (int, float)):
raw_confidence = float(raw_confidence)
else:
return None
if not math.isfinite(raw_confidence):
return None
normalized_fact = {
"content": content,
"category": category,
"confidence": raw_confidence,
}
source_error = fact.get("sourceError")
if isinstance(source_error, str):
normalized_source_error = source_error.strip()
if normalized_source_error:
normalized_fact["sourceError"] = normalized_source_error
return normalized_fact
def _normalize_memory_update_data(update_data: dict[str, Any]) -> dict[str, Any]:
"""Coerce parsed memory update data into the shape consumed by _apply_updates."""
user = update_data.get("user")
history = update_data.get("history")
new_facts = update_data.get("newFacts")
facts_to_remove = update_data.get("factsToRemove")
normalized_facts_to_remove = [fact_id for fact_id in facts_to_remove if isinstance(fact_id, str)] if isinstance(facts_to_remove, list) else []
normalized_new_facts = []
dropped_new_fact = not isinstance(new_facts, list)
if isinstance(new_facts, list):
for fact in new_facts:
normalized_fact = _normalize_memory_update_fact(fact)
if normalized_fact is not None:
normalized_new_facts.append(normalized_fact)
else:
dropped_new_fact = True
if normalized_facts_to_remove and dropped_new_fact:
raise json.JSONDecodeError(
"Unsafe partial memory update: factsToRemove with malformed newFacts",
json.dumps(update_data, ensure_ascii=False),
0,
)
return {
"user": user if isinstance(user, dict) else {},
"history": history if isinstance(history, dict) else {},
"newFacts": normalized_new_facts,
"factsToRemove": normalized_facts_to_remove,
}
def _parse_memory_update_response(response_content: Any) -> dict[str, Any]:
"""Parse the first valid memory-update JSON object from an LLM response.
Some providers may wrap JSON in thinking traces, prose, or markdown fences
even when prompted to return JSON only. This parser accepts safely
extractable JSON objects but does not repair truncated or malformed JSON.
"""
response_text = _extract_text(response_content).strip()
decoder = json.JSONDecoder()
for match in re.finditer(r"\{", response_text):
try:
parsed, _end = decoder.raw_decode(response_text[match.start() :])
except json.JSONDecodeError:
continue
if isinstance(parsed, dict) and _REQUIRED_MEMORY_UPDATE_TOP_LEVEL_KEYS.issubset(parsed):
return _normalize_memory_update_data(parsed)
raise json.JSONDecodeError("No valid memory update JSON object found", response_text, 0)
# Matches sentences that describe a file-upload *event* rather than general
# file-related work. Deliberately narrow to avoid removing legitimate facts
# such as "User works with CSV files" or "prefers PDF export".
@@ -353,13 +457,7 @@ class MemoryUpdater:
user_id: str | None = None,
) -> bool:
"""Parse the model response, apply updates, and persist memory."""
response_text = _extract_text(response_content).strip()
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
update_data = _parse_memory_update_response(response_content)
# Deep-copy before in-place mutation so a subsequent save() failure
# cannot corrupt the still-cached original object reference.
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
@@ -26,6 +26,11 @@ from langchain_core.messages import ToolMessage
logger = logging.getLogger(__name__)
# Workaround for issue #2894: malformed write_file calls can carry huge Markdown
# payloads in invalid tool-call args. Keep recovery error details short so the
# synthetic ToolMessage does not echo large or malformed content back to the model.
_MAX_RECOVERY_ERROR_DETAIL_LEN = 500
class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation.
@@ -98,9 +103,25 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
@staticmethod
def _synthetic_tool_message_content(tool_call: dict) -> str:
if tool_call.get("invalid"):
name = tool_call.get("name")
error = tool_call.get("error")
if isinstance(error, str) and error:
return f"[Tool call could not be executed because its arguments were invalid: {error}]"
error_text = error[:_MAX_RECOVERY_ERROR_DETAIL_LEN] if isinstance(error, str) and error else ""
# Workaround for issue #2894: malformed write_file calls can carry huge Markdown
# payloads in invalid tool-call args. Keep recovery guidance actionable without
# echoing large or malformed content back to the model.
if name == "write_file":
details = f" Parser error: {error_text}" if error_text else ""
return (
"[write_file failed before execution: the tool-call arguments were not valid JSON, "
"so no file was written. This often happens when the model tries to write a very "
"large Markdown file in a single tool call, especially when `content` contains "
"unescaped quotes, inline JSON, backslashes, or code fences. Do not retry the same "
"large `write_file` payload for this artifact; provide the report/content directly "
"as normal assistant text in your next response. If a file write is still needed "
f"later, split the file into smaller sections instead of one large payload.{details}]"
)
if error_text:
return f"[Tool call could not be executed because its arguments were invalid: {error_text}]"
return "[Tool call could not be executed because its arguments were invalid.]"
return "[Tool call was interrupted and did not return a result.]"
@@ -77,9 +77,11 @@ def _build_runtime_middlewares(
"""Build shared base middlewares for agent execution."""
from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from deerflow.agents.middlewares.tool_output_budget_middleware import ToolOutputBudgetMiddleware
from deerflow.sandbox.middleware import SandboxMiddleware
middlewares: list[AgentMiddleware] = [
ToolOutputBudgetMiddleware.from_app_config(app_config),
ThreadDataMiddleware(lazy_init=lazy_init),
SandboxMiddleware(lazy_init=lazy_init),
]
@@ -87,7 +89,7 @@ def _build_runtime_middlewares(
if include_uploads:
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
middlewares.insert(1, UploadsMiddleware())
middlewares.insert(2, UploadsMiddleware())
if include_dangling_tool_call_patch:
from deerflow.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
@@ -0,0 +1,489 @@
"""Middleware that enforces a per-result budget on tool outputs.
Oversized tool results are persisted to disk and replaced with a compact
preview containing a file reference. When disk persistence is
unavailable the middleware falls back to head+tail truncation so the
model context is never blown by a single large tool return.
"""
from __future__ import annotations
import asyncio
import logging
import os
import uuid
from collections.abc import Awaitable, Callable
from dataclasses import replace as dc_replace
from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
from langchain_core.messages import ToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
from deerflow.config.tool_output_config import ToolOutputConfig
logger = logging.getLogger(__name__)
def _default_config() -> ToolOutputConfig:
return ToolOutputConfig()
# ---------------------------------------------------------------------------
# Text helpers
# ---------------------------------------------------------------------------
def _message_text(content: Any) -> str | None:
"""Extract a plain-text representation from a ToolMessage content field.
Returns ``None`` for non-string / multimodal content so the caller
can skip budget enforcement (images, structured blocks, etc.).
"""
if isinstance(content, str):
return content
if content is None:
return None
if isinstance(content, list):
pieces: list[str] = []
for part in content:
if isinstance(part, str):
pieces.append(part)
elif isinstance(part, dict) and isinstance(part.get("text"), str):
pieces.append(part["text"])
else:
return None
return "\n".join(pieces) if pieces else None
return None
def _snap_to_line_boundary(text: str, pos: int) -> int:
"""Return *pos* or the nearest preceding newline+1, whichever is closer.
Used so that previews and truncations end on a complete line when
possible. If no newline exists in the second half of ``text[:pos]``
the original *pos* is returned unchanged.
"""
if pos <= 0 or pos >= len(text):
return pos
half = pos // 2
nl = text.rfind("\n", half, pos)
if nl >= 0:
return nl + 1
return pos
# ---------------------------------------------------------------------------
# Disk persistence
# ---------------------------------------------------------------------------
_EXT_MAP: dict[str, str] = {
"bash": "log",
"bash_tool": "log",
"web_fetch": "log",
}
def _sanitize_tool_name(name: str) -> str:
"""Strip path separators and traversal components from a tool name."""
base = os.path.basename(name)
safe = base.replace("..", "").replace("/", "_").replace("\\", "_")
return safe or "unknown"
def _externalize(
content: str,
*,
tool_name: str,
tool_call_id: str,
outputs_path: str,
storage_subdir: str,
) -> str | None:
"""Write *content* to disk and return the virtual path, or ``None`` on failure."""
if os.path.isabs(storage_subdir) or ".." in storage_subdir:
return None
storage_dir = os.path.join(outputs_path, storage_subdir)
try:
os.makedirs(storage_dir, exist_ok=True)
except OSError:
return None
safe_name = _sanitize_tool_name(tool_name)
ext = _EXT_MAP.get(tool_name, "txt")
short_id = uuid.uuid4().hex[:12]
filename = f"{safe_name}-{short_id}.{ext}"
filepath = os.path.join(storage_dir, filename)
if not os.path.abspath(filepath).startswith(os.path.abspath(storage_dir)):
return None
try:
with open(filepath, "w", encoding="utf-8") as f:
f.write(content)
except OSError:
return None
virtual_base = "/mnt/user-data/outputs"
return f"{virtual_base}/{storage_subdir}/{filename}"
# ---------------------------------------------------------------------------
# Preview / fallback builders
# ---------------------------------------------------------------------------
def _build_preview(
content: str,
*,
tool_name: str,
virtual_path: str,
head_chars: int,
tail_chars: int,
) -> str:
"""Build a preview with a file reference for externalized output."""
total = len(content)
head_end = _snap_to_line_boundary(content, min(head_chars, total))
tail_start = max(head_end, total - tail_chars)
tail_start_snapped = _snap_to_line_boundary(content, tail_start)
if tail_start_snapped > head_end:
tail_start = tail_start_snapped
head = content[:head_end]
tail = content[tail_start:] if tail_start < total else ""
omitted = total - len(head) - len(tail)
ref = f"\n\n[Full {tool_name} output saved to {virtual_path} ({total} chars, ~{total // 4} tokens). Use read_file with start_line and end_line to access specific sections. {omitted} chars omitted from this preview.]\n\n"
parts = [head, ref]
if tail:
parts.append(tail)
return "".join(parts)
def _build_fallback(
content: str,
*,
tool_name: str,
max_chars: int,
head_chars: int,
tail_chars: int,
) -> str:
"""Build a head+tail truncation when disk persistence is unavailable.
The returned string is guaranteed to be no longer than *max_chars*.
"""
total = len(content)
if max_chars <= 0 or total <= max_chars:
return content
marker_template = "\n\n[... {n} chars omitted from {tn} output. Persistent storage unavailable. Consider narrowing the query or using more specific parameters.]\n\n"
marker_overhead = len(marker_template.format(n=total, tn=tool_name))
if marker_overhead >= max_chars:
return content[:max_chars]
budget = max_chars - marker_overhead
effective_head = min(head_chars, budget)
effective_tail = min(tail_chars, max(0, budget - effective_head))
head_end = _snap_to_line_boundary(content, min(effective_head, total))
tail_start = max(head_end, total - effective_tail)
tail_start_snapped = _snap_to_line_boundary(content, tail_start)
if tail_start_snapped > head_end:
tail_start = tail_start_snapped
head = content[:head_end]
tail = content[tail_start:] if tail_start < total else ""
omitted = total - len(head) - len(tail)
marker = marker_template.format(n=omitted, tn=tool_name)
parts = [head, marker]
if tail:
parts.append(tail)
return "".join(parts)
# ---------------------------------------------------------------------------
# Core budget logic
# ---------------------------------------------------------------------------
def _resolve_outputs_path(request: ToolCallRequest) -> str | None:
"""Best-effort extraction of the thread outputs path."""
runtime = getattr(request, "runtime", None)
if runtime is None:
return None
state = getattr(runtime, "state", None)
if state is None:
return None
thread_data = state.get("thread_data")
if not isinstance(thread_data, dict):
return None
outputs_path = thread_data.get("outputs_path")
return outputs_path if isinstance(outputs_path, str) else None
def _budget_content(
content: str,
*,
tool_name: str,
tool_call_id: str,
outputs_path: str | None,
config: ToolOutputConfig,
) -> str | None:
"""Apply budget to *content*. Returns ``None`` if no change needed."""
threshold = config.tool_overrides.get(tool_name, config.externalize_min_chars)
if threshold <= 0 and config.fallback_max_chars <= 0:
return None
if len(content) <= threshold and len(content) <= config.fallback_max_chars:
return None
if threshold > 0 and len(content) > threshold and outputs_path:
virtual_path = _externalize(
content,
tool_name=tool_name,
tool_call_id=tool_call_id,
outputs_path=outputs_path,
storage_subdir=config.storage_subdir,
)
if virtual_path is not None:
logger.info(
"Externalized %s output (%d chars) to %s",
tool_name,
len(content),
virtual_path,
)
return _build_preview(
content,
tool_name=tool_name,
virtual_path=virtual_path,
head_chars=config.preview_head_chars,
tail_chars=config.preview_tail_chars,
)
if config.fallback_max_chars > 0 and len(content) > config.fallback_max_chars:
logger.warning(
"Fallback-truncating %s output: %d chars → %d max",
tool_name,
len(content),
config.fallback_max_chars,
)
return _build_fallback(
content,
tool_name=tool_name,
max_chars=config.fallback_max_chars,
head_chars=config.fallback_head_chars,
tail_chars=config.fallback_tail_chars,
)
return None
# ---------------------------------------------------------------------------
# Result patchers
# ---------------------------------------------------------------------------
def _patch_tool_message(msg: ToolMessage, config: ToolOutputConfig, outputs_path: str | None) -> ToolMessage:
"""Apply budget to a single ToolMessage. Returns the original if unchanged."""
tool_name = msg.name or "unknown"
if tool_name in config.exempt_tools:
return msg
text = _message_text(msg.content)
if text is None:
return msg
replacement = _budget_content(
text,
tool_name=tool_name,
tool_call_id=msg.tool_call_id or "",
outputs_path=outputs_path,
config=config,
)
if replacement is None:
return msg
update: dict[str, Any] = {"content": replacement}
if getattr(msg, "response_metadata", None):
update["response_metadata"] = dict(msg.response_metadata)
if getattr(msg, "additional_kwargs", None):
update["additional_kwargs"] = dict(msg.additional_kwargs)
return msg.model_copy(update=update)
def _effective_trigger(tool_name: str, config: ToolOutputConfig) -> int:
"""Smallest content length that could trigger budgeting for *tool_name*.
Mirrors the trigger conditions in :func:`_budget_content` (per-tool
externalize threshold OR global fallback), so the pre-scan never produces
a false negative. Returns ``-1`` when nothing could ever trigger.
"""
candidates: list[int] = []
externalize = config.tool_overrides.get(tool_name, config.externalize_min_chars)
if externalize > 0:
candidates.append(externalize)
if config.fallback_max_chars > 0:
candidates.append(config.fallback_max_chars)
return min(candidates) if candidates else -1
def _tool_message_over_budget(msg: ToolMessage, config: ToolOutputConfig) -> bool:
"""Cheap, per-tool-aware check: is this ToolMessage non-exempt and over its trigger?"""
if (msg.name or "") in config.exempt_tools:
return False
trigger = _effective_trigger(msg.name or "", config)
if trigger < 0:
return False
text = _message_text(msg.content)
return text is not None and len(text) > trigger
def _needs_budget(result: ToolMessage | Command, config: ToolOutputConfig) -> bool:
"""Fast check whether *result* could need budgeting (avoids thread offload for small outputs)."""
if isinstance(result, ToolMessage):
return _tool_message_over_budget(result, config)
update = getattr(result, "update", None)
if isinstance(update, dict):
for msg in update.get("messages", []):
if isinstance(msg, ToolMessage) and _tool_message_over_budget(msg, config):
return True
return False
def _patch_result(result: ToolMessage | Command, config: ToolOutputConfig, outputs_path: str | None) -> ToolMessage | Command:
"""Apply budget to a tool call result (ToolMessage or Command)."""
if isinstance(result, ToolMessage):
return _patch_tool_message(result, config, outputs_path)
update = getattr(result, "update", None)
if not isinstance(update, dict):
return result
messages = update.get("messages")
if not isinstance(messages, list):
return result
new_messages: list[Any] = []
changed = False
for msg in messages:
if isinstance(msg, ToolMessage):
patched = _patch_tool_message(msg, config, outputs_path)
if patched is not msg:
changed = True
new_messages.append(patched)
else:
new_messages.append(msg)
if not changed:
return result
return dc_replace(result, update={**update, "messages": new_messages})
def _patch_model_messages(messages: list[Any], config: ToolOutputConfig) -> list[Any] | None:
"""Apply budget to historical ToolMessages in a model request. Returns ``None`` if unchanged.
A cheap pre-scan bails out before allocating a new list when no historical
ToolMessage exceeds the budget the common case once every result has
already been budgeted at tool-call time, so a long history is not rebuilt
on every model call.
"""
if not any(isinstance(msg, ToolMessage) and _tool_message_over_budget(msg, config) for msg in messages):
return None
updated: list[Any] = []
changed = False
for msg in messages:
if isinstance(msg, ToolMessage):
patched = _patch_tool_message(msg, config, outputs_path=None)
if patched is not msg:
changed = True
updated.append(patched)
else:
updated.append(msg)
return updated if changed else None
# ---------------------------------------------------------------------------
# Middleware class
# ---------------------------------------------------------------------------
class ToolOutputBudgetMiddleware(AgentMiddleware[AgentState]):
"""Enforce per-result budget on tool outputs via externalization or truncation."""
def __init__(self, config: ToolOutputConfig | None = None) -> None:
super().__init__()
self._config = config if config is not None else _default_config()
@classmethod
def from_app_config(cls, app_config: Any) -> ToolOutputBudgetMiddleware:
tool_output = getattr(app_config, "tool_output", None)
if isinstance(tool_output, ToolOutputConfig):
return cls(config=tool_output)
return cls()
# -- tool call hooks ---------------------------------------------------
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
result = handler(request)
if not self._config.enabled:
return result
if not _needs_budget(result, self._config):
return result
outputs_path = _resolve_outputs_path(request)
return _patch_result(result, self._config, outputs_path)
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
result = await handler(request)
if not self._config.enabled:
return result
if not _needs_budget(result, self._config):
return result
outputs_path = _resolve_outputs_path(request)
return await asyncio.to_thread(_patch_result, result, self._config, outputs_path)
# -- model call hooks (historical message truncation) ------------------
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
if self._config.enabled:
messages = getattr(request, "messages", None)
if isinstance(messages, list):
patched = _patch_model_messages(messages, self._config)
if patched is not None:
request = request.override(messages=patched)
return handler(request)
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
if self._config.enabled:
messages = getattr(request, "messages", None)
if isinstance(messages, list):
patched = _patch_model_messages(messages, self._config)
if patched is not None:
request = request.override(messages=patched)
return await handler(request)
@@ -7,6 +7,7 @@ from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langchain_core.runnables import run_in_executor
from langgraph.runtime import Runtime
from deerflow.config.paths import Paths, get_paths
@@ -293,3 +294,16 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
"uploaded_files": new_files,
"messages": messages,
}
@override
async def abefore_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None:
"""Async hook that offloads the synchronous uploads scan off the event loop.
``before_agent`` performs blocking filesystem IO (directory enumeration,
``stat``, reading sibling ``.md`` outlines). When the graph runs async,
langgraph would otherwise execute the sync hook directly on the event
loop, so it is dispatched to a worker thread via ``run_in_executor``.
``run_in_executor`` copies the current context, so the ``user_id``
contextvar read by ``get_effective_user_id()`` is preserved.
"""
return await run_in_executor(None, self.before_agent, state, runtime)
@@ -119,7 +119,6 @@ class AioSandboxProvider(SandboxProvider):
port: 8080 # Base port for local containers
container_prefix: deer-flow-sandbox
idle_timeout: 600 # Idle timeout in seconds (0 to disable)
auto_restart: true # Restart crashed containers automatically
replicas: 3 # Max concurrent sandbox containers (LRU eviction when exceeded)
mounts: # Volume mounts for local containers
- host_path: /path/on/host
@@ -204,14 +203,12 @@ class AioSandboxProvider(SandboxProvider):
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
replicas = getattr(sandbox_config, "replicas", None)
auto_restart = getattr(sandbox_config, "auto_restart", True)
return {
"image": sandbox_config.image or DEFAULT_IMAGE,
"port": sandbox_config.port or DEFAULT_PORT,
"container_prefix": sandbox_config.container_prefix or DEFAULT_CONTAINER_PREFIX,
"idle_timeout": idle_timeout if idle_timeout is not None else DEFAULT_IDLE_TIMEOUT,
"auto_restart": auto_restart,
"replicas": replicas if replicas is not None else DEFAULT_REPLICAS,
"mounts": sandbox_config.mounts or [],
"environment": self._resolve_env_vars(sandbox_config.environment or {}),
@@ -774,57 +771,17 @@ class AioSandboxProvider(SandboxProvider):
def get(self, sandbox_id: str) -> Sandbox | None:
"""Get a sandbox by ID. Updates last activity timestamp.
When ``auto_restart`` is enabled (the default), the container's liveness
is verified on each lookup. If the underlying container has crashed, the
sandbox is evicted from all caches so that the next ``acquire()`` call will
transparently create a fresh container.
Args:
sandbox_id: The ID of the sandbox.
Returns:
The sandbox instance if found and alive, None otherwise.
The sandbox instance if found, None otherwise.
"""
with self._lock:
sandbox = self._sandboxes.get(sandbox_id)
if sandbox is None:
return None
self._last_activity[sandbox_id] = time.time()
auto_restart = self._config.get("auto_restart", True)
info = self._sandbox_infos.get(sandbox_id) if auto_restart else None
if not info:
return sandbox
if self._backend.is_alive(info):
return sandbox
info_to_destroy = None
with self._lock:
current_sandbox = self._sandboxes.get(sandbox_id)
current_info = self._sandbox_infos.get(sandbox_id)
if current_sandbox is None:
return None
if current_info is not info:
if sandbox is not None:
self._last_activity[sandbox_id] = time.time()
return current_sandbox
logger.warning(f"Sandbox {sandbox_id} container is not alive, evicting from cache for auto-restart")
self._sandboxes.pop(sandbox_id, None)
self._sandbox_infos.pop(sandbox_id, None)
self._last_activity.pop(sandbox_id, None)
self._warm_pool.pop(sandbox_id, None)
thread_ids = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids:
del self._thread_sandboxes[tid]
info_to_destroy = info
if info_to_destroy:
try:
self._backend.destroy(info_to_destroy)
except Exception as e:
logger.warning(f"Failed to cleanup dead sandbox {sandbox_id}: {e}")
return None
return sandbox
def release(self, sandbox_id: str) -> None:
"""Release a sandbox from active use into the warm pool.
@@ -30,6 +30,7 @@ from deerflow.config.summarization_config import SummarizationConfig, load_summa
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict
from deerflow.config.token_usage_config import TokenUsageConfig
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
from deerflow.config.tool_output_config import ToolOutputConfig
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
load_dotenv()
@@ -93,6 +94,7 @@ class AppConfig(BaseModel):
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
skill_evolution: SkillEvolutionConfig = Field(default_factory=SkillEvolutionConfig, description="Agent-managed skill evolution configuration")
extensions: ExtensionsConfig = Field(default_factory=ExtensionsConfig, description="Extensions configuration (MCP servers and skills state)")
tool_output: ToolOutputConfig = Field(default_factory=ToolOutputConfig, description="Tool output budget protection configuration")
tool_search: ToolSearchConfig = Field(default_factory=ToolSearchConfig, description="Tool search / deferred loading configuration")
title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration")
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
@@ -23,9 +23,6 @@ class SandboxConfig(BaseModel):
replicas: Maximum number of concurrent sandbox containers (default: 3). When the limit is reached the least-recently-used sandbox is evicted to make room.
container_prefix: Prefix for container names (default: deer-flow-sandbox)
idle_timeout: Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.
auto_restart: Automatically restart sandbox containers that have crashed (default: true). When a tool call
detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated
on the next acquire. Set to false to disable.
mounts: List of volume mounts to share directories with the container
environment: Environment variables to inject into the container (values starting with $ are resolved from host env)
"""
@@ -58,10 +55,6 @@ class SandboxConfig(BaseModel):
default=None,
description="Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.",
)
auto_restart: bool = Field(
default=True,
description="Automatically restart sandbox containers that have crashed. When a tool call detects the container is no longer alive, the sandbox is evicted from cache and transparently recreated on the next acquire.",
)
mounts: list[VolumeMountConfig] = Field(
default_factory=list,
description="List of volume mounts to share directories between host and container",
@@ -0,0 +1,62 @@
"""Configuration for tool output budget protection."""
from __future__ import annotations
from pydantic import BaseModel, Field
class ToolOutputConfig(BaseModel):
"""Config section for tool-result output budget enforcement.
When a tool returns more than ``externalize_min_chars`` characters,
the full output is persisted to disk and replaced with a compact
preview + file reference. If disk persistence is unavailable the
output falls back to head+tail truncation.
"""
enabled: bool = Field(
default=True,
description="Enable the tool output budget middleware.",
)
externalize_min_chars: int = Field(
default=12_000,
ge=0,
description="Character threshold to trigger disk externalization. Outputs below this pass through unchanged. Set to 0 to disable externalization (fallback truncation still applies when output exceeds fallback_max_chars).",
)
preview_head_chars: int = Field(
default=2_000,
ge=0,
description="Characters to keep from the head of the output in the preview.",
)
preview_tail_chars: int = Field(
default=1_000,
ge=0,
description="Characters to keep from the tail of the output in the preview.",
)
fallback_max_chars: int = Field(
default=30_000,
ge=0,
description="Maximum characters when disk persistence is unavailable. 0 disables fallback truncation.",
)
fallback_head_chars: int = Field(
default=8_000,
ge=0,
description="Head characters for fallback truncation.",
)
fallback_tail_chars: int = Field(
default=3_000,
ge=0,
description="Tail characters for fallback truncation.",
)
storage_subdir: str = Field(
default=".tool-results",
description="Subdirectory under the thread outputs path for persisted tool results.",
)
exempt_tools: list[str] = Field(
default_factory=lambda: ["read_file", "read_file_tool"],
description="Tool names exempt from budget enforcement (prevents persist→read→persist loops).",
)
tool_overrides: dict[str, int] = Field(
default_factory=dict,
description="Per-tool externalize_min_chars overrides. Keys are tool names, values are char thresholds. Use 0 to disable externalization for a specific tool.",
)
@@ -87,8 +87,7 @@ def get_cached_mcp_tools() -> list[BaseTool]:
Also checks if the config file has been modified since last initialization,
and re-initializes if needed. This ensures that changes made through the
Gateway API (which runs in a separate process) are reflected in the
LangGraph Server.
Gateway API are reflected in the Gateway-embedded LangGraph runtime.
Returns:
List of cached MCP tools.
+13 -4
View File
@@ -1,4 +1,4 @@
"""Load MCP tools using langchain-mcp-adapters with persistent sessions."""
"""Load MCP tools using langchain-mcp-adapters with stdio session pooling."""
from __future__ import annotations
@@ -173,8 +173,10 @@ def _make_session_pool_tool(
async def get_mcp_tools() -> list[BaseTool]:
"""Get all tools from enabled MCP servers.
Tools are wrapped with persistent-session logic so that consecutive
calls within the same thread reuse the same MCP session.
Tools using stdio transport are wrapped with persistent-session logic so
consecutive calls within the same thread reuse the same MCP session.
HTTP/SSE tools are returned unwrapped to avoid cross-task TaskGroup
cleanup errors.
Returns:
List of LangChain tools from all enabled MCP servers.
@@ -251,6 +253,9 @@ async def get_mcp_tools() -> list[BaseTool]:
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
# Wrap each tool with persistent-session logic.
# Only pool stdio sessions. HTTP/SSE transports use anyio TaskGroups
# internally which cannot be closed from a different async task, so
# pooling them causes RuntimeError on cleanup (see #3203).
wrapped_tools: list[BaseTool] = []
for tool in tools:
tool_server: str | None = None
@@ -260,7 +265,11 @@ async def get_mcp_tools() -> list[BaseTool]:
break
if tool_server is not None:
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
transport = servers_config[tool_server].get("transport", "stdio")
if transport == "stdio":
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
else:
wrapped_tools.append(tool)
else:
wrapped_tools.append(tool)
@@ -0,0 +1,124 @@
"""Helpers for replaying provider-specific assistant message fields.
Several provider adapters need to preserve fields that LangChain stores on the
original ``AIMessage`` but drops when serializing request payloads. This module
keeps the assistant-message matching logic shared while letting each provider
decide which fields to restore.
"""
from __future__ import annotations
import json
from collections.abc import Callable, Sequence
from typing import Any
from langchain_core.messages import AIMessage, BaseMessage
AssistantPayloadRestorer = Callable[[dict[str, Any], AIMessage], None]
def restore_assistant_payloads(
payload_messages: Sequence[dict[str, Any]],
original_messages: Sequence[BaseMessage],
restore: AssistantPayloadRestorer,
) -> None:
"""Restore provider-specific fields onto serialized assistant payloads."""
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
restore(payload_msg, orig_msg)
return
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
assistant_payloads = [m for m in payload_messages if m.get("role") == "assistant"]
used_ai_indexes: set[int] = set()
for ordinal, payload_msg in enumerate(assistant_payloads):
ai_msg = _match_ai_message(payload_msg, ai_messages, used_ai_indexes, ordinal)
if ai_msg is not None:
restore(payload_msg, ai_msg)
def restore_additional_kwargs_field(payload_msg: dict[str, Any], orig_msg: AIMessage, field_name: str) -> None:
"""Copy a provider-specific ``additional_kwargs`` field onto a payload message."""
value = orig_msg.additional_kwargs.get(field_name)
if value is not None:
payload_msg[field_name] = value
def restore_reasoning_content(payload_msg: dict[str, Any], orig_msg: AIMessage) -> None:
"""Copy provider reasoning content onto a serialized assistant payload."""
restore_additional_kwargs_field(payload_msg, orig_msg, "reasoning_content")
def _match_ai_message(
payload_msg: dict[str, Any],
ai_messages: Sequence[AIMessage],
used_ai_indexes: set[int],
fallback_ordinal: int,
) -> AIMessage | None:
payload_key = _assistant_signature(payload_msg)
if payload_key is not None:
matches = [index for index, ai_msg in enumerate(ai_messages) if index not in used_ai_indexes and _ai_signature(ai_msg) == payload_key]
if len(matches) == 1:
used_ai_indexes.add(matches[0])
return ai_messages[matches[0]]
fallback_index = _next_unused_index_at_or_after(len(ai_messages), used_ai_indexes, fallback_ordinal)
if fallback_index is not None:
used_ai_indexes.add(fallback_index)
return ai_messages[fallback_index]
return None
def _next_unused_index_at_or_after(count: int, used_ai_indexes: set[int], start: int) -> int | None:
"""Return the next unused AI index at or after ``start``.
Scanning forward from the payload's ordinal preserves the positional bias of
the previous behaviour while still recovering when serialization drops or
reorders messages so the exact ordinal index is already taken. It does not
wrap to earlier indexes because those messages may be represented by payload
entries that were already dropped.
"""
if count == 0 or start >= count:
return None
for index in range(start, count):
if index not in used_ai_indexes:
return index
return None
def _assistant_signature(payload_msg: dict[str, Any]) -> tuple[str, str] | None:
return _signature(
payload_msg.get("content"),
_tool_call_ids(payload_msg.get("tool_calls") or []),
)
def _ai_signature(message: AIMessage) -> tuple[str, str] | None:
tool_calls = message.tool_calls or message.additional_kwargs.get("tool_calls") or []
return _signature(message.content, _tool_call_ids(tool_calls))
def _signature(content: Any, tool_call_ids: tuple[str, ...]) -> tuple[str, str] | None:
if content in (None, "") and not tool_call_ids:
return None
return (_stable_repr(content), "|".join(tool_call_ids))
def _stable_repr(value: Any) -> str:
try:
return json.dumps(value, sort_keys=True, ensure_ascii=False)
except TypeError:
return repr(value)
def _tool_call_ids(tool_calls: Sequence[Any]) -> tuple[str, ...]:
ids: list[str] = []
for tool_call in tool_calls:
if isinstance(tool_call, dict):
call_id = tool_call.get("id")
if isinstance(call_id, str) and call_id:
ids.append(call_id)
return tuple(ids)
@@ -10,9 +10,10 @@ on all assistant messages when thinking mode is enabled.
from typing import Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage
from langchain_deepseek import ChatDeepSeek
from deerflow.models.assistant_payload_replay import restore_assistant_payloads, restore_reasoning_content
class PatchedChatDeepSeek(ChatDeepSeek):
"""ChatDeepSeek with proper reasoning_content preservation.
@@ -49,25 +50,10 @@ class PatchedChatDeepSeek(ChatDeepSeek):
# Call parent to get the base payload
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
# Match payload messages with original messages to restore reasoning_content
payload_messages = payload.get("messages", [])
# The payload messages and original messages should be in the same order
# Iterate through both and match by position
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
reasoning_content = orig_msg.additional_kwargs.get("reasoning_content")
if reasoning_content is not None:
payload_msg["reasoning_content"] = reasoning_content
else:
# Fallback: match by counting assistant messages
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
for (idx, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
reasoning_content = ai_msg.additional_kwargs.get("reasoning_content")
if reasoning_content is not None:
payload_messages[idx]["reasoning_content"] = reasoning_content
restore_assistant_payloads(
payload.get("messages", []),
original_messages,
restore_reasoning_content,
)
return payload
@@ -0,0 +1,140 @@
"""Patched ChatOpenAI adapter for Xiaomi MiMo reasoning_content replay.
MiMo's OpenAI-compatible API returns ``reasoning_content`` in thinking mode and
requires that value to be replayed on historical assistant messages in
multi-turn agent conversations. Standard ``langchain_openai.ChatOpenAI`` drops
that provider-specific field, which can cause HTTP 400 errors once tool calls
enter the conversation history.
"""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
from deerflow.models.assistant_payload_replay import restore_assistant_payloads, restore_reasoning_content
_MISSING = object()
def _extract_reasoning_content(value: Any) -> str | object:
"""Return reasoning_content from a dict/Pydantic object, preserving empty strings."""
if isinstance(value, Mapping):
if "reasoning_content" in value and value["reasoning_content"] is not None:
return value["reasoning_content"]
return _MISSING
reasoning = getattr(value, "reasoning_content", _MISSING)
if reasoning is not _MISSING and reasoning is not None:
return reasoning
model_extra = getattr(value, "model_extra", None)
if isinstance(model_extra, Mapping) and "reasoning_content" in model_extra and model_extra["reasoning_content"] is not None:
return model_extra["reasoning_content"]
return _MISSING
def _with_reasoning_content(message: AIMessage | AIMessageChunk, reasoning: str) -> AIMessage | AIMessageChunk:
additional_kwargs = dict(message.additional_kwargs)
if additional_kwargs.get("reasoning_content") != reasoning:
additional_kwargs["reasoning_content"] = reasoning
return message.model_copy(update={"additional_kwargs": additional_kwargs})
def _get_typed_choice_message(response: Any, index: int) -> Any:
choices = getattr(response, "choices", None)
if choices is None:
return None
try:
return choices[index].message
except (AttributeError, IndexError, TypeError):
return None
class PatchedChatMiMo(ChatOpenAI):
"""ChatOpenAI with ``reasoning_content`` preservation for MiMo thinking mode."""
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def lc_secrets(self) -> dict[str, str]:
return {"api_key": "MIMO_API_KEY", "openai_api_key": "MIMO_API_KEY"}
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict:
original_messages = self._convert_input(input_).to_messages()
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
restore_assistant_payloads(
payload.get("messages", []),
original_messages,
restore_reasoning_content,
)
return payload
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: type,
base_generation_info: dict | None,
) -> ChatGenerationChunk | None:
generation_chunk = super()._convert_chunk_to_generation_chunk(
chunk,
default_chunk_class,
base_generation_info,
)
if generation_chunk is None:
return None
choices = chunk.get("choices", [])
if choices:
delta = choices[0].get("delta") or {}
reasoning = _extract_reasoning_content(delta)
if reasoning is not _MISSING and isinstance(generation_chunk.message, AIMessageChunk):
generation_chunk = ChatGenerationChunk(
message=_with_reasoning_content(generation_chunk.message, reasoning),
generation_info=generation_chunk.generation_info,
)
return generation_chunk
def _create_chat_result(
self,
response: dict | Any,
generation_info: dict | None = None,
) -> ChatResult:
result = super()._create_chat_result(response, generation_info)
response_dict = response if isinstance(response, dict) else response.model_dump()
choices = response_dict.get("choices", [])
patched_generations: list[ChatGeneration] | None = None
for index, generation in enumerate(result.generations):
choice = choices[index] if index < len(choices) else {}
choice_message = choice.get("message", {}) if isinstance(choice, Mapping) else {}
reasoning = _extract_reasoning_content(choice_message)
if reasoning is _MISSING and not isinstance(response, dict):
reasoning = _extract_reasoning_content(_get_typed_choice_message(response, index))
message = generation.message
if reasoning is not _MISSING and isinstance(message, AIMessage):
if patched_generations is None:
patched_generations = list(result.generations)
patched_generations[index] = ChatGeneration(
message=_with_reasoning_content(message, reasoning),
generation_info=generation.generation_info,
)
return ChatResult(generations=patched_generations or result.generations, llm_output=result.llm_output)
@@ -27,6 +27,8 @@ from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from deerflow.models.assistant_payload_replay import restore_assistant_payloads
class PatchedChatOpenAI(ChatOpenAI):
"""ChatOpenAI with ``thought_signature`` preservation for Gemini thinking via OpenAI gateway.
@@ -75,18 +77,7 @@ class PatchedChatOpenAI(ChatOpenAI):
# Obtain the base payload from the parent implementation.
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
payload_messages = payload.get("messages", [])
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
_restore_tool_call_signatures(payload_msg, orig_msg)
else:
# Fallback: match assistant-role entries positionally against AIMessages.
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
for (_, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
_restore_tool_call_signatures(payload_msg, ai_msg)
restore_assistant_payloads(payload.get("messages", []), original_messages, _restore_tool_call_signatures)
return payload
@@ -144,10 +144,13 @@ class DbRunEventStore(RunEventStore):
async def put_batch(self, events):
if not events:
return []
thread_ids = {e["thread_id"] for e in events}
if len(thread_ids) > 1:
raise ValueError(f"put_batch requires all events to belong to the same thread; got {thread_ids!r}")
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
# Get max seq for the thread (assume all events in batch belong to same thread).
# All events belong to the same thread (validated above).
thread_id = events[0]["thread_id"]
max_seq = await self._max_seq_for_thread(session, thread_id)
seq = max_seq or 0
@@ -6,6 +6,15 @@ Each run's events are stored in a single file:
All categories (message, trace, lifecycle) are in the same file.
This backend is suitable for lightweight single-node deployments.
**Single-process guarantee**: the in-memory seq counter is process-local.
Multi-process deployments sharing the same directory will produce duplicate
or non-monotonic seq values. Use ``DbRunEventStore`` for multi-process or
high-concurrency deployments.
File I/O is offloaded to a thread pool via ``asyncio.to_thread`` so the
event loop is never blocked. Per-thread ``asyncio.Lock`` objects serialise
writes within a single process to prevent interleaved JSONL lines.
Known trade-off: ``list_messages()`` must scan all run files for a
thread since messages from multiple runs need unified seq ordering.
``list_events()`` reads only one file -- the fast path.
@@ -13,6 +22,7 @@ thread since messages from multiple runs need unified seq ordering.
from __future__ import annotations
import asyncio
import json
import logging
import re
@@ -30,6 +40,11 @@ class JsonlRunEventStore(RunEventStore):
def __init__(self, base_dir: str | Path | None = None):
self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow")
self._seq_counters: dict[str, int] = {} # thread_id -> current max seq
# Per-thread asyncio.Lock — serialises concurrent writes within one process.
self._write_locks: dict[str, asyncio.Lock] = {}
def _get_write_lock(self, thread_id: str) -> asyncio.Lock:
return self._write_locks.setdefault(thread_id, asyncio.Lock())
@staticmethod
def _validate_id(value: str, label: str) -> str:
@@ -50,10 +65,8 @@ class JsonlRunEventStore(RunEventStore):
self._seq_counters[thread_id] = self._seq_counters.get(thread_id, 0) + 1
return self._seq_counters[thread_id]
def _ensure_seq_loaded(self, thread_id: str) -> None:
"""Load max seq from existing files if not yet cached."""
if thread_id in self._seq_counters:
return
def _compute_max_seq(self, thread_id: str) -> int:
"""Scan all run files for a thread and return the current max seq (blocking I/O)."""
max_seq = 0
thread_dir = self._thread_dir(thread_id)
if thread_dir.exists():
@@ -64,7 +77,13 @@ class JsonlRunEventStore(RunEventStore):
max_seq = max(max_seq, record.get("seq", 0))
except json.JSONDecodeError:
logger.debug("Skipping malformed JSONL line in %s", f)
continue
return max_seq
async def _ensure_seq_loaded(self, thread_id: str) -> None:
"""Load max seq from existing files into the in-memory counter (non-blocking)."""
if thread_id in self._seq_counters:
return
max_seq = await asyncio.to_thread(self._compute_max_seq, thread_id)
self._seq_counters[thread_id] = max_seq
def _write_record(self, record: dict) -> None:
@@ -74,7 +93,7 @@ class JsonlRunEventStore(RunEventStore):
f.write(json.dumps(record, default=str, ensure_ascii=False) + "\n")
def _read_thread_events(self, thread_id: str) -> list[dict]:
"""Read all events for a thread, sorted by seq."""
"""Read all events for a thread, sorted by seq (blocking I/O)."""
events = []
thread_dir = self._thread_dir(thread_id)
if not thread_dir.exists():
@@ -87,12 +106,11 @@ class JsonlRunEventStore(RunEventStore):
events.append(json.loads(line))
except json.JSONDecodeError:
logger.debug("Skipping malformed JSONL line in %s", f)
continue
events.sort(key=lambda e: e.get("seq", 0))
return events
def _read_run_events(self, thread_id: str, run_id: str) -> list[dict]:
"""Read events for a specific run file."""
"""Read events for a specific run file (blocking I/O)."""
path = self._run_file(thread_id, run_id)
if not path.exists():
return []
@@ -104,25 +122,36 @@ class JsonlRunEventStore(RunEventStore):
events.append(json.loads(line))
except json.JSONDecodeError:
logger.debug("Skipping malformed JSONL line in %s", path)
continue
events.sort(key=lambda e: e.get("seq", 0))
return events
def _delete_thread_files(self, thread_id: str) -> None:
thread_dir = self._thread_dir(thread_id)
if thread_dir.exists():
for f in thread_dir.glob("*.jsonl"):
f.unlink()
def _delete_run_file(self, thread_id: str, run_id: str) -> None:
path = self._run_file(thread_id, run_id)
if path.exists():
path.unlink()
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
self._ensure_seq_loaded(thread_id)
seq = self._next_seq(thread_id)
record = {
"thread_id": thread_id,
"run_id": run_id,
"event_type": event_type,
"category": category,
"content": content,
"metadata": metadata or {},
"seq": seq,
"created_at": created_at or datetime.now(UTC).isoformat(),
}
self._write_record(record)
return record
async with self._get_write_lock(thread_id):
await self._ensure_seq_loaded(thread_id)
seq = self._next_seq(thread_id)
record = {
"thread_id": thread_id,
"run_id": run_id,
"event_type": event_type,
"category": category,
"content": content,
"metadata": metadata or {},
"seq": seq,
"created_at": created_at or datetime.now(UTC).isoformat(),
}
await asyncio.to_thread(self._write_record, record)
return record
async def put_batch(self, events):
if not events:
@@ -134,7 +163,7 @@ class JsonlRunEventStore(RunEventStore):
return results
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
all_events = self._read_thread_events(thread_id)
all_events = await asyncio.to_thread(self._read_thread_events, thread_id)
messages = [e for e in all_events if e.get("category") == "message"]
if before_seq is not None:
@@ -147,13 +176,13 @@ class JsonlRunEventStore(RunEventStore):
return messages[-limit:]
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
events = self._read_run_events(thread_id, run_id)
events = await asyncio.to_thread(self._read_run_events, thread_id, run_id)
if event_types is not None:
events = [e for e in events if e.get("event_type") in event_types]
return events[:limit]
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
events = self._read_run_events(thread_id, run_id)
events = await asyncio.to_thread(self._read_run_events, thread_id, run_id)
filtered = [e for e in events if e.get("category") == "message"]
if before_seq is not None:
filtered = [e for e in filtered if e.get("seq", 0) < before_seq]
@@ -165,23 +194,25 @@ class JsonlRunEventStore(RunEventStore):
return filtered[-limit:] if len(filtered) > limit else filtered
async def count_messages(self, thread_id):
all_events = self._read_thread_events(thread_id)
all_events = await asyncio.to_thread(self._read_thread_events, thread_id)
return sum(1 for e in all_events if e.get("category") == "message")
async def delete_by_thread(self, thread_id):
all_events = self._read_thread_events(thread_id)
count = len(all_events)
thread_dir = self._thread_dir(thread_id)
if thread_dir.exists():
for f in thread_dir.glob("*.jsonl"):
f.unlink()
self._seq_counters.pop(thread_id, None)
return count
async with self._get_write_lock(thread_id):
all_events = await asyncio.to_thread(self._read_thread_events, thread_id)
count = len(all_events)
await asyncio.to_thread(self._delete_thread_files, thread_id)
self._seq_counters.pop(thread_id, None)
# Pop the lock inside the held scope to minimise the window where a new caller
# could obtain a fresh lock while a waiting coroutine still holds the old one.
# Note: coroutines that already acquired a reference to this lock before the
# delete will still proceed after we release — this is an accepted narrow race.
self._write_locks.pop(thread_id, None)
return count
async def delete_by_run(self, thread_id, run_id):
events = self._read_run_events(thread_id, run_id)
count = len(events)
path = self._run_file(thread_id, run_id)
if path.exists():
path.unlink()
return count
async with self._get_write_lock(thread_id):
events = await asyncio.to_thread(self._read_run_events, thread_id, run_id)
count = len(events)
await asyncio.to_thread(self._delete_run_file, thread_id, run_id)
return count
@@ -1,16 +1,39 @@
"""Run lifecycle management for LangGraph Platform API compatibility."""
from .domain import (
AssistantId,
CancelAction,
DisconnectMode,
EventSeq,
InvalidRunTransition,
MultitaskStrategy,
Run,
RunId,
RunScope,
RunStatus,
ThreadId,
UserId,
)
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
from .schemas import DisconnectMode, RunStatus
from .worker import RunContext, run_agent
__all__ = [
"AssistantId",
"CancelAction",
"ConflictError",
"DisconnectMode",
"EventSeq",
"InvalidRunTransition",
"MultitaskStrategy",
"Run",
"RunContext",
"RunId",
"RunManager",
"RunRecord",
"RunScope",
"RunStatus",
"ThreadId",
"UnsupportedStrategyError",
"UserId",
"run_agent",
]
@@ -0,0 +1,20 @@
"""Application-layer DTOs and services for run runtime use cases."""
from .commands import CancelRunCommand, CreateRunCommand, JoinRunStreamCommand
from .dto import RunMessageView, RunSnapshot, RunStreamHandle, StoredRunEvent
from .queries import GetRunQuery, ListRunMessagesQuery, ListRunsQuery
from .services import RunsApplicationService
__all__ = [
"CancelRunCommand",
"CreateRunCommand",
"GetRunQuery",
"JoinRunStreamCommand",
"ListRunMessagesQuery",
"ListRunsQuery",
"RunMessageView",
"RunSnapshot",
"RunStreamHandle",
"RunsApplicationService",
"StoredRunEvent",
]
@@ -0,0 +1,46 @@
"""Application command DTOs for run use cases."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal
from ..domain import AssistantId, CancelAction, DisconnectMode, MultitaskStrategy, RunId, RunScope, ThreadId
@dataclass(frozen=True)
class CreateRunCommand:
thread_id: ThreadId
assistant_id: AssistantId | None = None
input: dict[str, Any] | None = None
command: dict[str, Any] | None = None
metadata: dict[str, Any] = field(default_factory=dict)
config: dict[str, Any] = field(default_factory=dict)
context: dict[str, Any] = field(default_factory=dict)
scope: RunScope = RunScope.stateful
on_disconnect: DisconnectMode = DisconnectMode.cancel
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject
stream_mode: list[str] | str | None = None
stream_subgraphs: bool = False
interrupt_before: list[str] | Literal["*"] | None = None
interrupt_after: list[str] | Literal["*"] | None = None
@dataclass(frozen=True)
class CancelRunCommand:
run_id: RunId
action: CancelAction = CancelAction.interrupt
wait: bool = False
@dataclass(frozen=True)
class JoinRunStreamCommand:
run_id: RunId
last_event_id: str | None = None
__all__ = [
"CancelRunCommand",
"CreateRunCommand",
"JoinRunStreamCommand",
]
@@ -0,0 +1,76 @@
"""Application output DTOs for run use cases."""
from __future__ import annotations
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from typing import Any
from ..domain import AssistantId, EventSeq, Run, RunId, RunStatus, ThreadId
@dataclass(frozen=True)
class RunSnapshot:
run_id: RunId
thread_id: ThreadId
assistant_id: AssistantId | None = None
status: RunStatus = RunStatus.pending
metadata: dict[str, Any] = field(default_factory=dict)
kwargs: dict[str, Any] = field(default_factory=dict)
created_at: str = ""
updated_at: str = ""
error: str | None = None
model_name: str | None = None
@classmethod
def from_run(cls, run: Run) -> RunSnapshot:
return cls(
run_id=run.run_id,
thread_id=run.thread_id,
assistant_id=run.assistant_id,
status=run.status,
metadata=dict(run.metadata),
kwargs=dict(run.kwargs),
created_at=run.created_at,
updated_at=run.updated_at,
error=run.error,
model_name=run.model_name,
)
@dataclass(frozen=True)
class RunMessageView:
thread_id: ThreadId
run_id: RunId
seq: EventSeq
event_type: str
content: str | dict[str, Any] = ""
metadata: dict[str, Any] = field(default_factory=dict)
created_at: str = ""
@dataclass(frozen=True)
class StoredRunEvent:
thread_id: ThreadId
run_id: RunId
seq: EventSeq
event_type: str
category: str
content: str | dict[str, Any] = ""
metadata: dict[str, Any] = field(default_factory=dict)
created_at: str = ""
@dataclass(frozen=True)
class RunStreamHandle:
run_id: RunId
thread_id: ThreadId
events: AsyncIterator[Any]
__all__ = [
"RunMessageView",
"RunSnapshot",
"RunStreamHandle",
"StoredRunEvent",
]
@@ -0,0 +1,37 @@
"""Application query DTOs for run use cases."""
from __future__ import annotations
from dataclasses import dataclass
from ..domain import RunId, ThreadId, UserId
@dataclass(frozen=True)
class GetRunQuery:
run_id: RunId
thread_id: ThreadId | None = None
user_id: UserId | None = None
@dataclass(frozen=True)
class ListRunsQuery:
thread_id: ThreadId
user_id: UserId | None = None
limit: int = 100
@dataclass(frozen=True)
class ListRunMessagesQuery:
thread_id: ThreadId
run_id: RunId
limit: int = 50
before_seq: int | None = None
after_seq: int | None = None
__all__ = [
"GetRunQuery",
"ListRunMessagesQuery",
"ListRunsQuery",
]
@@ -0,0 +1,74 @@
"""Application service skeleton for run use cases."""
from __future__ import annotations
from dataclasses import dataclass
from ..execution import RunExecutionScheduler, RunSupervisor
from ..repositories import RunEventLog, RunRepository
from ..streams import RunStreamBroker
from .commands import CancelRunCommand, CreateRunCommand, JoinRunStreamCommand
from .dto import RunMessageView, RunSnapshot, RunStreamHandle
from .queries import GetRunQuery, ListRunMessagesQuery, ListRunsQuery
@dataclass
class RunsApplicationService:
"""Use-case orchestration boundary for run runtime operations.
PR1 only introduces the boundary and dependency shape. Existing Gateway
handlers continue to call the legacy service functions until later PRs move
behavior into this class.
"""
run_repository: RunRepository
run_event_log: RunEventLog
stream_broker: RunStreamBroker
scheduler: RunExecutionScheduler
supervisor: RunSupervisor
async def create_background(self, command: CreateRunCommand) -> RunSnapshot:
# PR1 defines the application boundary; later PRs move Gateway runtime
# behavior behind this method.
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def create_and_stream(self, command: CreateRunCommand) -> RunStreamHandle:
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def create_and_wait(self, command: CreateRunCommand) -> RunSnapshot:
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def join_stream(self, command: JoinRunStreamCommand) -> RunStreamHandle:
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def cancel(self, command: CancelRunCommand) -> bool:
return await self.supervisor.cancel(command.run_id, action=command.action)
async def get_run(self, query: GetRunQuery) -> RunSnapshot | None:
run = await self.run_repository.get(query.run_id, user_id=query.user_id)
if run is None:
return None
if query.thread_id is not None and run.thread_id != query.thread_id:
return None
return RunSnapshot.from_run(run)
async def list_runs(self, query: ListRunsQuery) -> list[RunSnapshot]:
return await self.run_repository.list_by_thread(
query.thread_id,
user_id=query.user_id,
limit=query.limit,
)
async def list_run_messages(self, query: ListRunMessagesQuery) -> list[RunMessageView]:
return await self.run_event_log.list_messages_by_run(
query.thread_id,
query.run_id,
limit=query.limit,
before_seq=query.before_seq,
after_seq=query.after_seq,
)
__all__ = [
"RunsApplicationService",
]
@@ -0,0 +1,33 @@
"""Run runtime domain model."""
from .errors import InvalidRunTransition, RunDomainError
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
from .identifiers import AssistantId, RunId, ThreadId, UserId
from .model import Run
from .policies import CancelPolicy, MultitaskDecision, MultitaskPolicy
from .value_objects import CancelAction, DisconnectMode, EventSeq, MultitaskStrategy, RunScope, RunStatus
__all__ = [
"AssistantId",
"CancelAction",
"CancelPolicy",
"DisconnectMode",
"EventSeq",
"InvalidRunTransition",
"MultitaskDecision",
"MultitaskPolicy",
"MultitaskStrategy",
"Run",
"RunCancelled",
"RunCompleted",
"RunCreated",
"RunDomainError",
"RunEvent",
"RunFailed",
"RunId",
"RunScope",
"RunStarted",
"RunStatus",
"ThreadId",
"UserId",
]
@@ -0,0 +1,24 @@
"""Domain-level errors for run lifecycle operations."""
from __future__ import annotations
from .value_objects import RunStatus
class RunDomainError(Exception):
"""Base class for run runtime domain errors."""
class InvalidRunTransition(RunDomainError):
"""Raised when a run status transition violates lifecycle rules."""
def __init__(self, current: RunStatus, target: RunStatus) -> None:
super().__init__(f"Cannot transition run from {current.value!r} to {target.value!r}")
self.current = current
self.target = target
__all__ = [
"InvalidRunTransition",
"RunDomainError",
]
@@ -0,0 +1,64 @@
"""Domain events emitted by the run aggregate."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from deerflow.utils.time import now_iso
from .identifiers import AssistantId, RunId, ThreadId
from .value_objects import CancelAction, RunStatus
@dataclass(frozen=True)
class RunCreated:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
assistant_id: AssistantId | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass(frozen=True)
class RunStarted:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
@dataclass(frozen=True)
class RunCompleted:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
@dataclass(frozen=True)
class RunFailed:
run_id: RunId
thread_id: ThreadId
status: RunStatus
occurred_at: str = field(default_factory=now_iso)
error: str | None = None
@dataclass(frozen=True)
class RunCancelled:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
action: CancelAction = CancelAction.interrupt
RunEvent = RunCreated | RunStarted | RunCompleted | RunFailed | RunCancelled
__all__ = [
"RunCancelled",
"RunCompleted",
"RunCreated",
"RunEvent",
"RunFailed",
"RunStarted",
]
@@ -0,0 +1,27 @@
"""Lightweight identifiers for the run runtime domain."""
from __future__ import annotations
from typing import NewType
RunId = NewType("RunId", str)
ThreadId = NewType("ThreadId", str)
AssistantId = NewType("AssistantId", str)
UserId = NewType("UserId", str)
def require_non_empty(value: str, *, field_name: str) -> str:
"""Return a stripped identifier value, rejecting empty identifiers."""
normalized = value.strip()
if not normalized:
raise ValueError(f"{field_name} must not be empty")
return normalized
__all__ = [
"AssistantId",
"RunId",
"ThreadId",
"UserId",
"require_non_empty",
]
@@ -0,0 +1,193 @@
"""Run aggregate root and lifecycle invariants."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from deerflow.utils.time import now_iso
from .errors import InvalidRunTransition
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
from .identifiers import AssistantId, RunId, ThreadId, require_non_empty
from .value_objects import CancelAction, MultitaskStrategy, RunScope, RunStatus
# Keep lifecycle transitions explicit so later application code cannot invent
# ad hoc status moves outside the aggregate.
_ALLOWED_TRANSITIONS: dict[RunStatus, frozenset[RunStatus]] = {
RunStatus.pending: frozenset(
{
RunStatus.running,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
),
RunStatus.running: frozenset(
{
RunStatus.success,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
),
RunStatus.success: frozenset(),
RunStatus.error: frozenset(),
RunStatus.timeout: frozenset(),
RunStatus.interrupted: frozenset(),
}
@dataclass
class Run:
"""Run aggregate root.
The aggregate owns lifecycle invariants only. Infrastructure concerns such
as SQL sessions, SSE frames, Redis clients, and FastAPI requests stay out of
this model.
"""
run_id: RunId
thread_id: ThreadId
status: RunStatus
assistant_id: AssistantId | None = None
scope: RunScope = RunScope.stateful
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject
metadata: dict[str, Any] = field(default_factory=dict)
kwargs: dict[str, Any] = field(default_factory=dict)
created_at: str = field(default_factory=now_iso)
updated_at: str = field(default_factory=now_iso)
error: str | None = None
model_name: str | None = None
_pending_events: list[RunEvent] = field(default_factory=list, init=False, repr=False)
def __post_init__(self) -> None:
self.run_id = RunId(require_non_empty(str(self.run_id), field_name="run_id"))
self.thread_id = ThreadId(require_non_empty(str(self.thread_id), field_name="thread_id"))
if self.assistant_id is not None:
self.assistant_id = AssistantId(require_non_empty(str(self.assistant_id), field_name="assistant_id"))
@classmethod
def create(
cls,
*,
run_id: RunId,
thread_id: ThreadId,
assistant_id: AssistantId | None = None,
scope: RunScope = RunScope.stateful,
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject,
metadata: dict[str, Any] | None = None,
kwargs: dict[str, Any] | None = None,
model_name: str | None = None,
created_at: str | None = None,
) -> Run:
timestamp = created_at or now_iso()
run = cls(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending,
scope=scope,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=timestamp,
updated_at=timestamp,
model_name=model_name,
)
run._record_event(
RunCreated(
run_id=run.run_id,
thread_id=run.thread_id,
occurred_at=timestamp,
assistant_id=run.assistant_id,
metadata=dict(run.metadata),
)
)
return run
@property
def is_terminal(self) -> bool:
return not _ALLOWED_TRANSITIONS[self.status]
def pull_events(self) -> tuple[RunEvent, ...]:
# Domain events are drained by the application layer after the aggregate
# has accepted a state change.
events = tuple(self._pending_events)
self._pending_events.clear()
return events
def mark_started(self, *, at: str | None = None) -> None:
self._transition_to(RunStatus.running, at=at)
def mark_completed(self, *, at: str | None = None) -> None:
self._transition_to(RunStatus.success, at=at)
def mark_failed(self, error: str | None = None, *, at: str | None = None) -> None:
self._transition_to(RunStatus.error, error=error, at=at)
def mark_timed_out(self, error: str | None = None, *, at: str | None = None) -> None:
self._transition_to(RunStatus.timeout, error=error, at=at)
def mark_cancelled(self, *, action: CancelAction = CancelAction.interrupt, at: str | None = None) -> None:
self._transition_to(RunStatus.interrupted, action=action, at=at)
def _transition_to(
self,
target: RunStatus,
*,
error: str | None = None,
action: CancelAction = CancelAction.interrupt,
at: str | None = None,
) -> None:
if target == self.status:
return
if target not in _ALLOWED_TRANSITIONS[self.status]:
raise InvalidRunTransition(self.status, target)
timestamp = at or now_iso()
self.status = target
self.updated_at = timestamp
if error is not None:
self.error = error
self._record_event(self._event_for_transition(target, timestamp, error=error, action=action))
def _event_for_transition(
self,
target: RunStatus,
occurred_at: str,
*,
error: str | None,
action: CancelAction,
) -> RunEvent:
# Keep event construction next to the transition rules so a new status
# cannot be added without an explicit durable event shape.
if target == RunStatus.running:
return RunStarted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
if target == RunStatus.success:
return RunCompleted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
if target in (RunStatus.error, RunStatus.timeout):
return RunFailed(
run_id=self.run_id,
thread_id=self.thread_id,
status=target,
occurred_at=occurred_at,
error=error,
)
if target == RunStatus.interrupted:
return RunCancelled(
run_id=self.run_id,
thread_id=self.thread_id,
occurred_at=occurred_at,
action=action,
)
raise InvalidRunTransition(self.status, target)
def _record_event(self, event: RunEvent) -> None:
self._pending_events.append(event)
__all__ = [
"Run",
"RunStatus",
]
@@ -0,0 +1,50 @@
"""Domain policies for run concurrency and cancellation."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from enum import StrEnum
from .model import Run
from .value_objects import CancelAction, MultitaskStrategy, RunStatus
class MultitaskDecision(StrEnum):
"""Application-level decision produced by a multitask policy."""
allow = "allow"
reject = "reject"
cancel_existing = "cancel_existing"
enqueue = "enqueue"
@dataclass(frozen=True)
class MultitaskPolicy:
strategy: MultitaskStrategy = MultitaskStrategy.reject
def decide(self, active_runs: Sequence[Run]) -> MultitaskDecision:
inflight = [run for run in active_runs if run.status in (RunStatus.pending, RunStatus.running)]
if not inflight:
return MultitaskDecision.allow
if self.strategy == MultitaskStrategy.reject:
return MultitaskDecision.reject
if self.strategy in (MultitaskStrategy.interrupt, MultitaskStrategy.rollback):
return MultitaskDecision.cancel_existing
return MultitaskDecision.enqueue
@dataclass(frozen=True)
class CancelPolicy:
action: CancelAction = CancelAction.interrupt
@property
def rolls_back_checkpoint(self) -> bool:
return self.action == CancelAction.rollback
__all__ = [
"CancelPolicy",
"MultitaskDecision",
"MultitaskPolicy",
]
@@ -0,0 +1,88 @@
"""Domain value objects for run lifecycle semantics."""
from __future__ import annotations
from dataclasses import dataclass
from enum import StrEnum
class RunStatus(StrEnum):
"""Lifecycle status of a single run."""
pending = "pending"
running = "running"
success = "success"
error = "error"
timeout = "timeout"
interrupted = "interrupted"
class DisconnectMode(StrEnum):
"""Behaviour when the SSE consumer disconnects."""
cancel = "cancel"
continue_ = "continue"
class RunScope(StrEnum):
"""Conversation scope for a run."""
stateful = "stateful"
stateless = "stateless"
temporary_thread = "temporary_thread"
class MultitaskStrategy(StrEnum):
"""Concurrency strategy for a new run on a thread."""
reject = "reject"
interrupt = "interrupt"
rollback = "rollback"
enqueue = "enqueue"
class CancelAction(StrEnum):
"""Cancellation action requested by an API or supervisor."""
interrupt = "interrupt"
rollback = "rollback"
TERMINAL_RUN_STATUSES: frozenset[RunStatus] = frozenset(
{
RunStatus.success,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
)
def is_terminal_status(status: RunStatus) -> bool:
return status in TERMINAL_RUN_STATUSES
@dataclass(frozen=True, order=True)
class EventSeq:
"""Thread-local event sequence number."""
value: int
def __post_init__(self) -> None:
if self.value < 0:
raise ValueError("EventSeq must be non-negative")
def next(self) -> EventSeq:
return EventSeq(self.value + 1)
__all__ = [
"CancelAction",
"DisconnectMode",
"EventSeq",
"MultitaskStrategy",
"RunScope",
"RunStatus",
"TERMINAL_RUN_STATUSES",
"is_terminal_status",
]
@@ -0,0 +1,12 @@
"""Execution contracts for run lifecycle orchestration."""
from .executor import RunExecutor
from .scheduler import RunExecutionHandle, RunExecutionScheduler
from .supervisor import RunSupervisor
__all__ = [
"RunExecutionHandle",
"RunExecutionScheduler",
"RunExecutor",
"RunSupervisor",
]
@@ -0,0 +1,19 @@
"""Run executor contract."""
from __future__ import annotations
from typing import Protocol
from ..domain import Run
class RunExecutor(Protocol):
"""Executes one run against the underlying agent or graph runtime."""
async def execute(self, run: Run) -> None:
pass
__all__ = [
"RunExecutor",
]
@@ -0,0 +1,26 @@
"""Run execution scheduler contract."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
from ..domain import RunId
@dataclass(frozen=True)
class RunExecutionHandle:
run_id: RunId
class RunExecutionScheduler(Protocol):
"""Starts background execution for an accepted run."""
async def start(self, run_id: RunId) -> RunExecutionHandle:
pass
__all__ = [
"RunExecutionHandle",
"RunExecutionScheduler",
]
@@ -0,0 +1,19 @@
"""Run execution supervision contract."""
from __future__ import annotations
from typing import Protocol
from ..domain import CancelAction, RunId
class RunSupervisor(Protocol):
"""Controls lifecycle operations for already scheduled runs."""
async def cancel(self, run_id: RunId, *, action: CancelAction = CancelAction.interrupt) -> bool:
pass
__all__ = [
"RunSupervisor",
]
@@ -0,0 +1,9 @@
"""Repository contracts for the run runtime application layer."""
from .run_event_log import RunEventLog
from .run_repository import RunRepository
__all__ = [
"RunEventLog",
"RunRepository",
]
@@ -0,0 +1,42 @@
"""Durable run event log contract."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from ..domain import RunEvent, RunId, ThreadId
if TYPE_CHECKING:
from ..application.dto import RunMessageView, StoredRunEvent
class RunEventLog(Protocol):
"""Persistence boundary for run messages and execution trace events."""
async def append(self, events: list[RunEvent]) -> list[StoredRunEvent]:
pass
async def list_messages_by_run(
self,
thread_id: ThreadId,
run_id: RunId,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> list[RunMessageView]:
pass
async def list_events_by_run(
self,
thread_id: ThreadId,
run_id: RunId,
*,
limit: int = 500,
) -> list[StoredRunEvent]:
pass
__all__ = [
"RunEventLog",
]
@@ -0,0 +1,37 @@
"""Run state repository contract."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from ..domain import Run, RunId, ThreadId, UserId
if TYPE_CHECKING:
from ..application.dto import RunSnapshot
class RunRepository(Protocol):
"""Persistence boundary for run state snapshots."""
async def save(self, run: Run) -> None:
pass
async def get(self, run_id: RunId, *, user_id: UserId | None = None) -> Run | None:
pass
async def list_by_thread(
self,
thread_id: ThreadId,
*,
user_id: UserId | None = None,
limit: int = 100,
) -> list[RunSnapshot]:
pass
async def delete(self, run_id: RunId) -> bool:
pass
__all__ = [
"RunRepository",
]
@@ -1,21 +1,10 @@
"""Run status and disconnect mode enums."""
"""Compatibility exports for run status and disconnect mode enums."""
from enum import StrEnum
# Existing callers import these enums from ``runs.schemas``. Re-export the
# domain definitions until all imports move to ``runs.domain``.
from .domain import DisconnectMode, RunStatus
class RunStatus(StrEnum):
"""Lifecycle status of a single run."""
pending = "pending"
running = "running"
success = "success"
error = "error"
timeout = "timeout"
interrupted = "interrupted"
class DisconnectMode(StrEnum):
"""Behaviour when the SSE consumer disconnects."""
cancel = "cancel"
continue_ = "continue"
__all__ = [
"DisconnectMode",
"RunStatus",
]
@@ -0,0 +1,8 @@
"""Realtime stream contracts for run application use cases."""
from .run_stream_broker import RunStreamBroker, RunStreamEvent
__all__ = [
"RunStreamBroker",
"RunStreamEvent",
]
@@ -0,0 +1,44 @@
"""Realtime run stream broker contract."""
from __future__ import annotations
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any, Protocol
from ..domain import RunId
@dataclass(frozen=True)
class RunStreamEvent:
id: str
event: str
data: Any
class RunStreamBroker(Protocol):
"""Realtime publish/subscribe boundary for run streams."""
async def publish(self, run_id: RunId, event: str, data: Any) -> None:
pass
async def publish_terminal(self, run_id: RunId, *, event: str = "end", data: Any = None) -> None:
pass
def subscribe(
self,
run_id: RunId,
*,
last_event_id: str | None = None,
heartbeat_interval: float = 15.0,
) -> AsyncIterator[RunStreamEvent]:
pass
async def cleanup(self, run_id: RunId, *, delay: float = 0) -> None:
pass
__all__ = [
"RunStreamBroker",
"RunStreamEvent",
]
@@ -13,6 +13,7 @@ import stat
import zipfile
from pathlib import Path, PurePosixPath, PureWindowsPath
from deerflow.skills.permissions import make_skill_tree_sandbox_readable
from deerflow.skills.security_scanner import scan_skill_content
logger = logging.getLogger(__name__)
@@ -139,6 +140,7 @@ def _move_staged_skill_into_reserved_target(staging_target: Path, target: Path)
reserved = True
for child in staging_target.iterdir():
shutil.move(str(child), target / child.name)
make_skill_tree_sandbox_readable(target)
installed = True
except FileExistsError as e:
raise SkillAlreadyExistsError(f"Skill '{target.name}' already exists") from e
@@ -0,0 +1,34 @@
"""Filesystem permission helpers for installed skill trees."""
import stat
from pathlib import Path
def make_skill_path_sandbox_readable(path: Path) -> None:
if path.is_symlink():
return
mode = stat.S_IMODE(path.stat().st_mode)
without_sandbox_write = mode & ~(stat.S_IWGRP | stat.S_IWOTH)
if path.is_dir():
path.chmod(without_sandbox_write | 0o555)
elif path.is_file():
path.chmod(without_sandbox_write | 0o444)
def make_skill_tree_sandbox_readable(target: Path) -> None:
make_skill_path_sandbox_readable(target)
for path in target.rglob("*"):
make_skill_path_sandbox_readable(path)
def make_skill_written_path_sandbox_readable(skill_root: Path, target: Path) -> None:
resolved_root = skill_root.resolve()
resolved_target = target.resolve()
resolved_target.relative_to(resolved_root)
make_skill_path_sandbox_readable(resolved_root)
current = resolved_root
for part in resolved_target.parent.relative_to(resolved_root).parts:
current = current / part
make_skill_path_sandbox_readable(current)
make_skill_path_sandbox_readable(resolved_target)
@@ -13,6 +13,7 @@ from datetime import UTC, datetime
from pathlib import Path
from deerflow.config.runtime_paths import resolve_path
from deerflow.skills.permissions import make_skill_written_path_sandbox_readable
from deerflow.skills.storage.skill_storage import SKILL_MD_FILE, SkillStorage
from deerflow.skills.types import SkillCategory
@@ -90,6 +91,7 @@ class LocalSkillStorage(SkillStorage):
tmp_file.write(content)
tmp_path = Path(tmp_file.name)
tmp_path.replace(target)
make_skill_written_path_sandbox_readable(self.get_custom_skill_dir(name), target)
async def ainstall_skill_from_archive(self, archive_path: str | Path) -> dict:
import zipfile
@@ -0,0 +1,62 @@
"""Regression anchor: JsonlRunEventStore async API must not block the loop.
``JsonlRunEventStore`` is the ``run_events.backend == "jsonl"`` implementation.
Its ``async def`` methods perform synchronous filesystem IO (``Path.glob``,
``read_text``, ``open``, ``unlink``) that must be offloaded with
``asyncio.to_thread`` (fixed in #3084). ``put`` runs on every emitted run event,
so any blocking IO here stalls the event loop on the hot path.
#3084 added a mock-based offload assertion in
``tests/test_jsonl_event_store_async_io.py`` that covers ``put`` only. This
anchor complements it by driving the **full** async surface (``put``,
``put_batch``, ``list_messages``, ``list_events``, ``list_messages_by_run``,
``count_messages``, ``delete_by_run``, ``delete_by_thread``) under the strict
Blockbuster runtime gate, so any blocking IO reintroduced on the event loop in
any of these methods not just removal of a specific ``to_thread`` call
fails CI.
"""
from __future__ import annotations
from pathlib import Path
import pytest
pytestmark = pytest.mark.asyncio
async def test_jsonl_run_event_store_async_api_does_not_block_event_loop(tmp_path: Path) -> None:
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
store = JsonlRunEventStore(base_dir=str(tmp_path))
# Seed an existing run file so put()'s seq-load globs + reads, and the
# read/delete paths have files to scan. Test-side IO is invisible to the
# gate (this module is not in scanned_modules).
thread_dir = tmp_path / "threads" / "t1" / "runs"
thread_dir.mkdir(parents=True, exist_ok=True)
(thread_dir / "r0.jsonl").write_text('{"seq": 1, "category": "message", "run_id": "r0"}\n', encoding="utf-8")
# writes: put + put_batch
record = await store.put(thread_id="t1", run_id="r1", event_type="message", category="message", content="hi")
assert record["seq"] >= 2
batch = await store.put_batch(
[
{"thread_id": "t1", "run_id": "r2", "event_type": "message", "category": "message", "content": "a"},
{"thread_id": "t1", "run_id": "r2", "event_type": "trace", "category": "trace", "content": "b"},
]
)
assert len(batch) == 2
# reads: list_messages / list_events / list_messages_by_run / count_messages.
# list_events is exercised both without and with the event_types filter so
# the filter branch runs after _read_run_events' filesystem IO.
assert isinstance(await store.list_messages("t1"), list)
assert isinstance(await store.list_events("t1", "r1"), list)
assert isinstance(await store.list_events("t1", "r1", event_types=["message"]), list)
assert isinstance(await store.list_messages_by_run("t1", "r2"), list)
assert await store.count_messages("t1") >= 1
# deletes: delete_by_run (single file) then delete_by_thread (remaining)
assert await store.delete_by_run("t1", "r2") >= 1
assert await store.delete_by_thread("t1") >= 1
@@ -0,0 +1,56 @@
"""Regression anchor: UploadsMiddleware must not block the event loop.
``before_agent`` scans the thread uploads directory (``exists`` / ``iterdir`` /
``stat`` plus reading sibling ``.md`` outlines). LangChain wires a sync-only
``before_agent`` as ``RunnableCallable(before_agent, None)``; langgraph's
``ainvoke`` runs it directly on the event loop when ``afunc is None``. So the
filesystem scan must be offloaded (the middleware provides ``abefore_agent``).
This anchor drives the real ``create_agent`` graph via ``ainvoke`` under the
strict Blockbuster gate. If the scan regresses back onto the event loop,
Blockbuster raises ``BlockingError`` and this test fails.
The graph/middleware construction is offloaded with ``asyncio.to_thread`` only
because ``Paths.__init__`` resolves paths synchronously; the surface under test
(``before_agent``'s directory scan) is exercised on the event loop, not
bypassed.
"""
from __future__ import annotations
import asyncio
from pathlib import Path
import pytest
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage
pytestmark = pytest.mark.asyncio
class _FakeModel(FakeMessagesListChatModel):
"""FakeMessagesListChatModel with a no-op ``bind_tools`` for create_agent."""
def bind_tools(self, tools, **kwargs): # type: ignore[override]
return self
async def test_before_agent_uploads_scan_does_not_block_event_loop(tmp_path: Path) -> None:
from langchain.agents import create_agent
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
from deerflow.runtime.user_context import get_effective_user_id
mw = await asyncio.to_thread(UploadsMiddleware, str(tmp_path))
uploads_dir = await asyncio.to_thread(mw._paths.sandbox_uploads_dir, "t1", user_id=get_effective_user_id())
uploads_dir.mkdir(parents=True, exist_ok=True) # test-side seeding (not in scanned_modules)
(uploads_dir / "existing.txt").write_text("hello", encoding="utf-8")
agent = await asyncio.to_thread(lambda: create_agent(model=_FakeModel(responses=[AIMessage(content="ok")]), tools=[], middleware=[mw]))
result = await agent.ainvoke(
{"messages": [HumanMessage(content="hi")]},
{"configurable": {"thread_id": "t1"}},
)
assert result["messages"]
@@ -1,210 +0,0 @@
"""Tests for AioSandboxProvider auto-restart of crashed containers."""
import importlib
import threading
from unittest.mock import MagicMock, patch
def _import_provider():
return importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
def _make_provider(*, auto_restart=True, alive=True):
"""Build a minimal AioSandboxProvider with a mock backend.
Args:
auto_restart: Value for the auto_restart config key.
alive: Whether the mock backend reports containers as alive.
"""
mod = _import_provider()
with patch.object(mod.AioSandboxProvider, "_start_idle_checker"):
provider = mod.AioSandboxProvider.__new__(mod.AioSandboxProvider)
provider._config = {"auto_restart": auto_restart}
provider._lock = threading.Lock()
provider._sandboxes = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._thread_locks = {}
provider._last_activity = {}
provider._warm_pool = {}
provider._shutdown_called = False
provider._idle_checker_stop = threading.Event()
backend = MagicMock()
backend.is_alive.return_value = alive
provider._backend = backend
return provider, backend
def _seed_sandbox(provider, sandbox_id="dead-beef", thread_id="thread-1"):
"""Insert a sandbox into the provider's caches as if it were acquired."""
sandbox = MagicMock()
info = MagicMock()
provider._sandboxes[sandbox_id] = sandbox
provider._sandbox_infos[sandbox_id] = info
provider._last_activity[sandbox_id] = 0.0
if thread_id:
provider._thread_sandboxes[thread_id] = sandbox_id
return sandbox, info
# ── get() returns sandbox when container is alive ──────────────────────────
def test_get_returns_sandbox_when_container_alive():
"""When auto_restart is on and the container is alive, get() returns the sandbox."""
provider, backend = _make_provider(auto_restart=True, alive=True)
sandbox, _ = _seed_sandbox(provider)
result = provider.get("dead-beef")
assert result is sandbox
backend.is_alive.assert_called_once()
def test_get_returns_sandbox_when_auto_restart_disabled():
"""When auto_restart is off, get() skips the health check entirely."""
provider, backend = _make_provider(auto_restart=False)
sandbox, _ = _seed_sandbox(provider)
result = provider.get("dead-beef")
assert result is sandbox
backend.is_alive.assert_not_called()
# ── get() evicts dead sandbox when auto_restart is on ──────────────────────
def test_get_evicts_dead_sandbox_when_auto_restart_enabled():
"""When the container is dead and auto_restart is on, get() returns None and cleans caches."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_, info = _seed_sandbox(provider, sandbox_id="dead-beef", thread_id="thread-1")
result = provider.get("dead-beef")
assert result is None
assert "dead-beef" not in provider._sandboxes
assert "dead-beef" not in provider._sandbox_infos
assert "dead-beef" not in provider._last_activity
assert "thread-1" not in provider._thread_sandboxes
backend.destroy.assert_called_once_with(info)
def test_get_returns_dead_sandbox_when_auto_restart_disabled():
"""When auto_restart is off, get() returns the cached sandbox even if the container is dead."""
provider, backend = _make_provider(auto_restart=False, alive=False)
sandbox, _ = _seed_sandbox(provider)
result = provider.get("dead-beef")
assert result is sandbox
# Caches are untouched
assert "dead-beef" in provider._sandboxes
def test_get_eviction_cleans_multiple_thread_mappings():
"""A sandbox mapped to multiple thread IDs has all mappings cleaned on eviction."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-1", thread_id="t-a")
# Manually add a second thread mapping to the same sandbox
provider._thread_sandboxes["t-b"] = "sid-1"
result = provider.get("sid-1")
assert result is None
assert "t-a" not in provider._thread_sandboxes
assert "t-b" not in provider._thread_sandboxes
# ── get() does not check health for unknown sandbox IDs ────────────────────
def test_get_returns_none_for_unknown_id():
"""If the sandbox_id is not in cache, get() returns None without checking health."""
provider, backend = _make_provider(auto_restart=True, alive=True)
result = provider.get("nonexistent")
assert result is None
backend.is_alive.assert_not_called()
# ── get() handles missing sandbox_info gracefully ──────────────────────────
def test_get_handles_missing_info_gracefully():
"""If sandbox is cached but info is missing, get() skips the health check."""
provider, backend = _make_provider(auto_restart=True, alive=False)
sandbox = MagicMock()
provider._sandboxes["sid-x"] = sandbox
provider._sandbox_infos.pop("sid-x", None) # Ensure no info
provider._last_activity["sid-x"] = 0.0
result = provider.get("sid-x")
# No info → cannot call is_alive → sandbox returned as-is
assert result is sandbox
backend.is_alive.assert_not_called()
def test_get_liveness_check_runs_outside_provider_lock():
"""get() should not hold the provider lock while checking backend liveness."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-locked", thread_id="thread-1")
def _assert_lock_not_held(_):
assert not provider._lock.locked()
return False
backend.is_alive.side_effect = _assert_lock_not_held
assert provider.get("sid-locked") is None
def test_get_still_evicts_when_backend_destroy_fails():
"""Cleanup errors should not keep stale sandbox state in memory."""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-fail", thread_id="thread-1")
backend.destroy.side_effect = RuntimeError("boom")
assert provider.get("sid-fail") is None
assert "sid-fail" not in provider._sandboxes
assert "sid-fail" not in provider._sandbox_infos
assert "thread-1" not in provider._thread_sandboxes
backend.destroy.assert_called_once()
# ── Integration: eviction clears caches for recreation ─────────────────────
def test_eviction_clears_all_caches_for_recreation():
"""After eviction, all caches are clean so _acquire_internal can recreate.
This verifies the preconditions for transparent restart: when get() evicts
a dead sandbox, the next _acquire_internal call will find no cached entry,
no warm-pool entry, and fall through to _create_sandbox.
"""
provider, backend = _make_provider(auto_restart=True, alive=False)
_seed_sandbox(provider, sandbox_id="sid-1", thread_id="thread-1")
# Before eviction: caches populated
assert "sid-1" in provider._sandboxes
assert "sid-1" in provider._sandbox_infos
assert "thread-1" in provider._thread_sandboxes
# get() detects the dead container and evicts
assert provider.get("sid-1") is None
# After eviction: all caches clean
assert "sid-1" not in provider._sandboxes
assert "sid-1" not in provider._sandbox_infos
assert "thread-1" not in provider._thread_sandboxes
assert "sid-1" not in provider._warm_pool
# _acquire_internal for the same thread would find nothing cached
# and generate the deterministic ID, then discover fails (container
# is gone), falling through to _create_sandbox — a fresh start.
@@ -0,0 +1,166 @@
"""Tests for shared assistant payload replay helpers."""
from __future__ import annotations
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.models.assistant_payload_replay import (
restore_additional_kwargs_field,
restore_assistant_payloads,
restore_reasoning_content,
)
def _restore_reasoning(payload_msg: dict, orig_msg: AIMessage) -> None:
restore_additional_kwargs_field(payload_msg, orig_msg, "reasoning_content")
def test_restore_additional_kwargs_field_copies_present_values_only():
payload_message = {"role": "assistant"}
orig_message = AIMessage(
content="answer",
additional_kwargs={
"reasoning_content": "",
"ignored_none": None,
},
)
restore_additional_kwargs_field(payload_message, orig_message, "reasoning_content")
restore_additional_kwargs_field(payload_message, orig_message, "ignored_none")
restore_additional_kwargs_field(payload_message, orig_message, "missing")
assert payload_message == {"role": "assistant", "reasoning_content": ""}
def test_restore_reasoning_content_copies_reasoning_content():
payload_message = {"role": "assistant"}
orig_message = AIMessage(content="answer", additional_kwargs={"reasoning_content": "thought"})
restore_reasoning_content(payload_message, orig_message)
assert payload_message["reasoning_content"] == "thought"
def test_restore_assistant_payloads_matches_by_position_when_lengths_match():
original_messages = [
HumanMessage(content="question"),
AIMessage(content="answer", additional_kwargs={"reasoning_content": "thought"}),
]
payload_messages = [
{"role": "user", "content": "question"},
{"role": "assistant", "content": "answer"},
]
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
assert payload_messages[1]["reasoning_content"] == "thought"
def test_restore_assistant_payloads_fallback_matches_unique_content_signature():
original_messages = [
AIMessage(content="first", additional_kwargs={"reasoning_content": "first-thought"}),
AIMessage(content="second", additional_kwargs={"reasoning_content": "second-thought"}),
]
payload_messages = [{"role": "assistant", "content": "second"}]
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
assert payload_messages[0]["reasoning_content"] == "second-thought"
def test_restore_assistant_payloads_fallback_matches_unique_tool_call_signature():
original_messages = [
AIMessage(
content="",
additional_kwargs={"reasoning_content": "first-thought"},
tool_calls=[{"id": "call_first", "name": "tool", "args": {}}],
),
AIMessage(
content="",
additional_kwargs={"reasoning_content": "second-thought"},
tool_calls=[{"id": "call_second", "name": "tool", "args": {}}],
),
]
payload_messages = [
{
"role": "assistant",
"content": "",
"tool_calls": [{"id": "call_second", "type": "function", "function": {"name": "tool", "arguments": "{}"}}],
}
]
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
assert payload_messages[0]["reasoning_content"] == "second-thought"
def test_restore_assistant_payloads_fallback_matches_structured_content_signature():
original_messages = [
AIMessage(
content=[{"type": "text", "text": "first"}],
additional_kwargs={"reasoning_content": "first-thought"},
),
AIMessage(
content=[{"type": "text", "text": "second"}],
additional_kwargs={"reasoning_content": "second-thought"},
),
]
payload_messages = [{"role": "assistant", "content": [{"text": "second", "type": "text"}]}]
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
assert payload_messages[0]["reasoning_content"] == "second-thought"
def test_restore_assistant_payloads_fallback_uses_order_when_signature_is_ambiguous():
original_messages = [
AIMessage(content="", additional_kwargs={"reasoning_content": "first-thought"}),
AIMessage(content="", additional_kwargs={"reasoning_content": "second-thought"}),
]
payload_messages = [{"role": "assistant", "content": ""}]
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
assert payload_messages[0]["reasoning_content"] == "first-thought"
def test_restore_assistant_payloads_fallback_uses_next_unused_when_ordinal_taken():
# Serialization dropped a leading empty assistant message, so payload ordinals
# no longer line up with the original AIMessage indices. The first payload
# uniquely matches a non-ordinal index by signature, which leaves the later
# ambiguous payload's exact ordinal index already used. It must still fall
# back to the remaining unused AIMessage (scanning forward from the ordinal)
# instead of silently dropping the field.
original_messages = [
AIMessage(content="", additional_kwargs={"reasoning_content": "dropped-thought"}),
AIMessage(content="unique", additional_kwargs={"reasoning_content": "unique-thought"}),
AIMessage(content="", additional_kwargs={"reasoning_content": "trailing-thought"}),
]
payload_messages = [
{"role": "assistant", "content": "unique"},
{"role": "assistant", "content": ""},
]
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
assert payload_messages[0]["reasoning_content"] == "unique-thought"
# Forward scan from the taken ordinal picks the trailing message, not the
# dropped leading one (which a naive min-unused scan would wrongly select).
assert payload_messages[1]["reasoning_content"] == "trailing-thought"
def test_restore_assistant_payloads_does_not_wrap_to_earlier_unused_message():
original_messages = [
HumanMessage(content="leading user"),
AIMessage(content="", additional_kwargs={"reasoning_content": "dropped-leading-thought"}),
AIMessage(content="unique", additional_kwargs={"reasoning_content": "unique-thought"}),
]
payload_messages = [
{"role": "assistant", "content": "unique"},
{"role": "assistant", "content": ""},
]
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
assert payload_messages[0]["reasoning_content"] == "unique-thought"
assert "reasoning_content" not in payload_messages[1]
+88
View File
@@ -372,6 +372,25 @@ class TestExtractResponseText:
# Should return "" (no text in current turn), NOT "Hi there!" from previous turn
assert _extract_response_text(result) == ""
def test_ignores_hidden_human_control_messages(self):
"""Hidden control messages should not terminate current-turn response extraction."""
from app.channels.manager import _extract_response_text
result = {
"messages": [
{"type": "human", "content": "plan this"},
{"type": "ai", "content": "Here is the plan."},
{
"type": "human",
"name": "todo_reminder",
"content": "keep todos updated",
"additional_kwargs": {"hide_from_ui": True},
},
]
}
assert _extract_response_text(result) == "Here is the plan."
# ---------------------------------------------------------------------------
# ChannelManager tests
@@ -1678,6 +1697,31 @@ class TestExtractArtifacts:
}
assert _extract_artifacts(result) == ["/mnt/user-data/outputs/a.txt", "/mnt/user-data/outputs/b.csv"]
def test_ignores_hidden_human_control_messages(self):
"""Hidden control messages should not hide current-turn present_files artifacts."""
from app.channels.manager import _extract_artifacts
result = {
"messages": [
{"type": "human", "content": "export"},
{
"type": "ai",
"content": "Done.",
"tool_calls": [
{"name": "present_files", "args": {"filepaths": ["/mnt/user-data/outputs/plan.md"]}},
],
},
{
"type": "human",
"name": "todo_completion_reminder",
"content": "mark tasks complete",
"additional_kwargs": {"hide_from_ui": True},
},
]
}
assert _extract_artifacts(result) == ["/mnt/user-data/outputs/plan.md"]
class TestFormatArtifactText:
def test_single_artifact(self):
@@ -1790,6 +1834,50 @@ class TestHandleChatWithArtifacts:
_run(go())
def test_hidden_human_control_message_does_not_trigger_no_response_fallback(self):
"""Plan-mode hidden control messages should not mask the final AI response."""
from app.channels.manager import ChannelManager
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
manager = ChannelManager(bus=bus, store=store)
run_result = {
"messages": [
{"type": "human", "content": "make a plan"},
{"type": "ai", "content": "Here is a concrete plan."},
{
"type": "human",
"name": "todo_reminder",
"content": "sync todos",
"additional_kwargs": {"hide_from_ui": True},
},
]
}
mock_client = _make_mock_langgraph_client(run_result=run_result)
manager._client = mock_client
outbound_received = []
bus.subscribe_outbound(lambda msg: outbound_received.append(msg))
await manager.start()
await bus.publish_inbound(
InboundMessage(
channel_name="test",
chat_id="c1",
user_id="u1",
text="make a plan",
)
)
await _wait_for(lambda: len(outbound_received) >= 1)
await manager.stop()
assert len(outbound_received) == 1
assert outbound_received[0].text == "Here is a concrete plan."
_run(go())
def test_only_last_turn_artifacts_returned(self):
"""Only artifacts from the current turn's present_files calls should be included."""
from app.channels.manager import ChannelManager
@@ -333,8 +333,27 @@ class TestBuildPatchedMessagesPatching:
assert patched[1].tool_call_id == "write_file:36"
assert patched[1].name == "write_file"
assert patched[1].status == "error"
assert "write_file failed before execution" in patched[1].content
assert "no file was written" in patched[1].content
assert "very large Markdown file in a single tool call" in patched[1].content
assert "Do not retry the same large `write_file` payload" in patched[1].content
assert "split the file into smaller sections" in patched[1].content
assert "normal assistant text" in patched[1].content
assert "Failed to parse tool arguments" in patched[1].content
assert 'bad {"json"}' not in patched[1].content
def test_non_write_file_invalid_tool_call_uses_generic_recovery_message(self):
mw = DanglingToolCallMiddleware()
msgs = [_ai_with_invalid_tool_calls([_invalid_tc(name="search", tc_id="search:1")])]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert patched[1].tool_call_id == "search:1"
assert patched[1].name == "search"
assert "arguments were invalid" in patched[1].content
assert "Failed to parse tool arguments" in patched[1].content
assert "write_file failed before execution" not in patched[1].content
def test_valid_and_invalid_tool_calls_are_both_patched(self):
mw = DanglingToolCallMiddleware()
@@ -83,3 +83,24 @@ def test_frontend_rewrites_langgraph_prefix_to_gateway():
assert "DEER_FLOW_INTERNAL_LANGGRAPH_BASE_URL" not in next_config
assert "http://127.0.0.1:2024" not in next_config
assert "langgraph-compat" not in api_client
def test_smoke_test_docs_do_not_expect_standalone_langgraph_server():
smoke_files = {
".agent/skills/smoke-test/SKILL.md": _read(".agent/skills/smoke-test/SKILL.md"),
".agent/skills/smoke-test/references/SOP.md": _read(".agent/skills/smoke-test/references/SOP.md"),
".agent/skills/smoke-test/references/troubleshooting.md": _read(".agent/skills/smoke-test/references/troubleshooting.md"),
".agent/skills/smoke-test/scripts/check_local_env.sh": _read(".agent/skills/smoke-test/scripts/check_local_env.sh"),
".agent/skills/smoke-test/scripts/deploy_local.sh": _read(".agent/skills/smoke-test/scripts/deploy_local.sh"),
".agent/skills/smoke-test/scripts/health_check.sh": _read(".agent/skills/smoke-test/scripts/health_check.sh"),
".agent/skills/smoke-test/templates/report.local.template.md": _read(".agent/skills/smoke-test/templates/report.local.template.md"),
".agent/skills/smoke-test/templates/report.docker.template.md": _read(".agent/skills/smoke-test/templates/report.docker.template.md"),
}
for path, content in smoke_files.items():
assert "localhost:2024" not in content, path
assert "127.0.0.1:2024" not in content, path
assert "deer-flow-langgraph" not in content, path
assert "langgraph.log" not in content, path
assert "LangGraph service" not in content, path
assert "langgraph dev" not in content, path
@@ -0,0 +1,223 @@
"""Concurrency-safety tests for JsonlRunEventStore async I/O hardening (#2816).
Verifies:
- write-lock serialises concurrent puts within the same thread_id
- put_batch keeps monotonic seq even under concurrent callers
- seq recovery from disk on fresh store init
- DB put_batch rejects mixed-thread batches
"""
from __future__ import annotations
import asyncio
import tempfile
from pathlib import Path
import pytest
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_store(base_dir: Path) -> JsonlRunEventStore:
return JsonlRunEventStore(base_dir=base_dir)
# ---------------------------------------------------------------------------
# Write-lock: per-thread lock exists and is reused
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_get_write_lock_returns_asyncio_lock():
with tempfile.TemporaryDirectory() as tmp:
store = _make_store(Path(tmp))
lock = store._get_write_lock("t1")
assert isinstance(lock, asyncio.Lock)
@pytest.mark.anyio
async def test_get_write_lock_same_thread_reuses_lock():
with tempfile.TemporaryDirectory() as tmp:
store = _make_store(Path(tmp))
lock_a = store._get_write_lock("t1")
lock_b = store._get_write_lock("t1")
assert lock_a is lock_b
@pytest.mark.anyio
async def test_get_write_lock_different_threads_get_different_locks():
with tempfile.TemporaryDirectory() as tmp:
store = _make_store(Path(tmp))
lock_a = store._get_write_lock("t1")
lock_b = store._get_write_lock("t2")
assert lock_a is not lock_b
# ---------------------------------------------------------------------------
# Seq monotonicity under concurrent puts
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_concurrent_puts_produce_unique_monotonic_seqs():
"""10 concurrent puts on the same thread must yield distinct, monotonic seq values."""
with tempfile.TemporaryDirectory() as tmp:
store = _make_store(Path(tmp))
results = await asyncio.gather(*[store.put(thread_id="t1", run_id=f"r{i}", event_type="trace", category="trace", content=f"msg{i}") for i in range(10)])
seqs = sorted(r["seq"] for r in results)
assert seqs == list(range(1, 11)), f"Expected 1-10, got {seqs}"
@pytest.mark.anyio
async def test_concurrent_puts_different_threads_independent_seqs():
"""Concurrent puts on different threads keep independent seq counters."""
with tempfile.TemporaryDirectory() as tmp:
store = _make_store(Path(tmp))
t1_results, t2_results = await asyncio.gather(
asyncio.gather(*[store.put(thread_id="t1", run_id="r1", event_type="trace", category="trace") for _ in range(5)]),
asyncio.gather(*[store.put(thread_id="t2", run_id="r2", event_type="trace", category="trace") for _ in range(5)]),
)
t1_seqs = sorted(r["seq"] for r in t1_results)
t2_seqs = sorted(r["seq"] for r in t2_results)
assert t1_seqs == [1, 2, 3, 4, 5]
assert t2_seqs == [1, 2, 3, 4, 5]
# ---------------------------------------------------------------------------
# put_batch: delegates to put() and preserves order
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_put_batch_seqs_are_monotonic():
with tempfile.TemporaryDirectory() as tmp:
store = _make_store(Path(tmp))
events = [{"thread_id": "t1", "run_id": "r1", "event_type": "trace", "category": "trace", "content": str(i)} for i in range(5)]
results = await store.put_batch(events)
seqs = [r["seq"] for r in results]
assert seqs == sorted(seqs)
assert len(set(seqs)) == 5
# ---------------------------------------------------------------------------
# _ensure_seq_loaded: recovers max_seq from disk after fresh store init
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_ensure_seq_loaded_recovers_from_disk():
"""A fresh JsonlRunEventStore should pick up the max seq written by a previous instance."""
with tempfile.TemporaryDirectory() as tmp:
base = Path(tmp)
store1 = _make_store(base)
for i in range(3):
await store1.put(thread_id="t1", run_id="r1", event_type="trace", category="trace", content=str(i))
store2 = _make_store(base)
record = await store2.put(thread_id="t1", run_id="r1", event_type="trace", category="trace", content="new")
assert record["seq"] == 4, f"Expected seq=4 after recovery, got {record['seq']}"
# ---------------------------------------------------------------------------
# asyncio.to_thread regression guard
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_put_offloads_write_via_to_thread():
"""Regression guard: put() must call asyncio.to_thread for _write_record."""
original = asyncio.to_thread
calls: list[str] = []
async def spy(*args, **kwargs):
calls.append(args[0].__name__ if callable(args[0]) else repr(args[0]))
return await original(*args, **kwargs)
from unittest.mock import patch
with tempfile.TemporaryDirectory() as tmp:
store = _make_store(Path(tmp))
with patch("asyncio.to_thread", new=spy):
await store.put(thread_id="t1", run_id="r1", event_type="trace", category="trace", content="x")
assert "_write_record" in calls, f"Expected asyncio.to_thread(_write_record, ...) — got: {calls}"
# ---------------------------------------------------------------------------
# Read methods are non-blocking (asyncio.to_thread path exercised)
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_list_messages_reads_written_records():
with tempfile.TemporaryDirectory() as tmp:
store = _make_store(Path(tmp))
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hello")
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="world")
messages = await store.list_messages("t1")
assert len(messages) == 2
assert messages[0]["content"] == "hello"
assert messages[1]["content"] == "world"
@pytest.mark.anyio
async def test_count_messages_accurate_after_concurrent_writes():
with tempfile.TemporaryDirectory() as tmp:
store = _make_store(Path(tmp))
await asyncio.gather(*[store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") for _ in range(7)])
count = await store.count_messages("t1")
assert count == 7
# ---------------------------------------------------------------------------
# delete_by_thread and delete_by_run use the write lock
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_delete_by_thread_clears_seq_counter_and_lock():
with tempfile.TemporaryDirectory() as tmp:
store = _make_store(Path(tmp))
await store.put(thread_id="t1", run_id="r1", event_type="trace", category="trace")
await store.delete_by_thread("t1")
assert "t1" not in store._seq_counters
assert "t1" not in store._write_locks
@pytest.mark.anyio
async def test_delete_by_run_removes_run_events():
with tempfile.TemporaryDirectory() as tmp:
store = _make_store(Path(tmp))
await store.put(thread_id="t1", run_id="r1", event_type="trace", category="trace")
await store.put(thread_id="t1", run_id="r2", event_type="trace", category="trace")
await store.delete_by_run("t1", "r1")
events = await store.list_events("t1", "r1")
assert events == []
# ---------------------------------------------------------------------------
# DB put_batch: rejects mixed-thread batches
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_db_put_batch_rejects_mixed_thread_ids():
"""DbRunEventStore.put_batch must raise ValueError for cross-thread batches."""
from unittest.mock import MagicMock
from deerflow.runtime.events.store.db import DbRunEventStore
mock_sf = MagicMock()
store = DbRunEventStore(session_factory=mock_sf)
events = [
{"thread_id": "t1", "run_id": "r1", "event_type": "trace", "category": "trace"},
{"thread_id": "t2", "run_id": "r2", "event_type": "trace", "category": "trace"},
]
with pytest.raises(ValueError, match="same thread"):
await store.put_batch(events)
@@ -476,6 +476,24 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
def test_create_summarization_middleware_uses_frontend_supported_update_key(monkeypatch):
"""LangGraph update keys use the middleware class name plus hook name."""
app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
app_config.summarization = SummarizationConfig(enabled=True)
app_config.memory = MemoryConfig(enabled=False)
fake_model = MagicMock()
fake_model.with_config.return_value = fake_model
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: fake_model)
middleware = lead_agent_module._create_summarization_middleware(app_config=app_config)
assert middleware is not None
update_key = f"{type(middleware).__name__}.before_model"
assert update_key == "DeerFlowSummarizationMiddleware.before_model"
def test_create_summarization_middleware_threads_resolved_app_config_to_model(monkeypatch):
fallback_app_config = _make_app_config([_make_model("fallback-model", supports_thinking=False)])
fallback_app_config.summarization = SummarizationConfig(enabled=True, model_name="fallback-model")
@@ -3,6 +3,7 @@
from __future__ import annotations
import os
import stat
import pytest
@@ -43,6 +44,20 @@ def test_write_is_atomic_overwrite(tmp_path, storage):
assert (tmp_path / "custom" / "demo-skill" / "SKILL.md").read_text() == "second"
def test_write_makes_written_path_sandbox_readable(tmp_path, storage):
skill_dir = tmp_path / "custom" / "demo-skill"
skill_dir.mkdir(parents=True)
skill_dir.chmod(0o700)
storage.write_custom_skill("demo-skill", "references/ref.md", "# ref")
ref_dir = skill_dir / "references"
ref_file = ref_dir / "ref.md"
assert stat.S_IMODE(skill_dir.stat().st_mode) & 0o055 == 0o055
assert stat.S_IMODE(ref_dir.stat().st_mode) & 0o055 == 0o055
assert stat.S_IMODE(ref_file.stat().st_mode) & 0o044 == 0o044
# ---------------------------------------------------------------------------
# Empty / blank path
# ---------------------------------------------------------------------------
+77
View File
@@ -407,3 +407,80 @@ def test_session_pool_tool_sync_wrapper_path_is_safe():
wrapped.func(url="https://example.com")
mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"})
# ---------------------------------------------------------------------------
# get_mcp_tools: HTTP transport should NOT be pooled
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_http_transport_tools_not_pooled():
"""HTTP/SSE transport tools should NOT be wrapped with the session pool."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import get_mcp_tools
class Args(BaseModel):
query: str = Field(..., description="query")
http_tool = StructuredTool(
name="myserver_search",
description="Search tool",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
stdio_tool = StructuredTool(
name="playwright_navigate",
description="Navigate browser",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
extensions_config = MagicMock()
extensions_config.get_enabled_mcp_servers.return_value = {
"myserver": MagicMock(type="http", url="http://localhost:8000/mcp", headers=None, command=None, args=[], env=None),
"playwright": MagicMock(type="stdio", command="npx", args=["-y", "@anthropic/mcp-server-playwright"], env=None, url=None, headers=None),
}
extensions_config.model_extra = {}
servers_config = {
"myserver": {"transport": "http", "url": "http://localhost:8000/mcp"},
"playwright": {"transport": "stdio", "command": "npx", "args": ["-y", "@anthropic/mcp-server-playwright"]},
}
with (
patch("deerflow.mcp.tools.ExtensionsConfig.from_file", return_value=extensions_config),
patch("deerflow.mcp.tools.build_servers_config", return_value=servers_config),
patch("deerflow.mcp.tools.get_initial_oauth_headers", return_value={}),
patch("deerflow.mcp.tools.build_oauth_tool_interceptor", return_value=None),
patch("langchain_mcp_adapters.client.MultiServerMCPClient") as MockClient,
patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm),
):
mock_client_instance = MockClient.return_value
mock_client_instance.get_tools = AsyncMock(return_value=[http_tool, stdio_tool])
tools = await get_mcp_tools()
pool = get_session_pool()
# Tool discovery is lazy: no pooled sessions are created until a wrapped tool is invoked.
assert list(pool._entries.keys()) == []
# Verify the HTTP tool was NOT wrapped with the pool (it's the original tool).
http_tools = [t for t in tools if t.name == "myserver_search"]
assert len(http_tools) == 1
assert http_tools[0].coroutine is http_tool.coroutine
# Verify the stdio tool WAS wrapped with the pool.
stdio_tools = [t for t in tools if t.name == "playwright_navigate"]
assert len(stdio_tools) == 1
assert stdio_tools[0].coroutine is not stdio_tool.coroutine
+98
View File
@@ -563,6 +563,28 @@ class TestUpdateMemoryStructuredResponse:
model.invoke = MagicMock(return_value=response)
return model
def _run_update_with_response(self, content):
updater = MemoryUpdater()
mock_storage = MagicMock()
mock_storage.save = MagicMock(return_value=True)
with (
patch.object(updater, "_get_model", return_value=self._make_mock_model(content)),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True, fact_confidence_threshold=0.7, max_facts=100)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Remember that I prefer concise updates."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Got it."
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], thread_id="thread-memory")
return result, mock_storage
def test_string_response_parses(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
@@ -609,6 +631,82 @@ class TestUpdateMemoryStructuredResponse:
assert result is True
def test_wrapped_json_responses_parse(self):
"""Memory update should tolerate provider wrappers around valid JSON."""
valid_json = '{"user": {}, "history": {}, "newFacts": [{"content": "User prefers concise updates", "category": "preference", "confidence": 0.9}], "factsToRemove": []}'
response_variants = [
f"<think>Analyze the conversation first.</think>\n{valid_json}",
f"<think>Analyze the conversation first.\n{valid_json}",
f"Here is the memory update:\n{valid_json}",
f"{valid_json}\nDone.",
f"```json\n{valid_json}\n```",
]
for content in response_variants:
result, mock_storage = self._run_update_with_response(content)
assert result is True
saved_memory = mock_storage.save.call_args.args[0]
assert saved_memory["facts"][0]["content"] == "User prefers concise updates"
def test_ignores_unrelated_json_before_memory_update(self):
"""Parser should not select unrelated JSON objects before the memory update."""
valid_json = '{"user": {}, "history": {}, "newFacts": [{"content": "Remember the actual update", "category": "context", "confidence": 0.9}], "factsToRemove": []}'
response = f'Example object: {{"user": "alice"}}\nActual memory update:\n{valid_json}'
result, mock_storage = self._run_update_with_response(response)
assert result is True
saved_memory = mock_storage.save.call_args.args[0]
assert saved_memory["facts"][0]["content"] == "Remember the actual update"
def test_invalid_json_response_is_skipped_without_saving(self):
"""Truncated JSON should remain a safe skipped update, not guessed repair."""
result, mock_storage = self._run_update_with_response('{"user": {}, "history": {}, "newFacts": [')
assert result is False
mock_storage.save.assert_not_called()
def test_schema_guard_ignores_invalid_update_fields(self):
"""Parsed JSON with bad field types should not break the memory update."""
response = '{"user": "bad", "history": [], "newFacts": ["bad", {"content": "User works on DeerFlow", "category": "context", "confidence": 0.91}], "factsToRemove": "bad"}'
result, mock_storage = self._run_update_with_response(response)
assert result is True
saved_memory = mock_storage.save.call_args.args[0]
assert [fact["content"] for fact in saved_memory["facts"]] == ["User works on DeerFlow"]
def test_fact_schema_guard_coerces_and_filters_nested_fields(self):
"""Malformed fact entries should be normalized per fact, not fail the whole update."""
response = (
'{"user": {}, "history": {}, "newFacts": ['
'{"content": " User likes async updates ", "category": 9, "confidence": "0.91", "sourceError": " parse issue "}, '
'{"content": "skip invalid confidence", "category": "context", "confidence": "high"}, '
'{"content": 12, "category": "context", "confidence": 0.9}, '
'{"content": " ", "category": "context", "confidence": 0.9}'
'], "factsToRemove": []}'
)
result, mock_storage = self._run_update_with_response(response)
assert result is True
saved_memory = mock_storage.save.call_args.args[0]
assert len(saved_memory["facts"]) == 1
assert saved_memory["facts"][0]["content"] == "User likes async updates"
assert saved_memory["facts"][0]["category"] == "context"
assert saved_memory["facts"][0]["confidence"] == 0.91
assert saved_memory["facts"][0]["sourceError"] == "parse issue"
def test_malformed_replacement_update_fails_closed(self):
"""Malformed replacement facts should not turn remove+add into delete-only."""
response = '{"user": {}, "history": {}, "newFacts": [{"content": "replacement fact", "category": "context", "confidence": "bad"}], "factsToRemove": ["fact_old"]}'
result, mock_storage = self._run_update_with_response(response)
assert result is False
mock_storage.save.assert_not_called()
def test_async_update_memory_delegates_to_sync(self):
"""aupdate_memory should delegate to sync _do_update_memory_sync via to_thread."""
updater = MemoryUpdater()
+35
View File
@@ -995,6 +995,41 @@ def test_openai_responses_api_settings_are_passed_to_chatopenai(monkeypatch):
assert captured.get("output_version") == "responses/v1"
# ---------------------------------------------------------------------------
# Provider class path resolution
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("model_id", ["mimo-v2.5-pro", "mimo-v2.5", "mimo-v2-flash"])
def test_create_chat_model_resolves_patched_mimo_provider(model_id):
from deerflow.models.patched_mimo import PatchedChatMiMo
model = ModelConfig(
name=f"{model_id}-thinking",
display_name=f"{model_id} Thinking",
description=None,
use="deerflow.models.patched_mimo:PatchedChatMiMo",
model=model_id,
api_key="test-key",
base_url="https://api.xiaomimimo.com/v1",
supports_thinking=True,
when_thinking_enabled={"extra_body": {"thinking": {"type": "enabled"}}},
supports_vision=False,
)
cfg = _make_app_config([model])
chat_model = factory_module.create_chat_model(
name=f"{model_id}-thinking",
thinking_enabled=True,
app_config=cfg,
attach_tracing=False,
)
assert isinstance(chat_model, PatchedChatMiMo)
assert chat_model.model_name == model_id
assert chat_model.extra_body["thinking"]["type"] == "enabled"
# ---------------------------------------------------------------------------
# Duplicate keyword argument collision (issue #1977)
# ---------------------------------------------------------------------------
@@ -0,0 +1,79 @@
"""Regression tests for the generated OpenAPI spec.
The Gateway exposes its FastAPI ``app.openapi()`` schema at ``/openapi.json``
and downstream tooling (SDK codegen, schema validators, client generators)
relies on ``operationId`` values being globally unique. FastAPI emits a
``UserWarning`` during spec generation when two routes share the same
``operationId`` concretely this happens when ``@router.api_route`` registers
one route for multiple HTTP methods, because the auto-generated unique id is
computed from a single method picked out of ``route.methods`` while OpenAPI
generation iterates over every method on that route.
These tests pin that invariant so the warning cannot silently come back.
"""
from __future__ import annotations
import warnings
import pytest
@pytest.fixture(scope="module")
def openapi_spec() -> dict:
"""Build the OpenAPI spec for the Gateway app once per module."""
from app.gateway.app import app
# ``app.openapi()`` caches the result on the FastAPI instance, so reset to
# force a fresh generation pass that triggers any duplicate-id warnings.
app.openapi_schema = None
return app.openapi()
def test_openapi_spec_has_no_duplicate_operation_warnings() -> None:
"""Generating the OpenAPI schema must not emit any ``Duplicate Operation ID`` UserWarning."""
from app.gateway.app import app
app.openapi_schema = None
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
app.openapi()
dup_messages = [str(item.message) for item in caught if "Duplicate Operation ID" in str(item.message)]
assert dup_messages == [], f"OpenAPI generation emitted duplicate operation id warnings: {dup_messages}"
def test_openapi_operation_ids_are_unique(openapi_spec: dict) -> None:
"""Every (path, method) operation in the spec must carry a unique ``operationId``."""
op_id_to_locations: dict[str, list[tuple[str, str]]] = {}
for path, path_item in openapi_spec.get("paths", {}).items():
for method, operation in path_item.items():
if not isinstance(operation, dict):
continue
op_id = operation.get("operationId")
if op_id is None:
continue
op_id_to_locations.setdefault(op_id, []).append((path, method))
duplicates = {op_id: locations for op_id, locations in op_id_to_locations.items() if len(locations) > 1}
assert not duplicates, f"Duplicate operationIds in OpenAPI spec: {duplicates}"
def test_stream_existing_run_exposes_distinct_get_and_post(openapi_spec: dict) -> None:
"""The ``/runs/{run_id}/stream`` endpoint must expose GET and POST as distinct operations.
LangGraph SDK ``joinStream`` uses GET while ``useStream``'s stop button uses POST, so
both methods must remain registered with their own ``operationId``.
"""
path = "/api/threads/{thread_id}/runs/{run_id}/stream"
path_item = openapi_spec["paths"].get(path)
assert path_item is not None, f"Expected {path} to be present in the OpenAPI spec"
assert "get" in path_item, f"Expected GET handler on {path}"
assert "post" in path_item, f"Expected POST handler on {path}"
get_op_id = path_item["get"].get("operationId")
post_op_id = path_item["post"].get("operationId")
assert get_op_id and post_op_id, "Both GET and POST must have operationIds"
assert get_op_id != post_op_id, f"GET and POST share operationId {get_op_id!r}, which breaks OpenAPI codegen"
+169
View File
@@ -0,0 +1,169 @@
"""Tests for deerflow.models.patched_mimo.PatchedChatMiMo."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
def _make_model(**kwargs):
from deerflow.models.patched_mimo import PatchedChatMiMo
return PatchedChatMiMo(
model="mimo-v2.5-pro",
api_key="test-key",
base_url="https://api.xiaomimimo.com/v1",
**kwargs,
)
def test_is_lc_serializable_returns_true():
from deerflow.models.patched_mimo import PatchedChatMiMo
assert PatchedChatMiMo.is_lc_serializable() is True
def test_lc_secrets_contains_mimo_api_key_mapping():
model = _make_model()
assert model.lc_secrets["api_key"] == "MIMO_API_KEY"
assert model.lc_secrets["openai_api_key"] == "MIMO_API_KEY"
def test_reasoning_content_injected_into_assistant_tool_call_message():
model = _make_model()
human = HumanMessage(content="Check Beijing weather.")
ai = AIMessage(
content="",
additional_kwargs={"reasoning_content": "I need to call the weather tool."},
)
payload_message = {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_weather",
"type": "function",
"function": {"name": "get_weather", "arguments": '{"location":"Beijing"}'},
}
],
}
base_payload = {
"messages": [
{"role": "user", "content": "Check Beijing weather."},
payload_message,
]
}
with patch.object(type(model).__bases__[0], "_get_request_payload", return_value=base_payload):
with patch.object(model, "_convert_input") as mock_convert:
mock_convert.return_value = MagicMock(to_messages=lambda: [human, ai])
payload = model._get_request_payload([human, ai])
assert payload["messages"][1]["reasoning_content"] == "I need to call the weather tool."
def test_reasoning_content_is_noop_when_missing():
model = _make_model()
human = HumanMessage(content="hello")
ai = AIMessage(content="hi", additional_kwargs={})
base_payload = {
"messages": [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
]
}
with patch.object(type(model).__bases__[0], "_get_request_payload", return_value=base_payload):
with patch.object(model, "_convert_input") as mock_convert:
mock_convert.return_value = MagicMock(to_messages=lambda: [human, ai])
payload = model._get_request_payload([human, ai])
assert "reasoning_content" not in payload["messages"][1]
def test_create_chat_result_maps_message_reasoning_content():
model = _make_model()
response = {
"choices": [
{
"message": {
"role": "assistant",
"content": "The weather is sunny.",
"reasoning_content": "The tool returned sunny weather, so answer directly.",
"tool_calls": None,
},
"finish_reason": "stop",
}
],
"model": "mimo-v2.5-pro",
}
result = model._create_chat_result(response)
message = result.generations[0].message
assert message.content == "The weather is sunny."
assert message.additional_kwargs["reasoning_content"] == "The tool returned sunny weather, so answer directly."
def test_create_chat_result_reads_reasoning_content_from_message_attribute():
model = _make_model()
class FakeMessage:
reasoning_content = "Reasoning stored on the SDK message object."
class FakeChoice:
message = FakeMessage()
class FakeResponse:
choices = [FakeChoice()]
def model_dump(self, **kwargs):
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "Answer.",
},
"finish_reason": "stop",
}
],
"model": "mimo-v2.5-pro",
}
result = model._create_chat_result(FakeResponse())
assert result.generations[0].message.additional_kwargs["reasoning_content"] == "Reasoning stored on the SDK message object."
def test_convert_chunk_to_generation_chunk_preserves_reasoning_deltas():
model = _make_model()
first = model._convert_chunk_to_generation_chunk(
{"choices": [{"delta": {"role": "assistant", "reasoning_content": "I need "}}]},
AIMessageChunk,
{},
)
second = model._convert_chunk_to_generation_chunk(
{"choices": [{"delta": {"reasoning_content": "a tool."}}]},
AIMessageChunk,
{},
)
answer = model._convert_chunk_to_generation_chunk(
{"choices": [{"delta": {"content": "Done."}, "finish_reason": "stop"}], "model": "mimo-v2.5-pro"},
AIMessageChunk,
{},
)
assert first is not None
assert second is not None
assert answer is not None
combined = first.message + second.message + answer.message
assert combined.additional_kwargs["reasoning_content"] == "I need a tool."
assert combined.content == "Done."
+109
View File
@@ -0,0 +1,109 @@
"""Tests for the DDD run domain skeleton."""
import pytest
from deerflow.runtime.runs import DisconnectMode, RunStatus
from deerflow.runtime.runs.domain import (
AssistantId,
CancelAction,
EventSeq,
InvalidRunTransition,
MultitaskStrategy,
Run,
RunCancelled,
RunCompleted,
RunCreated,
RunFailed,
RunId,
RunScope,
RunStarted,
ThreadId,
)
from deerflow.runtime.runs.schemas import DisconnectMode as CompatDisconnectMode
from deerflow.runtime.runs.schemas import RunStatus as CompatRunStatus
def test_compat_schema_exports_use_domain_enums() -> None:
assert CompatRunStatus is RunStatus
assert CompatDisconnectMode is DisconnectMode
def test_create_run_records_pending_state_and_created_event() -> None:
run = Run.create(
run_id=RunId("run-1"),
thread_id=ThreadId("thread-1"),
assistant_id=AssistantId("lead_agent"),
scope=RunScope.stateful,
multitask_strategy=MultitaskStrategy.reject,
metadata={"source": "test"},
kwargs={"input": {"messages": []}},
created_at="2026-01-01T00:00:00+00:00",
)
assert run.status == RunStatus.pending
assert run.run_id == "run-1"
assert run.thread_id == "thread-1"
assert run.assistant_id == "lead_agent"
assert run.created_at == "2026-01-01T00:00:00+00:00"
assert run.updated_at == "2026-01-01T00:00:00+00:00"
events = run.pull_events()
assert len(events) == 1
assert isinstance(events[0], RunCreated)
assert events[0].metadata == {"source": "test"}
assert run.pull_events() == ()
def test_run_allows_pending_running_success_transition() -> None:
run = Run.create(run_id=RunId("run-1"), thread_id=ThreadId("thread-1"))
run.pull_events()
run.mark_started(at="2026-01-01T00:00:01+00:00")
run.mark_completed(at="2026-01-01T00:00:02+00:00")
assert run.status == RunStatus.success
assert run.updated_at == "2026-01-01T00:00:02+00:00"
events = run.pull_events()
assert [type(event) for event in events] == [RunStarted, RunCompleted]
def test_run_records_failed_and_cancelled_domain_events() -> None:
failed = Run.create(run_id=RunId("run-failed"), thread_id=ThreadId("thread-1"))
failed.pull_events()
failed.mark_started()
failed.mark_failed("boom", at="2026-01-01T00:00:03+00:00")
failed_events = failed.pull_events()
assert failed.status == RunStatus.error
assert isinstance(failed_events[-1], RunFailed)
assert failed_events[-1].status == RunStatus.error
assert failed_events[-1].error == "boom"
cancelled = Run.create(run_id=RunId("run-cancelled"), thread_id=ThreadId("thread-1"))
cancelled.pull_events()
cancelled.mark_cancelled(action=CancelAction.rollback)
cancelled_events = cancelled.pull_events()
assert cancelled.status == RunStatus.interrupted
assert isinstance(cancelled_events[-1], RunCancelled)
assert cancelled_events[-1].action == CancelAction.rollback
def test_terminal_run_cannot_transition_again() -> None:
run = Run.create(run_id=RunId("run-1"), thread_id=ThreadId("thread-1"))
run.mark_started()
run.mark_completed()
with pytest.raises(InvalidRunTransition) as exc:
run.mark_failed("too late")
assert exc.value.current == RunStatus.success
assert exc.value.target == RunStatus.error
def test_domain_value_objects_validate_minimal_invariants() -> None:
assert EventSeq(1).next() == EventSeq(2)
with pytest.raises(ValueError, match="EventSeq"):
EventSeq(-1)
with pytest.raises(ValueError, match="run_id"):
Run.create(run_id=RunId(" "), thread_id=ThreadId("thread-1"))
+19 -14
View File
@@ -96,25 +96,30 @@ class _ScriptedAgent:
del subgraphs
self.controller.started.set()
thread_id = _thread_id_from_config(config)
human_text = _last_human_text(graph_input)
human = HumanMessage(content=human_text)
ai = await self.model.ainvoke([human], config=config)
state = {"messages": [human.model_dump(), ai.model_dump()], "title": self.title}
try:
thread_id = _thread_id_from_config(config)
human_text = _last_human_text(graph_input)
human = HumanMessage(content=human_text)
ai = await self.model.ainvoke([human], config=config)
state = {"messages": [human.model_dump(), ai.model_dump()], "title": self.title}
if self.checkpointer is not None:
await _write_checkpoint(self.checkpointer, thread_id=thread_id, state=state)
self.controller.checkpoint_written.set()
if self.checkpointer is not None:
await _write_checkpoint(self.checkpointer, thread_id=thread_id, state=state)
self.controller.checkpoint_written.set()
yield _stream_item_for_mode(stream_mode, state)
yield _stream_item_for_mode(stream_mode, state)
if self.block_after_first_chunk:
try:
if self.block_after_first_chunk:
while not self.controller.release.is_set():
await asyncio.sleep(0.05)
except asyncio.CancelledError:
self.controller.cancelled.set()
raise
except asyncio.CancelledError:
# Catch cancellation arriving anywhere in the body — including the
# `await ainvoke()` / `_write_checkpoint()` / `yield` points between
# ``started.set()`` and the original inner ``try`` — so tests that
# wait for ``cancelled`` after issuing ``POST /cancel`` no longer
# race with cancellation arriving early.
self.controller.cancelled.set()
raise
def _make_agent_factory(controller: _RunController, **agent_kwargs):
+63
View File
@@ -0,0 +1,63 @@
import stat
from deerflow.skills.permissions import make_skill_tree_sandbox_readable, make_skill_written_path_sandbox_readable
def _mode(path):
return stat.S_IMODE(path.stat().st_mode)
def test_skill_tree_readability_includes_hidden_paths_and_removes_sandbox_write(tmp_path):
root = tmp_path / "demo-skill"
hidden_dir = root / ".hidden"
scripts_dir = root / "scripts"
hidden_dir.mkdir(parents=True)
scripts_dir.mkdir()
env_file = root / ".env"
hidden_file = hidden_dir / ".secret"
script_file = scripts_dir / "run.sh"
env_file.write_text("secret", encoding="utf-8")
hidden_file.write_text("secret", encoding="utf-8")
script_file.write_text("#!/bin/sh\n", encoding="utf-8")
root.chmod(0o777)
hidden_dir.chmod(0o777)
scripts_dir.chmod(0o777)
env_file.chmod(0o666)
hidden_file.chmod(0o600)
script_file.chmod(0o777)
make_skill_tree_sandbox_readable(root)
assert _mode(root) == 0o755
assert _mode(hidden_dir) == 0o755
assert _mode(scripts_dir) == 0o755
assert _mode(env_file) == 0o644
assert _mode(hidden_file) == 0o644
assert _mode(script_file) == 0o755
def test_written_path_readability_is_limited_to_written_path(tmp_path):
root = tmp_path / "demo-skill"
ref_dir = root / "references"
sibling_dir = root / "templates"
ref_dir.mkdir(parents=True)
sibling_dir.mkdir()
target = ref_dir / "guide.md"
sibling = sibling_dir / "note.md"
target.write_text("guide", encoding="utf-8")
sibling.write_text("note", encoding="utf-8")
root.chmod(0o700)
ref_dir.chmod(0o700)
target.chmod(0o600)
sibling_dir.chmod(0o700)
sibling.chmod(0o600)
make_skill_written_path_sandbox_readable(root, target)
assert _mode(root) == 0o755
assert _mode(ref_dir) == 0o755
assert _mode(target) == 0o644
assert _mode(sibling_dir) == 0o700
assert _mode(sibling) == 0o600
@@ -1,14 +1,18 @@
import errno
import json
import stat
import zipfile
from io import BytesIO
from pathlib import Path
from types import SimpleNamespace
from _router_auth_helpers import make_authed_test_app
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.gateway.deps import get_config
from app.gateway.routers import skills as skills_router
from app.gateway.routers import uploads as uploads_router
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.types import Skill
@@ -53,6 +57,15 @@ def _make_skill_archive(tmp_path: Path, name: str, content: str | None = None) -
return archive
def _make_skill_archive_bytes(name: str, content: str | None = None) -> bytes:
buffer = BytesIO()
skill_content = content or _skill_content(name)
with zipfile.ZipFile(buffer, "w") as zf:
zf.writestr(f"{name}/SKILL.md", skill_content)
zf.writestr(f"{name}/references/guide.md", "# Guide\n")
return buffer.getvalue()
def test_install_skill_archive_runs_security_scan(monkeypatch, tmp_path):
skills_root = tmp_path / "skills"
(skills_root / "custom").mkdir(parents=True)
@@ -101,6 +114,65 @@ def test_install_skill_archive_runs_security_scan(monkeypatch, tmp_path):
assert refresh_calls == ["refresh"]
def test_uploaded_skill_archive_installs_sandbox_readable_tree(monkeypatch, tmp_path):
home = tmp_path / "home"
skills_root = tmp_path / "skills"
skills_root.mkdir()
refresh_calls = []
async def _scan(*args, **kwargs):
from deerflow.skills.security_scanner import ScanResult
return ScanResult(decision="allow", reason="ok")
async def _refresh():
refresh_calls.append("refresh")
config = SimpleNamespace(
skills=SimpleNamespace(get_skills_path=lambda: skills_root, container_path="/mnt/skills", use="deerflow.skills.storage.local_skill_storage:LocalSkillStorage"),
skill_evolution=SimpleNamespace(enabled=True, moderation_model_name=None),
uploads=SimpleNamespace(auto_convert_documents=False),
)
provider = SimpleNamespace(uses_thread_data_mounts=True)
monkeypatch.setenv("DEER_FLOW_HOME", str(home))
monkeypatch.setattr("deerflow.config.paths._paths", None)
monkeypatch.setattr(uploads_router, "get_sandbox_provider", lambda: provider)
monkeypatch.setattr("deerflow.skills.installer.scan_skill_content", _scan)
monkeypatch.setattr(skills_router, "refresh_skills_system_prompt_cache_async", _refresh)
app = make_authed_test_app()
app.state.config = config
app.dependency_overrides[get_config] = lambda: config
app.include_router(uploads_router.router)
app.include_router(skills_router.router)
thread_id = "thread-uploaded-skill"
archive_bytes = _make_skill_archive_bytes("uploaded-skill")
with TestClient(app) as client:
upload_response = client.post(
f"/api/threads/{thread_id}/uploads",
files=[("files", ("uploaded-skill.skill", archive_bytes, "application/octet-stream"))],
)
assert upload_response.status_code == 200
uploaded_file = upload_response.json()["files"][0]
uploaded_path = Path(uploaded_file["path"])
assert uploaded_path.is_file()
install_response = client.post("/api/skills/install", json={"thread_id": thread_id, "path": uploaded_file["virtual_path"]})
assert install_response.status_code == 200
assert install_response.json()["skill_name"] == "uploaded-skill"
installed_dir = skills_root / "custom" / "uploaded-skill"
nested_dir = installed_dir / "references"
assert stat.S_IMODE(installed_dir.stat().st_mode) & 0o055 == 0o055
assert stat.S_IMODE(nested_dir.stat().st_mode) & 0o055 == 0o055
assert stat.S_IMODE((installed_dir / "SKILL.md").stat().st_mode) & 0o044 == 0o044
assert stat.S_IMODE((nested_dir / "guide.md").stat().st_mode) & 0o044 == 0o044
assert refresh_calls == ["refresh"]
def test_install_skill_archive_security_scan_block_returns_400(monkeypatch, tmp_path):
skills_root = tmp_path / "skills"
(skills_root / "custom").mkdir(parents=True)
@@ -175,6 +247,7 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
)
assert update_response.status_code == 200
assert update_response.json()["description"] == "Edited skill"
assert stat.S_IMODE((custom_dir / "SKILL.md").stat().st_mode) & 0o044 == 0o044
history_response = client.get("/api/skills/custom/demo-skill/history")
assert history_response.status_code == 200
@@ -183,6 +256,7 @@ def test_custom_skills_router_lifecycle(monkeypatch, tmp_path):
rollback_response = client.post("/api/skills/custom/demo-skill/rollback", json={"history_index": -1})
assert rollback_response.status_code == 200
assert rollback_response.json()["description"] == "Demo skill"
assert stat.S_IMODE((custom_dir / "SKILL.md").stat().st_mode) & 0o044 == 0o044
assert refresh_calls == ["refresh", "refresh"]
+20
View File
@@ -198,6 +198,26 @@ class TestInstallSkillFromArchive:
assert result["skill_name"] == "test-skill"
assert (skills_root / "custom" / "test-skill" / "SKILL.md").exists()
def test_installed_skill_tree_is_readable_by_sandbox_mount(self, tmp_path):
zip_path = tmp_path / "test-skill.skill"
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("test-skill/SKILL.md", "---\nname: test-skill\ndescription: A test skill\n---\n\n# test-skill\n")
zf.writestr("test-skill/references/guide.md", "# Guide\n")
skills_root = tmp_path / "skills"
skills_root.mkdir()
get_or_new_skill_storage(skills_path=skills_root).install_skill_from_archive(zip_path)
installed_dir = skills_root / "custom" / "test-skill"
nested_dir = installed_dir / "references"
skill_file = installed_dir / "SKILL.md"
guide_file = nested_dir / "guide.md"
assert stat.S_IMODE(installed_dir.stat().st_mode) & 0o055 == 0o055
assert stat.S_IMODE(nested_dir.stat().st_mode) & 0o055 == 0o055
assert stat.S_IMODE(skill_file.stat().st_mode) & 0o044 == 0o044
assert stat.S_IMODE(guide_file.stat().st_mode) & 0o044 == 0o044
def test_scans_skill_markdown_before_install(self, tmp_path, monkeypatch):
zip_path = self._make_skill_zip(tmp_path)
skills_root = tmp_path / "skills"
@@ -5,7 +5,10 @@ from unittest import mock
from unittest.mock import MagicMock
import pytest
from langchain.agents import create_agent
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY, DynamicContextMiddleware
@@ -22,6 +25,23 @@ def _messages() -> list:
]
class _StaticChatModel(BaseChatModel):
text: str = "ok"
@property
def _llm_type(self) -> str:
return "static-test-chat-model"
def bind_tools(self, tools, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=self.text))])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
return HumanMessage(
content="<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
@@ -114,6 +134,32 @@ def test_before_summarization_hook_receives_messages_before_compression() -> Non
assert result["messages"][1].content.startswith("Here is a summary")
def test_summarization_middleware_emits_frontend_update_key_in_agent_stream() -> None:
middleware = DeerFlowSummarizationMiddleware(
model=_StaticChatModel(text="compressed summary"),
trigger=("messages", 4),
keep=("messages", 2),
token_counter=len,
)
agent = create_agent(
model=_StaticChatModel(text="done"),
tools=[],
middleware=[middleware],
)
chunks = list(agent.stream({"messages": _messages()}, stream_mode="updates"))
update = next(
(chunk["DeerFlowSummarizationMiddleware.before_model"] for chunk in chunks if "DeerFlowSummarizationMiddleware.before_model" in chunk),
None,
)
assert update is not None
emitted = update["messages"]
assert isinstance(emitted[0], RemoveMessage)
assert emitted[1].name == "summary"
assert emitted[1].content == ("Here is a summary of the conversation to date:\n\ncompressed summary")
def test_dynamic_context_reminder_is_preserved_across_summarization() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append])
@@ -134,12 +134,14 @@ def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
assert captured["app_config"] is app_config
# 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling,
# SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware
# (enabled by default — see SafetyFinishReasonConfig).
# 7 baseline (ToolOutputBudget, ThreadData, Sandbox, DanglingToolCall,
# LLMErrorHandling, SandboxAudit, ToolErrorHandling)
# + 1 SafetyFinishReasonMiddleware (enabled by default).
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
from deerflow.agents.middlewares.tool_output_budget_middleware import ToolOutputBudgetMiddleware
assert len(middlewares) == 7
assert len(middlewares) == 8
assert isinstance(middlewares[0], ToolOutputBudgetMiddleware)
assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares)
assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware)
@@ -0,0 +1,890 @@
"""Comprehensive tests for ToolOutputBudgetMiddleware.
Covers: pass-through, disk externalization, fallback truncation, UTF-8
boundaries, Command results, model-request history patching, config
variations, exempt tools, per-tool overrides, edge cases, and both
sync/async code paths.
"""
from __future__ import annotations
import os
import tempfile
from types import SimpleNamespace
import pytest
from langchain.agents.middleware.types import ModelRequest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langgraph.types import Command
from deerflow.agents.middlewares.tool_output_budget_middleware import (
ToolOutputBudgetMiddleware,
_build_fallback,
_build_preview,
_effective_trigger,
_externalize,
_message_text,
_needs_budget,
_patch_model_messages,
_sanitize_tool_name,
_snap_to_line_boundary,
_tool_message_over_budget,
)
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.tool_output_config import ToolOutputConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_request(tool_name: str = "remote_executor", tool_call_id: str = "tc-1", outputs_path: str | None = None) -> SimpleNamespace:
thread_data = {"outputs_path": outputs_path} if outputs_path else None
state = {"thread_data": thread_data} if thread_data else {}
runtime = SimpleNamespace(state=state)
return SimpleNamespace(
tool_call={"name": tool_name, "id": tool_call_id},
runtime=runtime,
)
def _tm(content: str = "ok", name: str = "tool", tool_call_id: str = "tc-1") -> ToolMessage:
return ToolMessage(content=content, name=name, tool_call_id=tool_call_id)
# ===========================================================================
# Unit tests for helper functions
# ===========================================================================
class TestMessageText:
def test_string_content(self):
assert _message_text("hello") == "hello"
def test_none_content(self):
assert _message_text(None) is None
def test_list_of_strings(self):
assert _message_text(["a", "b"]) == "a\nb"
def test_list_of_text_dicts(self):
assert _message_text([{"text": "x"}, {"text": "y"}]) == "x\ny"
def test_list_with_image_returns_none(self):
assert _message_text([{"type": "image", "data": "..."}]) is None
def test_empty_list(self):
assert _message_text([]) is None
def test_non_string_non_list(self):
assert _message_text(42) is None
class TestSnapToLineBoundary:
def test_snaps_to_newline(self):
text = "line1\nline2\nline3"
pos = 14 # inside "line3"
result = _snap_to_line_boundary(text, pos)
assert text[result - 1] == "\n"
def test_no_snap_when_no_newline_in_range(self):
text = "abcdefghij"
assert _snap_to_line_boundary(text, 8) == 8
def test_zero_pos(self):
assert _snap_to_line_boundary("abc", 0) == 0
def test_pos_beyond_length(self):
assert _snap_to_line_boundary("abc", 10) == 10
class TestExternalize:
def test_writes_file_and_returns_virtual_path(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = _externalize(
"full content here",
tool_name="bash",
tool_call_id="tc-1",
outputs_path=tmpdir,
storage_subdir=".tool-results",
)
assert path is not None
assert path.startswith("/mnt/user-data/outputs/.tool-results/bash-")
assert path.endswith(".log")
# Verify actual file on disk
storage_dir = os.path.join(tmpdir, ".tool-results")
files = os.listdir(storage_dir)
assert len(files) == 1
with open(os.path.join(storage_dir, files[0]), encoding="utf-8") as f:
assert f.read() == "full content here"
def test_returns_none_on_invalid_path(self):
path = _externalize(
"data",
tool_name="test",
tool_call_id="tc-1",
outputs_path="/nonexistent/path/that/should/not/exist",
storage_subdir=".tool-results",
)
assert path is None
def test_txt_extension_for_unknown_tool(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = _externalize(
"data",
tool_name="unknown_tool",
tool_call_id="tc-1",
outputs_path=tmpdir,
storage_subdir=".tool-results",
)
assert path is not None
assert path.endswith(".txt")
class TestSanitizeToolName:
def test_strips_path_separators(self):
assert _sanitize_tool_name("../../etc/passwd") == "passwd"
def test_strips_backslashes(self):
result = _sanitize_tool_name("..\\..\\windows\\system32")
assert ".." not in result
assert "/" not in result
def test_normal_name_unchanged(self):
assert _sanitize_tool_name("bash") == "bash"
def test_empty_becomes_unknown(self):
assert _sanitize_tool_name("") == "unknown"
def test_dots_only_becomes_unknown(self):
assert _sanitize_tool_name("..") == "unknown"
class TestExternalizePathTraversal:
def test_traversal_tool_name_is_sanitized(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = _externalize(
"data",
tool_name="../../etc/passwd",
tool_call_id="tc-1",
outputs_path=tmpdir,
storage_subdir=".tool-results",
)
assert path is not None
assert "passwd-" in path
assert "../" not in path
def test_absolute_storage_subdir_rejected(self):
path = _externalize(
"data",
tool_name="tool",
tool_call_id="tc-1",
outputs_path="/tmp",
storage_subdir="/etc/evil",
)
assert path is None
def test_traversal_storage_subdir_rejected(self):
path = _externalize(
"data",
tool_name="tool",
tool_call_id="tc-1",
outputs_path="/tmp",
storage_subdir="../../../etc",
)
assert path is None
class TestNeedsBudget:
def test_small_output_does_not_need_budget(self):
config = ToolOutputConfig(externalize_min_chars=1000)
msg = _tm("small", name="tool")
assert _needs_budget(msg, config) is False
def test_large_output_needs_budget(self):
config = ToolOutputConfig(externalize_min_chars=50)
msg = _tm("x" * 100, name="tool")
assert _needs_budget(msg, config) is True
def test_exempt_tool_does_not_need_budget(self):
config = ToolOutputConfig(externalize_min_chars=10)
msg = _tm("x" * 100, name="read_file")
assert _needs_budget(msg, config) is False
def test_multimodal_does_not_need_budget(self):
config = ToolOutputConfig(externalize_min_chars=10)
msg = ToolMessage(content=[{"type": "image", "data": "x" * 100}], name="tool", tool_call_id="tc-1")
assert _needs_budget(msg, config) is False
class TestBuildPreview:
def test_contains_head_and_tail_and_reference(self):
content = "HEAD_" + "x" * 5000 + "_TAIL"
preview = _build_preview(
content,
tool_name="bash",
virtual_path="/mnt/test/bash-abc.log",
head_chars=100,
tail_chars=50,
)
assert preview.startswith("HEAD_")
assert "_TAIL" in preview
assert "/mnt/test/bash-abc.log" in preview
assert "read_file" in preview
assert "start_line and end_line" in preview
def test_reports_total_chars(self):
content = "a" * 10000
preview = _build_preview(
content,
tool_name="web_search",
virtual_path="/mnt/test/file.txt",
head_chars=200,
tail_chars=100,
)
assert "10000 chars" in preview
class TestBuildFallback:
def test_short_content_unchanged(self):
assert _build_fallback("short", tool_name="t", max_chars=100, head_chars=50, tail_chars=50) == "short"
def test_zero_max_disables(self):
content = "a" * 1000
assert _build_fallback(content, tool_name="t", max_chars=0, head_chars=50, tail_chars=50) == content
def test_truncates_long_content(self):
content = "H" * 5000 + "M" * 20000 + "T" * 5000
result = _build_fallback(content, tool_name="bash", max_chars=12000, head_chars=6000, tail_chars=3000)
assert len(result) < len(content)
assert "omitted from bash output" in result
assert "Persistent storage unavailable" in result
def test_preserves_head_and_tail(self):
content = "HEADSTART" + "x" * 50000 + "TAILEND"
result = _build_fallback(content, tool_name="t", max_chars=20000, head_chars=10000, tail_chars=5000)
assert result.startswith("HEADSTART")
assert "TAILEND" in result
def test_result_never_exceeds_max_chars(self):
"""The marker itself has non-zero length; total must still respect max_chars."""
for max_chars in [200, 500, 1000, 5000, 20000]:
content = "x" * 50000
result = _build_fallback(content, tool_name="long_tool_name", max_chars=max_chars, head_chars=max_chars // 2, tail_chars=max_chars // 4)
assert len(result) <= max_chars, f"max_chars={max_chars}: got {len(result)}"
def test_very_small_max_chars_does_not_crash(self):
content = "x" * 1000
result = _build_fallback(content, tool_name="t", max_chars=50, head_chars=20, tail_chars=10)
assert len(result) <= 50
# ===========================================================================
# Middleware integration tests — wrap_tool_call
# ===========================================================================
class TestWrapToolCallPassThrough:
def test_small_output_passes_through(self):
mw = ToolOutputBudgetMiddleware(config=ToolOutputConfig(externalize_min_chars=1000))
msg = _tm("small output", name="bash")
result = mw.wrap_tool_call(_make_request(), lambda _: msg)
assert result is msg
def test_disabled_middleware_passes_through(self):
mw = ToolOutputBudgetMiddleware(config=ToolOutputConfig(enabled=False, externalize_min_chars=10, fallback_max_chars=20))
msg = _tm("x" * 50000, name="bash")
result = mw.wrap_tool_call(_make_request(), lambda _: msg)
assert result is msg
class TestWrapToolCallExternalize:
def test_oversized_output_externalized_to_disk(self):
with tempfile.TemporaryDirectory() as tmpdir:
config = ToolOutputConfig(externalize_min_chars=100, preview_head_chars=50, preview_tail_chars=30)
mw = ToolOutputBudgetMiddleware(config=config)
content = "x" * 500
msg = _tm(content, name="remote_executor")
req = _make_request(outputs_path=tmpdir)
result = mw.wrap_tool_call(req, lambda _: msg)
assert isinstance(result, ToolMessage)
assert result is not msg
assert "Full remote_executor output saved to" in result.content
assert "read_file" in result.content
assert result.tool_call_id == "tc-1"
# Verify file was written
storage_dir = os.path.join(tmpdir, ".tool-results")
assert os.path.isdir(storage_dir)
files = os.listdir(storage_dir)
assert len(files) == 1
with open(os.path.join(storage_dir, files[0]), encoding="utf-8") as f:
assert f.read() == content
def test_preview_contains_head_and_tail(self):
with tempfile.TemporaryDirectory() as tmpdir:
config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10)
mw = ToolOutputBudgetMiddleware(config=config)
content = "HEADPART_" + "m" * 200 + "_TAILPART"
msg = _tm(content, name="web_search")
req = _make_request(outputs_path=tmpdir)
result = mw.wrap_tool_call(req, lambda _: msg)
assert result.content.startswith("HEADPART_")
assert "_TAILPART" in result.content
class TestWrapToolCallFallback:
def test_fallback_when_no_outputs_path(self):
config = ToolOutputConfig(
externalize_min_chars=50,
fallback_max_chars=200,
fallback_head_chars=80,
fallback_tail_chars=40,
)
mw = ToolOutputBudgetMiddleware(config=config)
content = "x" * 500
msg = _tm(content, name="mcp_tool")
req = _make_request(outputs_path=None)
result = mw.wrap_tool_call(req, lambda _: msg)
assert isinstance(result, ToolMessage)
assert result is not msg
assert "omitted from mcp_tool output" in result.content
assert "Persistent storage unavailable" in result.content
assert len(result.content) < len(content)
def test_fallback_when_disk_write_fails(self):
config = ToolOutputConfig(
externalize_min_chars=50,
fallback_max_chars=200,
fallback_head_chars=80,
fallback_tail_chars=40,
)
mw = ToolOutputBudgetMiddleware(config=config)
content = "x" * 500
msg = _tm(content, name="tool")
req = _make_request(outputs_path="/nonexistent/impossible/path")
result = mw.wrap_tool_call(req, lambda _: msg)
assert isinstance(result, ToolMessage)
assert "omitted from tool output" in result.content
class TestWrapToolCallExemption:
def test_read_file_exempt(self):
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=50)
mw = ToolOutputBudgetMiddleware(config=config)
content = "x" * 100
msg = _tm(content, name="read_file")
result = mw.wrap_tool_call(_make_request(tool_name="read_file"), lambda _: msg)
assert result is msg
def test_read_file_tool_exempt(self):
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=50)
mw = ToolOutputBudgetMiddleware(config=config)
content = "x" * 100
msg = _tm(content, name="read_file_tool")
result = mw.wrap_tool_call(_make_request(tool_name="read_file_tool"), lambda _: msg)
assert result is msg
def test_custom_exempt_tool(self):
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=50, exempt_tools=["my_tool"])
mw = ToolOutputBudgetMiddleware(config=config)
content = "x" * 100
msg = _tm(content, name="my_tool")
result = mw.wrap_tool_call(_make_request(tool_name="my_tool"), lambda _: msg)
assert result is msg
class TestWrapToolCallPerToolOverride:
def test_per_tool_threshold(self):
with tempfile.TemporaryDirectory() as tmpdir:
config = ToolOutputConfig(
externalize_min_chars=50000, # global: high
tool_overrides={"sensitive_tool": 100}, # override: low
)
mw = ToolOutputBudgetMiddleware(config=config)
content = "x" * 500
msg = _tm(content, name="sensitive_tool")
req = _make_request(tool_name="sensitive_tool", outputs_path=tmpdir)
result = mw.wrap_tool_call(req, lambda _: msg)
assert result is not msg
assert "Full sensitive_tool output saved to" in result.content
def test_per_tool_zero_disables_externalization(self):
config = ToolOutputConfig(
externalize_min_chars=50,
tool_overrides={"bash": 0},
fallback_max_chars=200,
fallback_head_chars=80,
fallback_tail_chars=40,
)
mw = ToolOutputBudgetMiddleware(config=config)
content = "x" * 500
msg = _tm(content, name="bash")
# Even with outputs_path, externalization disabled for bash
req = _make_request(tool_name="bash", outputs_path="/tmp/test")
result = mw.wrap_tool_call(req, lambda _: msg)
assert isinstance(result, ToolMessage)
# Should use fallback instead of externalization
assert "Persistent storage unavailable" in result.content or "omitted" in result.content
class TestWrapToolCallCommand:
def test_command_messages_are_patched(self):
with tempfile.TemporaryDirectory() as tmpdir:
config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10)
mw = ToolOutputBudgetMiddleware(config=config)
tool_msg = _tm("x" * 200, name="present_files")
command = Command(update={"messages": [tool_msg], "artifacts": ["/mnt/report.html"]})
req = _make_request(tool_name="present_files", outputs_path=tmpdir)
result = mw.wrap_tool_call(req, lambda _: command)
assert isinstance(result, Command)
assert result is not command
assert result.update["artifacts"] == ["/mnt/report.html"]
new_msg = result.update["messages"][0]
assert isinstance(new_msg, ToolMessage)
assert "Full present_files output saved to" in new_msg.content
def test_command_without_messages_unchanged(self):
config = ToolOutputConfig(externalize_min_chars=10)
mw = ToolOutputBudgetMiddleware(config=config)
command = Command(update={"key": "value"})
result = mw.wrap_tool_call(_make_request(), lambda _: command)
assert result is command
class TestWrapToolCallEdgeCases:
def test_none_content_passes_through(self):
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20)
mw = ToolOutputBudgetMiddleware(config=config)
msg = ToolMessage(content=None, name="tool", tool_call_id="tc-1")
result = mw.wrap_tool_call(_make_request(), lambda _: msg)
assert result is msg
def test_empty_string_passes_through(self):
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20)
mw = ToolOutputBudgetMiddleware(config=config)
msg = _tm("", name="tool")
result = mw.wrap_tool_call(_make_request(), lambda _: msg)
assert result is msg
def test_multimodal_content_skipped(self):
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20)
mw = ToolOutputBudgetMiddleware(config=config)
content = [{"type": "image", "data": "x" * 100}]
msg = ToolMessage(content=content, name="tool", tool_call_id="tc-1")
result = mw.wrap_tool_call(_make_request(), lambda _: msg)
assert result is msg
def test_exactly_at_threshold_passes_through(self):
config = ToolOutputConfig(externalize_min_chars=100, fallback_max_chars=100)
mw = ToolOutputBudgetMiddleware(config=config)
msg = _tm("x" * 100, name="tool")
result = mw.wrap_tool_call(_make_request(), lambda _: msg)
assert result is msg
def test_one_char_over_threshold_triggers(self):
with tempfile.TemporaryDirectory() as tmpdir:
config = ToolOutputConfig(externalize_min_chars=100)
mw = ToolOutputBudgetMiddleware(config=config)
msg = _tm("x" * 101, name="tool")
req = _make_request(outputs_path=tmpdir)
result = mw.wrap_tool_call(req, lambda _: msg)
assert result is not msg
def test_chinese_content_preserved(self):
with tempfile.TemporaryDirectory() as tmpdir:
config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10)
mw = ToolOutputBudgetMiddleware(config=config)
content = "你好世界" * 50
msg = _tm(content, name="tool")
req = _make_request(outputs_path=tmpdir)
result = mw.wrap_tool_call(req, lambda _: msg)
assert isinstance(result, ToolMessage)
# File should contain the full Chinese content
storage_dir = os.path.join(tmpdir, ".tool-results")
files = os.listdir(storage_dir)
with open(os.path.join(storage_dir, files[0]), encoding="utf-8") as f:
assert f.read() == content
def test_no_runtime_state_uses_fallback(self):
config = ToolOutputConfig(
externalize_min_chars=50,
fallback_max_chars=500,
fallback_head_chars=100,
fallback_tail_chars=50,
)
mw = ToolOutputBudgetMiddleware(config=config)
content = "x" * 1000
msg = _tm(content, name="tool")
req = SimpleNamespace(
tool_call={"name": "tool", "id": "tc-1"},
runtime=None,
)
result = mw.wrap_tool_call(req, lambda _: msg)
assert isinstance(result, ToolMessage)
assert "omitted" in result.content
assert len(result.content) <= 500
# ===========================================================================
# MCP content_and_artifact format tests
# ===========================================================================
class TestMCPContentAndArtifact:
"""MCP tools return content as list of content blocks, not plain strings."""
def test_text_content_blocks_are_budgeted(self):
with tempfile.TemporaryDirectory() as tmpdir:
config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10)
mw = ToolOutputBudgetMiddleware(config=config)
content = [{"type": "text", "text": "x" * 200}]
msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-mcp")
req = _make_request(tool_name="mcp_tool", outputs_path=tmpdir)
result = mw.wrap_tool_call(req, lambda _: msg)
assert result is not msg
assert isinstance(result.content, str)
assert "Full mcp_tool output saved to" in result.content
assert result.tool_call_id == "tc-mcp"
def test_multiple_text_blocks_joined_and_budgeted(self):
config = ToolOutputConfig(externalize_min_chars=50, fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50)
mw = ToolOutputBudgetMiddleware(config=config)
content = [{"type": "text", "text": "a" * 300}, {"type": "text", "text": "b" * 300}]
msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-mcp2")
req = _make_request(tool_name="mcp_tool")
result = mw.wrap_tool_call(req, lambda _: msg)
assert result is not msg
assert "omitted" in result.content
def test_image_content_blocks_are_skipped(self):
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20)
mw = ToolOutputBudgetMiddleware(config=config)
content = [{"type": "image", "data": "base64data" * 100}]
msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-img")
req = _make_request(tool_name="mcp_tool")
result = mw.wrap_tool_call(req, lambda _: msg)
assert result is msg
def test_mixed_text_and_image_blocks_are_skipped(self):
config = ToolOutputConfig(externalize_min_chars=10)
mw = ToolOutputBudgetMiddleware(config=config)
content = [{"type": "text", "text": "x" * 100}, {"type": "image", "data": "base64"}]
msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-mix")
req = _make_request(tool_name="mcp_tool")
result = mw.wrap_tool_call(req, lambda _: msg)
assert result is msg
def test_small_text_blocks_pass_through(self):
config = ToolOutputConfig(externalize_min_chars=1000)
mw = ToolOutputBudgetMiddleware(config=config)
content = [{"type": "text", "text": "small result"}]
msg = ToolMessage(content=content, name="mcp_tool", tool_call_id="tc-sm")
req = _make_request(tool_name="mcp_tool")
result = mw.wrap_tool_call(req, lambda _: msg)
assert result is msg
# ===========================================================================
# Async path tests
# ===========================================================================
class TestAsyncPaths:
@pytest.mark.anyio
async def test_async_tool_call_externalized(self):
with tempfile.TemporaryDirectory() as tmpdir:
config = ToolOutputConfig(externalize_min_chars=50, preview_head_chars=20, preview_tail_chars=10)
mw = ToolOutputBudgetMiddleware(config=config)
content = "x" * 200
msg = _tm(content, name="async_tool")
req = _make_request(tool_name="async_tool", outputs_path=tmpdir)
async def handler(_):
return msg
result = await mw.awrap_tool_call(req, handler)
assert isinstance(result, ToolMessage)
assert result is not msg
assert "Full async_tool output saved to" in result.content
@pytest.mark.anyio
async def test_async_model_call_patches_history(self):
config = ToolOutputConfig(fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50)
mw = ToolOutputBudgetMiddleware(config=config)
oversized = _tm("h" * 1000, name="tool", tool_call_id="tc-h")
request = ModelRequest(model=None, messages=[oversized], tools=[], state={})
captured: dict[str, ModelRequest] = {}
async def handler(req):
captured["request"] = req
return []
await mw.awrap_model_call(request, handler)
forwarded = captured["request"]
assert forwarded is not request
msg = forwarded.messages[0]
assert isinstance(msg, ToolMessage)
assert "omitted" in msg.content
# ===========================================================================
# wrap_model_call — historical message patching
# ===========================================================================
class TestWrapModelCall:
def test_oversized_historical_messages_truncated(self):
config = ToolOutputConfig(fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50)
mw = ToolOutputBudgetMiddleware(config=config)
oversized = _tm("q" * 1000, name="tool", tool_call_id="tc-q")
request = ModelRequest(model=None, messages=[oversized], tools=[], state={})
captured: dict[str, ModelRequest] = {}
def handler(req):
captured["request"] = req
return []
mw.wrap_model_call(request, handler)
forwarded = captured["request"]
assert forwarded is not request
msg = forwarded.messages[0]
assert isinstance(msg, ToolMessage)
assert "omitted" in msg.content
assert len(msg.content) < len(oversized.content) + 150
def test_small_historical_messages_unchanged(self):
config = ToolOutputConfig(fallback_max_chars=1000)
mw = ToolOutputBudgetMiddleware(config=config)
small = _tm("small", name="tool")
request = ModelRequest(model=None, messages=[small], tools=[], state={})
captured: dict[str, ModelRequest] = {}
def handler(req):
captured["request"] = req
return []
mw.wrap_model_call(request, handler)
assert captured["request"] is request
def test_exempt_tools_in_history_unchanged(self):
config = ToolOutputConfig(fallback_max_chars=50)
mw = ToolOutputBudgetMiddleware(config=config)
read_msg = _tm("x" * 200, name="read_file", tool_call_id="tc-r")
request = ModelRequest(model=None, messages=[read_msg], tools=[], state={})
captured: dict[str, ModelRequest] = {}
def handler(req):
captured["request"] = req
return []
mw.wrap_model_call(request, handler)
assert captured["request"] is request
def test_non_tool_messages_preserved(self):
config = ToolOutputConfig(fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50)
mw = ToolOutputBudgetMiddleware(config=config)
human = HumanMessage(content="x" * 200)
ai = AIMessage(content="y" * 200)
oversized_tool = _tm("z" * 1000, name="tool")
request = ModelRequest(model=None, messages=[human, ai, oversized_tool], tools=[], state={})
captured: dict[str, ModelRequest] = {}
def handler(req):
captured["request"] = req
return []
mw.wrap_model_call(request, handler)
msgs = captured["request"].messages
assert msgs[0] is human
assert msgs[1] is ai
assert isinstance(msgs[2], ToolMessage)
assert "omitted" in msgs[2].content
# ===========================================================================
# Config integration
# ===========================================================================
class TestFromAppConfig:
def test_from_app_config_with_tool_output(self):
config = AppConfig(
sandbox=SandboxConfig(use="test"),
tool_output={"externalize_min_chars": 5000, "preview_head_chars": 500},
)
mw = ToolOutputBudgetMiddleware.from_app_config(config)
assert mw._config.externalize_min_chars == 5000
assert mw._config.preview_head_chars == 500
def test_from_app_config_defaults(self):
config = AppConfig(sandbox=SandboxConfig(use="test"))
mw = ToolOutputBudgetMiddleware.from_app_config(config)
assert mw._config.externalize_min_chars == 12000
class TestPatchModelMessages:
def test_returns_none_when_no_changes(self):
config = ToolOutputConfig(fallback_max_chars=1000)
messages = [_tm("short", name="tool")]
assert _patch_model_messages(messages, config) is None
def test_patches_oversized_messages(self):
config = ToolOutputConfig(fallback_max_chars=500, fallback_head_chars=100, fallback_tail_chars=50)
messages = [_tm("x" * 1000, name="tool")]
result = _patch_model_messages(messages, config)
assert result is not None
assert len(result) == 1
assert "omitted" in result[0].content
# ===========================================================================
# Pre-scan helpers (_effective_trigger / _tool_message_over_budget / _needs_budget)
# These guard the fast-path optimization — a false negative here is a real bug
# (budgeting silently skipped), so per-tool overrides must be honored.
# ===========================================================================
class TestPreScanHelpers:
def test_effective_trigger_uses_global_externalize(self):
config = ToolOutputConfig(externalize_min_chars=12000, fallback_max_chars=30000)
# smallest of the two thresholds wins
assert _effective_trigger("any_tool", config) == 12000
def test_effective_trigger_respects_per_tool_override(self):
config = ToolOutputConfig(externalize_min_chars=50000, fallback_max_chars=0, tool_overrides={"sensitive": 100})
assert _effective_trigger("sensitive", config) == 100
# other tools fall back to the (high) global
assert _effective_trigger("other", config) == 50000
def test_effective_trigger_per_tool_zero_falls_to_fallback(self):
config = ToolOutputConfig(externalize_min_chars=50, tool_overrides={"bash": 0}, fallback_max_chars=200)
# externalize disabled for bash → only fallback can trigger
assert _effective_trigger("bash", config) == 200
def test_effective_trigger_returns_negative_when_fully_disabled(self):
config = ToolOutputConfig(externalize_min_chars=0, fallback_max_chars=0)
assert _effective_trigger("any", config) == -1
def test_pre_scan_does_not_short_circuit_per_tool_override(self):
"""Regression: pre-scan must honor per-tool overrides, not just global threshold."""
config = ToolOutputConfig(externalize_min_chars=50000, fallback_max_chars=0, tool_overrides={"sensitive": 100})
msg = _tm("x" * 500, name="sensitive")
# 500 < global 50000 but > per-tool 100 → must still be flagged
assert _tool_message_over_budget(msg, config) is True
assert _needs_budget(msg, config) is True
def test_exempt_tool_never_over_budget(self):
config = ToolOutputConfig(externalize_min_chars=10, fallback_max_chars=20, exempt_tools=["read_file"])
msg = _tm("x" * 1000, name="read_file")
assert _tool_message_over_budget(msg, config) is False
def test_model_call_pre_scan_skips_when_nothing_oversized(self):
"""_patch_model_messages returns None (no list rebuild) when all messages are small."""
config = ToolOutputConfig(externalize_min_chars=12000, fallback_max_chars=30000)
messages = [_tm("small", name="tool"), HumanMessage(content="hi"), _tm("also small", name="bash")]
assert _patch_model_messages(messages, config) is None
# ===========================================================================
# Middleware ordering in the chain
# ===========================================================================
class TestMiddlewareChainIntegration:
def test_budget_middleware_is_first_in_chain(self):
from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
app_config = AppConfig(sandbox=SandboxConfig(use="test"))
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
assert isinstance(middlewares[0], ToolOutputBudgetMiddleware)
def test_budget_middleware_in_lead_chain(self):
from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares
app_config = AppConfig(sandbox=SandboxConfig(use="test"))
middlewares = build_lead_runtime_middlewares(app_config=app_config, lazy_init=False)
assert isinstance(middlewares[0], ToolOutputBudgetMiddleware)
# ===========================================================================
# Config version bump
# ===========================================================================
class TestConfigVersion:
def test_config_version_bumped(self):
import yaml
example_path = os.path.join(os.path.dirname(__file__), "..", "..", "config.example.yaml")
if os.path.exists(example_path):
with open(example_path, encoding="utf-8") as f:
data = yaml.safe_load(f)
assert data.get("config_version", 0) >= 11
def test_config_example_has_tool_output_section(self):
import yaml
example_path = os.path.join(os.path.dirname(__file__), "..", "..", "config.example.yaml")
if os.path.exists(example_path):
with open(example_path, encoding="utf-8") as f:
data = yaml.safe_load(f)
assert "tool_output" in data
tool_output = data["tool_output"]
assert tool_output["enabled"] is True
assert tool_output["externalize_min_chars"] == 12000
assert "read_file" in tool_output["exempt_tools"]
@@ -0,0 +1,177 @@
"""Regression tests for issue #3265.
The non-streaming ``/wait`` endpoints used to ``await record.task`` with no
disconnect handling and silently swallow ``CancelledError``. When a long
tool call (e.g. ``pip install`` inside a custom skill) kept the connection
idle long enough for an intermediate HTTP layer to time out, the handler
would return a stale checkpoint that looked like a normal completion.
The fix introduces ``wait_for_run_completion`` in ``app.gateway.services``:
it subscribes to the stream bridge until ``END_SENTINEL``, polls
``request.is_disconnected()`` on every wake-up, and honours the record's
``on_disconnect`` mode by cancelling the background run on real client
disconnect.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Any
from deerflow.runtime import RunManager, RunStatus
from deerflow.runtime.runs.schemas import DisconnectMode
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge
THREAD_ID = "thread-wait-3265"
@dataclass
class _FakeRequest:
"""Minimal stand-in for FastAPI ``Request`` with controllable disconnect.
``is_disconnected`` is awaited each iteration of the helper's loop, so the
counter lets a test transition from "still connected" to "disconnected"
after N polls without racing the event loop.
"""
disconnect_after: int = 10**9 # effectively "never" by default
_polls: int = 0
async def is_disconnected(self) -> bool:
self._polls += 1
return self._polls > self.disconnect_after
async def _create_running_record(mgr: RunManager, *, on_disconnect: DisconnectMode) -> Any:
record = await mgr.create_or_reject(
THREAD_ID,
assistant_id=None,
on_disconnect=on_disconnect,
)
await mgr.set_status(record.run_id, RunStatus.running)
return record
# ---------------------------------------------------------------------------
# Helper-level unit tests
# ---------------------------------------------------------------------------
class TestWaitForRunCompletion:
def test_returns_when_run_publishes_end(self) -> None:
"""Happy path: helper returns once the bridge publishes END_SENTINEL."""
from app.gateway.services import wait_for_run_completion
async def run() -> None:
mgr = RunManager()
bridge = MemoryStreamBridge()
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
request = _FakeRequest()
async def finish_soon() -> None:
await asyncio.sleep(0)
await bridge.publish(record.run_id, "values", {"messages": []})
await mgr.set_status(record.run_id, RunStatus.success)
await bridge.publish_end(record.run_id)
asyncio.create_task(finish_soon())
completed = await asyncio.wait_for(
wait_for_run_completion(bridge, record, request, mgr),
timeout=2.0,
)
assert completed is True
assert record.status == RunStatus.success
asyncio.run(run())
def test_cancels_run_on_disconnect_when_cancel_mode(self) -> None:
"""on_disconnect=cancel: real disconnect must call run_mgr.cancel()."""
from app.gateway.services import wait_for_run_completion
async def run() -> None:
mgr = RunManager()
bridge = MemoryStreamBridge()
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
# Attach a real (idle) task so cancel() actually has something to cancel.
sleeper = asyncio.create_task(asyncio.sleep(30))
record.task = sleeper
request = _FakeRequest(disconnect_after=0) # disconnected on first poll
async def publish_until_cancel() -> None:
# Emit one event so subscribe wakes up immediately; helper polls
# is_disconnected after each yield.
await asyncio.sleep(0)
await bridge.publish(record.run_id, "values", {"step": 1})
asyncio.create_task(publish_until_cancel())
completed = await asyncio.wait_for(
wait_for_run_completion(bridge, record, request, mgr),
timeout=2.0,
)
assert completed is False
assert record.status == RunStatus.interrupted
# Drain the cancelled sleeper so it does not linger past the test.
try:
await asyncio.wait_for(sleeper, timeout=1.0)
except asyncio.CancelledError:
pass
assert sleeper.done()
asyncio.run(run())
def test_does_not_cancel_when_continue_mode(self) -> None:
"""on_disconnect=continue: disconnect must NOT cancel the run."""
from app.gateway.services import wait_for_run_completion
async def run() -> None:
mgr = RunManager()
bridge = MemoryStreamBridge()
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.continue_)
sleeper = asyncio.create_task(asyncio.sleep(30))
record.task = sleeper
request = _FakeRequest(disconnect_after=0)
async def publish_then_end() -> None:
await asyncio.sleep(0)
await bridge.publish(record.run_id, "values", {"step": 1})
asyncio.create_task(publish_then_end())
completed = await asyncio.wait_for(
wait_for_run_completion(bridge, record, request, mgr),
timeout=2.0,
)
# Disconnected before END — helper still reports incomplete so the
# caller skips checkpoint serialization, but the run keeps going.
assert completed is False
assert record.status == RunStatus.running
sleeper.cancel()
asyncio.run(run())
def test_no_cancel_when_run_already_finished(self) -> None:
"""If the run ended (END_SENTINEL) before disconnect is observed, the
finally block must not call cancel the run is already terminal."""
from app.gateway.services import wait_for_run_completion
async def run() -> None:
mgr = RunManager()
bridge = MemoryStreamBridge()
record = await _create_running_record(mgr, on_disconnect=DisconnectMode.cancel)
# Publish END before subscribe — helper should see ended=True first
# poll and return without ever observing the "disconnect".
await mgr.set_status(record.run_id, RunStatus.success)
await bridge.publish_end(record.run_id)
request = _FakeRequest(disconnect_after=0)
completed = await asyncio.wait_for(
wait_for_run_completion(bridge, record, request, mgr),
timeout=2.0,
)
assert completed is True
assert record.status == RunStatus.success
asyncio.run(run())
+61 -6
View File
@@ -15,7 +15,7 @@
# ============================================================================
# Bump this number when the config schema changes.
# Run `make config-upgrade` to merge new fields into your local config.yaml.
config_version: 10
config_version: 11
# ============================================================================
# Logging
@@ -177,6 +177,38 @@ models:
# thinking:
# type: disabled
# Example: Xiaomi MiMo model (with thinking support)
# MiMo thinking mode returns reasoning_content and requires that field to be
# replayed on historical assistant messages in multi-turn agent/tool-call
# conversations. Use PatchedChatMiMo instead of plain ChatOpenAI.
# Use https://api.xiaomimimo.com/v1 with pay-as-you-go `sk-...` keys.
# Use your Token Plan regional URL (for example
# https://token-plan-cn.xiaomimimo.com/v1) with Token Plan `tp-...` keys.
# PatchedChatMiMo is model-id agnostic; use it for every MiMo thinking model
# entry you configure (for example mimo-v2.5-pro, mimo-v2.5, mimo-v2-pro,
# mimo-v2-omni, or mimo-v2-flash), including models referenced by subagent
# model overrides.
# See: https://platform.xiaomimimo.com/docs/en-US/usage-guide/passing-back-reasoning_content
# - name: mimo-v2.5-pro
# display_name: MiMo V2.5 Pro
# use: deerflow.models.patched_mimo:PatchedChatMiMo
# model: mimo-v2.5-pro
# api_key: $MIMO_API_KEY
# base_url: https://api.xiaomimimo.com/v1
# request_timeout: 600.0
# max_retries: 2
# max_tokens: 8192
# supports_thinking: true
# supports_vision: false
# when_thinking_enabled:
# extra_body:
# thinking:
# type: enabled
# when_thinking_disabled:
# extra_body:
# thinking:
# type: disabled
# Example: DeepSeek model (with thinking support)
# - name: deepseek-v3
# display_name: DeepSeek V3 (Thinking)
@@ -512,6 +544,34 @@ tools:
tool_search:
enabled: false
# ============================================================================
# Tool Output Budget Protection
# ============================================================================
# Prevents oversized tool results from blowing the model context window.
# Outputs exceeding `externalize_min_chars` are persisted to disk and replaced
# with a compact preview + file reference. The model can read the full output
# via read_file. When disk persistence is unavailable, outputs exceeding
# `fallback_max_chars` are head+tail truncated instead.
#
# `exempt_tools` prevents persist→read→persist infinite loops for read tools.
# `tool_overrides` allows per-tool threshold customization.
tool_output:
enabled: true
externalize_min_chars: 12000
preview_head_chars: 2000
preview_tail_chars: 1000
fallback_max_chars: 30000
fallback_head_chars: 8000
fallback_tail_chars: 3000
storage_subdir: ".tool-results"
exempt_tools:
- read_file
- read_file_tool
# tool_overrides:
# web_search: 8000
# bash: 20000
# ============================================================================
# Loop Detection Configuration
# ============================================================================
@@ -642,11 +702,6 @@ sandbox:
# # Optional: Prefix for container names (default: deer-flow-sandbox)
# # container_prefix: deer-flow-sandbox
#
# # Optional: Automatically restart crashed sandbox containers (default: true)
# # When enabled, a dead container is detected on the next tool call and
# # transparently replaced with a fresh one. Set to false to disable.
# # auto_restart: true
#
# # Optional: Additional mount directories from host to container
# # NOTE: Skills directory is automatically mounted from skills.path to skills.container_path
# # mounts:
+4 -3
View File
@@ -10,10 +10,11 @@
# should be updated accordingly.
# Backend API URLs (optional)
# Leave these commented out to use the default nginx proxy (recommended for `make dev`)
# Only set these if you need to connect to backend services directly
# Leave these commented out to use the default nginx proxy (recommended for `make dev`).
# Only set these if you need to connect to the Gateway service directly.
# For split-origin browser access, also configure GATEWAY_CORS_ORIGINS.
# NEXT_PUBLIC_BACKEND_BASE_URL="http://localhost:8001"
# NEXT_PUBLIC_LANGGRAPH_BASE_URL="http://localhost:2024"
# NEXT_PUBLIC_LANGGRAPH_BASE_URL="http://localhost:8001/api"
# Server-only Gateway wiring used by SSR (auth checks, /api/* rewrites).
# Defaults to localhost — only override for non-local deployments.
+5 -1
View File
@@ -88,7 +88,11 @@ Backend API URLs are optional; an nginx proxy is used by default:
```
NEXT_PUBLIC_BACKEND_BASE_URL=http://localhost:8001
NEXT_PUBLIC_LANGGRAPH_BASE_URL=http://localhost:2024
NEXT_PUBLIC_LANGGRAPH_BASE_URL=http://localhost:8001/api
```
Leave these unset for the standard `make dev` / Docker flow, where nginx serves
the public `/api/langgraph/*` prefix and rewrites it to Gateway's native `/api/*`
routes.
Requires Node.js 22+ and pnpm 10.26.2+.
+17 -17
View File
@@ -299,16 +299,16 @@ packages:
'@antfu/install-pkg@1.1.0':
resolution: {integrity: sha512-MGQsmw10ZyI+EJo45CdSER4zEb+p31LpDAFp2Z3gkSd1yqVZGi0Ebx++YTEMonJy4oChEMLsxZ64j8FH6sSqtQ==}
'@babel/helper-string-parser@7.27.1':
resolution: {integrity: sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==}
'@babel/helper-string-parser@7.29.7':
resolution: {integrity: sha512-Pb5ijPrZ89GDH8223L4UP8i6QApWxs04RbPQJTeWDV0/keR2E36MeKnyr6LYmUUvqRRI+Iv87SuF1W6ErINzYw==}
engines: {node: '>=6.9.0'}
'@babel/helper-validator-identifier@7.28.5':
resolution: {integrity: sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==}
'@babel/helper-validator-identifier@7.29.7':
resolution: {integrity: sha512-qehxGkRj55h/ff8EMaJ+cYhyaKlHIxqYDn682wQD7RNp9UujOQsHog2uS0r2vzr4pW+sXf90NeeayjcNaX3fFg==}
engines: {node: '>=6.9.0'}
'@babel/parser@7.29.3':
resolution: {integrity: sha512-b3ctpQwp+PROvU/cttc4OYl4MzfJUWy6FZg+PMXfzmt/+39iHVF0sDfqay8TQM3JA2EUOyKcFZt75jWriQijsA==}
'@babel/parser@7.29.7':
resolution: {integrity: sha512-hnORnjP/1P/zFEndoeX+n+t1RwWRJiJpM/jO7FW32Kn9r5+sJB2JWOdYo4L6k78j15eCwY3Gm/7364B1EMwtNg==}
engines: {node: '>=6.0.0'}
hasBin: true
@@ -316,8 +316,8 @@ packages:
resolution: {integrity: sha512-05WQkdpL9COIMz4LjTxGpPNCdlpyimKppYNoJ5Di5EUObifl8t4tuLuUBBZEpoLYOmfvIWrsp9fCl0HoPRVTdA==}
engines: {node: '>=6.9.0'}
'@babel/types@7.29.0':
resolution: {integrity: sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A==}
'@babel/types@7.29.7':
resolution: {integrity: sha512-4zBIxpPzowiZpusoFkyGVwakdRJUyuH5PxQ/PrqghfdFWWasvnCdPfQXHrenDai+gyLARulZjZowCOj6fjT4pA==}
engines: {node: '>=6.9.0'}
'@braintree/sanitize-url@7.1.2':
@@ -5777,20 +5777,20 @@ snapshots:
package-manager-detector: 1.6.0
tinyexec: 1.0.2
'@babel/helper-string-parser@7.27.1': {}
'@babel/helper-string-parser@7.29.7': {}
'@babel/helper-validator-identifier@7.28.5': {}
'@babel/helper-validator-identifier@7.29.7': {}
'@babel/parser@7.29.3':
'@babel/parser@7.29.7':
dependencies:
'@babel/types': 7.29.0
'@babel/types': 7.29.7
'@babel/runtime@7.28.6': {}
'@babel/types@7.29.0':
'@babel/types@7.29.7':
dependencies:
'@babel/helper-string-parser': 7.27.1
'@babel/helper-validator-identifier': 7.28.5
'@babel/helper-string-parser': 7.29.7
'@babel/helper-validator-identifier': 7.29.7
'@braintree/sanitize-url@7.1.2': {}
@@ -8047,7 +8047,7 @@ snapshots:
'@vue/compiler-core@3.5.28':
dependencies:
'@babel/parser': 7.29.3
'@babel/parser': 7.29.7
'@vue/shared': 3.5.28
entities: 7.0.1
estree-walker: 2.0.2
@@ -8060,7 +8060,7 @@ snapshots:
'@vue/compiler-sfc@3.5.28':
dependencies:
'@babel/parser': 7.29.3
'@babel/parser': 7.29.7
'@vue/compiler-core': 3.5.28
'@vue/compiler-dom': 3.5.28
'@vue/compiler-ssr': 3.5.28
@@ -146,6 +146,27 @@ export default function NewAgentPage() {
err.reason === "backend_unreachable"
) {
setNameError(t.agents.nameStepNetworkError);
} else if (
err instanceof AgentNameCheckError &&
err.reason === "request_failed"
) {
// Surface the backend-provided detail (e.g. validation error) when
// one is present, wrapped in a localised prefix so zh-CN users
// don't see a bare English string next to the surrounding Chinese
// UI. Falls back to the generic localised fallback when the backend
// sent no detail — `err.message` is unreliable for this branch
// because `checkAgentName` substitutes a generated fallback string
// ("Failed to check agent name: ${statusText}") when `detail` is
// missing, so testing `err.message` would always be truthy and the
// generated fallback would leak through.
setNameError(
err.detail
? t.agents.nameStepCheckErrorWithDetail.replace(
"{detail}",
err.detail,
)
: t.agents.nameStepCheckError,
);
} else {
setNameError(t.agents.nameStepCheckError);
}
@@ -172,6 +193,7 @@ export default function NewAgentPage() {
t.agents.nameStepNetworkError,
t.agents.nameStepBootstrapMessage,
t.agents.nameStepCheckError,
t.agents.nameStepCheckErrorWithDetail,
t.agents.nameStepInvalidError,
threadId,
]);
@@ -7,7 +7,10 @@ import {
MessageResponse,
type MessageResponseProps,
} from "@/components/ai-elements/message";
import { streamdownPlugins } from "@/core/streamdown";
import {
preprocessStreamdownMarkdown,
streamdownPlugins,
} from "@/core/streamdown";
import { cn } from "@/lib/utils";
import { CitationLink } from "../citations/citation-link";
@@ -33,6 +36,10 @@ export function MarkdownContent({
remarkPlugins = streamdownPlugins.remarkPlugins,
components: componentsFromProps,
}: MarkdownContentProps) {
const normalizedContent = useMemo(
() => preprocessStreamdownMarkdown(content),
[content],
);
const components = useMemo(() => {
return {
a: (props: AnchorHTMLAttributes<HTMLAnchorElement>) => {
@@ -70,7 +77,7 @@ export function MarkdownContent({
rehypePlugins={rehypePlugins}
components={components}
>
{content}
{normalizedContent}
</MessageResponse>
);
}
+12 -1
View File
@@ -9,6 +9,15 @@ export class AgentNameCheckError extends Error {
constructor(
message: string,
public readonly reason: "backend_unreachable" | "request_failed",
/**
* Raw backend `detail` string when the failure came from a backend
* response carrying one. `null` when no detail was provided (e.g.
* network-layer failure, empty response body, unparseable body) in
* which case `message` is a generated fallback like "Failed to check
* agent name: Bad Gateway" and the UI should prefer its own localized
* fallback instead of surfacing the generated string.
*/
public readonly detail: string | null = null,
) {
super(message);
this.name = "AgentNameCheckError";
@@ -104,9 +113,11 @@ export async function checkAgentName(
"backend_unreachable",
);
}
const backendDetail = typeof err.detail === "string" ? err.detail : null;
throw new AgentNameCheckError(
err.detail ?? `Failed to check agent name: ${res.statusText}`,
backendDetail ?? `Failed to check agent name: ${res.statusText}`,
"request_failed",
backendDetail,
);
}
return res.json() as Promise<{ available: boolean; name: string }>;
+56 -6
View File
@@ -38,6 +38,47 @@ function injectCsrfHeader(_url: URL, init: RequestInit): RequestInit {
return { ...init, headers };
}
export function isInactiveRunStreamError(error: unknown): boolean {
const status =
typeof error === "object" && error !== null
? Reflect.get(error, "status")
: undefined;
const message =
typeof error === "string"
? error
: error instanceof Error
? error.message
: typeof error === "object" && error !== null
? String(Reflect.get(error, "message") ?? "")
: "";
// Match the gateway's store-only run response in
// backend/app/gateway/routers/thread_runs.py until the API exposes a
// structured error code for inactive run streams.
return (
(status === 409 || message.includes("HTTP 409")) &&
message.includes("not active on this worker") &&
message.includes("cannot be streamed")
);
}
export function clearReconnectRun(
threadId: string | null | undefined,
runId: string,
): void {
if (typeof window === "undefined" || !threadId) return;
const key = `lg:stream:${threadId}`;
try {
const storage = window.sessionStorage;
if (storage.getItem(key) === runId) {
storage.removeItem(key);
}
} catch {
// Ignore storage access failures so reconnect cleanup never throws.
}
}
function createCompatibleClient(isMock?: boolean): LangGraphClient {
if (isStaticWebsiteOnly() && !isMock) {
return createStaticClient();
@@ -59,12 +100,21 @@ function createCompatibleClient(isMock?: boolean): LangGraphClient {
)) as typeof client.runs.stream;
const originalJoinStream = client.runs.joinStream.bind(client.runs);
client.runs.joinStream = ((threadId, runId, options) =>
originalJoinStream(
threadId,
runId,
sanitizeRunStreamOptions(options),
)) as typeof client.runs.joinStream;
client.runs.joinStream = async function* (threadId, runId, options) {
try {
yield* originalJoinStream(
threadId,
runId,
sanitizeRunStreamOptions(options),
);
} catch (error) {
if (isInactiveRunStreamError(error)) {
clearReconnectRun(threadId, runId);
return;
}
throw error;
}
} as typeof client.runs.joinStream;
return client;
}
+1
View File
@@ -204,6 +204,7 @@ export const enUS: Translations = {
nameStepNetworkError:
"Network request failed — check your network or backend connection",
nameStepCheckError: "Could not verify name availability — please try again",
nameStepCheckErrorWithDetail: "Name check failed: {detail}",
nameStepApiDisabledError:
"Custom agent management is not enabled on this server. Please contact your administrator.",
nameStepBootstrapMessage:
+1
View File
@@ -141,6 +141,7 @@ export interface Translations {
nameStepAlreadyExistsError: string;
nameStepNetworkError: string;
nameStepCheckError: string;
nameStepCheckErrorWithDetail: string;
nameStepApiDisabledError: string;
nameStepBootstrapMessage: string;
save: string;
+1
View File
@@ -192,6 +192,7 @@ export const zhCN: Translations = {
nameStepAlreadyExistsError: "已存在同名智能体",
nameStepNetworkError: "网络请求失败,请检查网络或后端连接",
nameStepCheckError: "无法验证名称可用性,请稍后重试",
nameStepCheckErrorWithDetail: "名称校验失败:{detail}",
nameStepApiDisabledError:
"服务器未开启自定义智能体管理功能,请联系管理员。",
nameStepBootstrapMessage:
+2
View File
@@ -1 +1,3 @@
export * from "./mermaid";
export * from "./preprocess";
export * from "./plugins";
+98
View File
@@ -0,0 +1,98 @@
const MERMAID_OPENING_FENCE_RE =
/^[ \t]{0,3}(`{3,}|~{3,})[ \t]*mermaid(?:[ \t].*)?$/i;
const WINDOWS_LINE_ENDING_RE = /\r\n?/g;
const LABELLED_DOTTED_ARROW_RE =
/^(\s*)(.+?)\s*--\s*("[^"\n]+"|'[^'\n]+')\s*-\.->\s*(.+?)\s*$/;
function normalizeMermaidCode(code: string): string {
return code
.split("\n")
.map((line) =>
line.replace(
LABELLED_DOTTED_ARROW_RE,
(
_match,
indent: string,
source: string,
label: string,
target: string,
) => `${indent}${source} -. ${label} .-> ${target}`,
),
)
.join("\n");
}
function isClosingFence(line: string, fence: string): boolean {
const trimmedLine = line.trimEnd();
const indentationLength = trimmedLine.length - trimmedLine.trimStart().length;
const fenceMarker = trimmedLine.slice(indentationLength);
const fenceChar = fence.charAt(0);
if (indentationLength > 3 || !fenceMarker.startsWith(fenceChar)) {
return false;
}
return (
fenceMarker.length >= fence.length &&
[...fenceMarker].every((char) => char === fenceChar)
);
}
export function normalizeMermaidMarkdown(markdown: string): string {
const lines = markdown.replace(WINDOWS_LINE_ENDING_RE, "\n").split("\n");
const normalizedLines: string[] = [];
for (let index = 0; index < lines.length; index += 1) {
const line = lines[index]!;
const openingFenceMatch = MERMAID_OPENING_FENCE_RE.exec(line);
if (!openingFenceMatch) {
normalizedLines.push(line);
continue;
}
const openingFence = openingFenceMatch[1];
if (openingFence === undefined) {
normalizedLines.push(line);
continue;
}
const codeLines: string[] = [];
let closingLine: string | undefined;
let cursor = index + 1;
for (; cursor < lines.length; cursor += 1) {
const candidateLine = lines[cursor]!;
if (isClosingFence(candidateLine, openingFence)) {
closingLine = candidateLine;
break;
}
codeLines.push(candidateLine);
}
if (closingLine === undefined) {
normalizedLines.push(line, ...codeLines);
index = cursor - 1;
continue;
}
normalizedLines.push(line);
if (codeLines.length > 0) {
normalizedLines.push(
...normalizeMermaidCode(codeLines.join("\n")).split("\n"),
);
}
normalizedLines.push(closingLine);
index = cursor;
}
return normalizedLines.join("\n");
}
@@ -0,0 +1,11 @@
import { normalizeMermaidMarkdown } from "./mermaid";
const MERMAID_BLOCK_HINT_RE = /mermaid/i;
export function preprocessStreamdownMarkdown(markdown: string): string {
if (!MERMAID_BLOCK_HINT_RE.test(markdown) || !markdown.includes("-.->")) {
return markdown;
}
return normalizeMermaidMarkdown(markdown);
}
+128 -13
View File
@@ -1,7 +1,12 @@
import type { AIMessage, Message, Run } from "@langchain/langgraph-sdk";
import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
import { useStream } from "@langchain/langgraph-sdk/react";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import {
type QueryClient,
useMutation,
useQuery,
useQueryClient,
} from "@tanstack/react-query";
import { useCallback, useEffect, useRef, useState } from "react";
import { toast } from "sonner";
@@ -11,6 +16,7 @@ import { getAPIClient } from "../api";
import { fetch } from "../api/fetcher";
import { getBackendBaseURL } from "../config";
import { useI18n } from "../i18n/hooks";
import { isHiddenFromUIMessage } from "../messages/utils";
import type { FileInMessage } from "../messages/utils";
import type { LocalSettings } from "../settings";
import { useUpdateSubtask } from "../tasks/context";
@@ -49,6 +55,11 @@ function isNonEmptyString(value: string | undefined): value is string {
return typeof value === "string" && value.length > 0;
}
const SUMMARIZATION_MIDDLEWARE_UPDATE_KEYS = new Set([
"SummarizationMiddleware.before_model",
"DeerFlowSummarizationMiddleware.before_model",
]);
function messageIdentity(message: Message): string | undefined {
if (
"tool_call_id" in message &&
@@ -65,17 +76,33 @@ function messageIdentity(message: Message): string | undefined {
function dedupeMessagesByIdentity(messages: Message[]): Message[] {
const lastIndexByIdentity = new Map<string, number>();
const lastVisibleIndexByIdentity = new Map<string, number>();
// This is a UI-display dedupe rule, not a general LangChain message-stream
// contract. Hidden messages that share an identity with a visible message are
// treated as control messages for this merged view; hidden messages carrying
// independent tracing/task semantics should use a distinct id or a custom
// stream/state channel instead of relying on message dedupe preservation.
messages.forEach((message, index) => {
const identity = messageIdentity(message);
if (identity) {
lastIndexByIdentity.set(identity, index);
if (!isHiddenFromUIMessage(message)) {
lastVisibleIndexByIdentity.set(identity, index);
}
}
});
return messages.filter((message, index) => {
const identity = messageIdentity(message);
return !identity || lastIndexByIdentity.get(identity) === index;
if (!identity) {
return true;
}
const visibleIndex = lastVisibleIndexByIdentity.get(identity);
if (visibleIndex !== undefined) {
return visibleIndex === index;
}
return lastIndexByIdentity.get(identity) === index;
});
}
@@ -97,8 +124,15 @@ export function mergeMessages(
threadMessages: Message[],
optimisticMessages: Message[],
): Message[] {
// Only visible live messages should trim overlapping history. Hidden messages
// are UI control messages in this path, not observability records; any hidden
// message that must survive as task/tracing data should use custom events or a
// separate state channel instead of participating in this overlap heuristic.
const threadMessageIds = new Set(
threadMessages.map(messageIdentity).filter(isNonEmptyString),
threadMessages
.filter((message) => !isHiddenFromUIMessage(message))
.map(messageIdentity)
.filter(isNonEmptyString),
);
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
@@ -149,6 +183,72 @@ export function getVisibleOptimisticMessages(
return optimisticMessages;
}
export function getSummarizationMiddlewareMessages(
data: unknown,
): Message[] | undefined {
if (typeof data !== "object" || data === null) {
return undefined;
}
for (const [key, update] of Object.entries(data)) {
if (!SUMMARIZATION_MIDDLEWARE_UPDATE_KEYS.has(key)) {
continue;
}
if (typeof update !== "object" || update === null) {
continue;
}
const messages = Reflect.get(update, "messages");
if (Array.isArray(messages)) {
return [...messages] as Message[];
}
}
return undefined;
}
export function upsertThreadInSearchCache(
queryClient: QueryClient,
thread: AgentThread,
) {
queryClient.setQueriesData(
{
queryKey: ["threads", "search"],
exact: false,
},
(oldData: Array<AgentThread> | undefined) => {
if (!oldData) {
return [thread];
}
const existingIndex = oldData.findIndex(
(t) => t.thread_id === thread.thread_id,
);
if (existingIndex === -1) {
return [thread, ...oldData];
}
return oldData.map((t, index) => {
if (index !== existingIndex) {
return t;
}
return {
...thread,
...t,
metadata: {
...(thread.metadata ?? {}),
...(t.metadata ?? {}),
},
values: {
...thread.values,
...t.values,
},
};
});
},
);
}
function getStreamErrorMessage(error: unknown): string {
if (typeof error === "string" && error.trim()) {
return error;
@@ -241,6 +341,20 @@ export function useThreadStream({
fetchStateHistory: { limit: 1 },
onCreated(meta) {
handleStreamStart(meta.thread_id, meta.run_id);
const now = new Date().toISOString();
upsertThreadInSearchCache(queryClient, {
thread_id: meta.thread_id,
created_at: now,
updated_at: now,
metadata: context.agent_name ? { agent_name: context.agent_name } : {},
status: "busy",
values: {
title: t.pages.newChat,
messages: [],
artifacts: [],
},
interrupts: {},
});
if (context.agent_name && !isMock) {
void getAPIClient()
.threads.update(meta.thread_id, {
@@ -258,24 +372,25 @@ export function useThreadStream({
}
},
onUpdateEvent(data) {
if (data["SummarizationMiddleware.before_model"]) {
const _messages = [
...(data["SummarizationMiddleware.before_model"].messages ?? []),
];
if (_messages.length < 2) {
return;
}
const _messages = getSummarizationMiddlewareMessages(data);
if (_messages && _messages.length >= 2) {
for (const m of _messages) {
if (m.name === "summary" && m.type === "human") {
summarizedRef.current?.add(m.id ?? "");
}
}
const _lastKeepMessage = _messages[2];
const firstRetainedVisibleIdentity = _messages
.filter((message) => message.type !== "remove")
.filter((message) => !isHiddenFromUIMessage(message))
.map(messageIdentity)
.find(isNonEmptyString);
const _currentMessages = [...messagesRef.current];
const _movedMessages: Message[] = [];
for (const m of _currentMessages) {
if (m.id !== undefined && m.id === _lastKeepMessage?.id) {
if (
firstRetainedVisibleIdentity &&
messageIdentity(m) === firstRetainedVisibleIdentity
) {
break;
}
if (!summarizedRef.current?.has(m.id ?? "")) {

Some files were not shown because too many files have changed in this diff Show More