mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
Merge branch 'main' into fix-3127
This commit is contained in:
@@ -184,6 +184,18 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
|
||||
|
||||
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart.
|
||||
|
||||
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work (logging level, channels, `langgraph_runtime` engines) and passes it explicitly to `langgraph_runtime(app, startup_config)`. Infrastructure fields are **restart-required**:
|
||||
|
||||
| Field | Why a restart is required |
|
||||
|---|---|
|
||||
| `database.*` | `init_engine_from_config()` runs once during `langgraph_runtime()` startup; the SQLAlchemy engine holds the connection pool. |
|
||||
| `checkpointer.*` (including SQLite WAL/journal settings) | `make_checkpointer()` binds the persistent checkpointer once at startup. |
|
||||
| `run_events.*` | `make_run_event_store()` selects memory- vs. SQL-backed implementation at startup. |
|
||||
| `stream_bridge.*` | `make_stream_bridge()` constructs the bridge object once. |
|
||||
| `sandbox.use` | `get_sandbox_provider()` caches the provider singleton (`_default_sandbox_provider`); a new class path takes effect only on next process start. |
|
||||
| `log_level` | `apply_logging_level()` is called only in `app.py` startup; it mutates the root logger's level, and `get_app_config()` returning a fresh `AppConfig` does not retrigger it. |
|
||||
| `channels.*` IM platform credentials | `start_channel_service()` is invoked once during startup; live channels are not rebuilt on config change. |
|
||||
|
||||
Configuration priority:
|
||||
1. Explicit `config_path` argument
|
||||
2. `DEER_FLOW_CONFIG_PATH` environment variable
|
||||
|
||||
@@ -161,10 +161,16 @@ async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan handler."""
|
||||
|
||||
# Load config and check necessary environment variables at startup
|
||||
# Load config and check necessary environment variables at startup.
|
||||
# `startup_config` is a local snapshot used only for one-shot bootstrap
|
||||
# work (logging level, langgraph_runtime engines, channels). Request-time
|
||||
# config resolution always routes through `get_app_config()` in
|
||||
# `app/gateway/deps.py::get_config()` so `config.yaml` edits become
|
||||
# visible without a process restart. We deliberately do NOT cache this
|
||||
# snapshot on `app.state` to keep that contract enforceable.
|
||||
try:
|
||||
app.state.config = get_app_config()
|
||||
apply_logging_level(app.state.config.log_level)
|
||||
startup_config = get_app_config()
|
||||
apply_logging_level(startup_config.log_level)
|
||||
logger.info("Configuration loaded successfully")
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
||||
@@ -174,7 +180,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
|
||||
|
||||
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
||||
async with langgraph_runtime(app):
|
||||
async with langgraph_runtime(app, startup_config):
|
||||
logger.info("LangGraph runtime initialised")
|
||||
|
||||
# Check admin bootstrap state and migrate orphan threads after admin exists.
|
||||
@@ -185,7 +191,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
try:
|
||||
from app.channels.service import start_channel_service
|
||||
|
||||
channel_service = await start_channel_service(app.state.config)
|
||||
channel_service = await start_channel_service(startup_config)
|
||||
logger.info("Channel service started: %s", channel_service.get_status())
|
||||
except Exception:
|
||||
logger.exception("No IM channels configured or channel service failed to start")
|
||||
|
||||
+69
-17
@@ -3,11 +3,21 @@
|
||||
**Getters** (used by routers): raise 503 when a required dependency is
|
||||
missing, except ``get_store`` which returns ``None``.
|
||||
|
||||
``AppConfig`` is intentionally *not* cached on ``app.state``. Routers and the
|
||||
run path resolve it through :func:`deerflow.config.app_config.get_app_config`,
|
||||
which performs mtime-based hot reload, so edits to ``config.yaml`` take
|
||||
effect on the next request without a process restart. The engines created in
|
||||
:func:`langgraph_runtime` (stream bridge, persistence, checkpointer, store,
|
||||
run-event store) accept a ``startup_config`` snapshot — they are
|
||||
restart-required by design and stay bound to that snapshot to keep the live
|
||||
process consistent with itself.
|
||||
|
||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import TYPE_CHECKING, TypeVar, cast
|
||||
@@ -15,12 +25,14 @@ from typing import TYPE_CHECKING, TypeVar, cast
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.app_config import AppConfig, get_app_config
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
from deerflow.runtime import RunContext, RunManager, StreamBridge
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
@@ -30,21 +42,55 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_config(request: Request) -> AppConfig:
|
||||
"""Return the app-scoped ``AppConfig`` stored on ``app.state``."""
|
||||
config = getattr(request.app.state, "config", None)
|
||||
if config is None:
|
||||
raise HTTPException(status_code=503, detail="Configuration not available")
|
||||
return config
|
||||
def get_config() -> AppConfig:
|
||||
"""Return the freshest ``AppConfig`` for the current request.
|
||||
|
||||
Routes through :func:`deerflow.config.app_config.get_app_config`, which
|
||||
honours runtime ``ContextVar`` overrides and reloads ``config.yaml`` from
|
||||
disk when its mtime changes. ``AppConfig`` is not cached on ``app.state``
|
||||
at all — the only startup-time snapshot lives as a local
|
||||
``startup_config`` variable inside ``lifespan()`` and is passed
|
||||
explicitly into :func:`langgraph_runtime` for the engines that are
|
||||
restart-required by design. Routing every request through
|
||||
:func:`get_app_config` closes the bytedance/deer-flow issue #3107 BUG-001
|
||||
split-brain where the worker / lead-agent thread saw a stale startup
|
||||
snapshot.
|
||||
|
||||
Any failure to materialise the config (missing file, permission denied,
|
||||
YAML parse error, validation error) is reported as 503 — semantically
|
||||
"the gateway cannot serve requests without a usable configuration" — and
|
||||
logged with the original exception so operators have something to debug.
|
||||
"""
|
||||
try:
|
||||
return get_app_config()
|
||||
except Exception as exc: # noqa: BLE001 - request boundary: log and degrade gracefully
|
||||
logger.exception("Failed to load AppConfig at request time")
|
||||
raise HTTPException(status_code=503, detail="Configuration not available") from exc
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGenerator[None, None]:
|
||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
||||
|
||||
``startup_config`` is the ``AppConfig`` snapshot taken once during
|
||||
``lifespan()`` for one-shot infrastructure bootstrap. The engines and
|
||||
stores constructed here (stream bridge, persistence engine, checkpointer,
|
||||
store, run-event store) are restart-required by design — they hold live
|
||||
connections, file handles, or singleton providers — so they bind to this
|
||||
snapshot and survive across `config.yaml` edits. Request-time consumers
|
||||
must still go through :func:`get_config` for any field that should be
|
||||
hot-reloadable. See ``backend/CLAUDE.md`` "Config Hot-Reload Boundary".
|
||||
|
||||
The matching ``run_events_config`` is frozen onto ``app.state`` so
|
||||
:func:`get_run_context` pairs a freshly-loaded ``AppConfig`` with the
|
||||
*startup-time* run-events configuration the underlying ``event_store``
|
||||
was built from — otherwise the runtime could end up combining a live
|
||||
new ``run_events_config`` with an event store still bound to the
|
||||
previous backend.
|
||||
|
||||
Usage in ``app.py``::
|
||||
|
||||
async with langgraph_runtime(app):
|
||||
async with langgraph_runtime(app, startup_config):
|
||||
yield
|
||||
"""
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
||||
@@ -53,9 +99,7 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
config = getattr(app.state, "config", None)
|
||||
if config is None:
|
||||
raise RuntimeError("langgraph_runtime() requires app.state.config to be initialized")
|
||||
config = startup_config
|
||||
|
||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge(config))
|
||||
|
||||
@@ -84,8 +128,12 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
app.state.thread_store = make_thread_store(sf, app.state.store)
|
||||
|
||||
# Run event store (has its own factory with config-driven backend selection)
|
||||
# Run event store. The store and the matching ``run_events_config`` are
|
||||
# both frozen at startup so ``get_run_context`` does not combine a
|
||||
# freshly-reloaded ``AppConfig.run_events`` with a store still bound to
|
||||
# the previous backend.
|
||||
run_events_config = getattr(config, "run_events", None)
|
||||
app.state.run_events_config = run_events_config
|
||||
app.state.run_event_store = make_run_event_store(run_events_config)
|
||||
|
||||
# RunManager with store backing for persistence
|
||||
@@ -139,16 +187,20 @@ def get_thread_store(request: Request) -> ThreadMetaStore:
|
||||
def get_run_context(request: Request) -> RunContext:
|
||||
"""Build a :class:`RunContext` from ``app.state`` singletons.
|
||||
|
||||
Returns a *base* context with infrastructure dependencies.
|
||||
Returns a *base* context with infrastructure dependencies. The
|
||||
``app_config`` field is resolved live so per-run fields (e.g.
|
||||
``models[*].max_tokens``) follow ``config.yaml`` edits; the
|
||||
``event_store`` / ``run_events_config`` pair stays frozen to the snapshot
|
||||
captured in :func:`langgraph_runtime` so callers never see a store bound
|
||||
to one backend paired with a config pointing at another.
|
||||
"""
|
||||
config = get_config(request)
|
||||
return RunContext(
|
||||
checkpointer=get_checkpointer(request),
|
||||
store=get_store(request),
|
||||
event_store=get_run_event_store(request),
|
||||
run_events_config=getattr(config, "run_events", None),
|
||||
run_events_config=getattr(request.app.state, "run_events_config", None),
|
||||
thread_store=get_thread_store(request),
|
||||
app_config=config,
|
||||
app_config=get_config(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -66,6 +66,14 @@ class RunResponse(BaseModel):
|
||||
multitask_strategy: str = "reject"
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
llm_call_count: int = 0
|
||||
lead_agent_tokens: int = 0
|
||||
subagent_tokens: int = 0
|
||||
middleware_tokens: int = 0
|
||||
message_count: int = 0
|
||||
|
||||
|
||||
class ThreadTokenUsageModelBreakdown(BaseModel):
|
||||
@@ -111,6 +119,14 @@ def _record_to_response(record: RunRecord) -> RunResponse:
|
||||
multitask_strategy=record.multitask_strategy,
|
||||
created_at=record.created_at,
|
||||
updated_at=record.updated_at,
|
||||
total_input_tokens=record.total_input_tokens,
|
||||
total_output_tokens=record.total_output_tokens,
|
||||
total_tokens=record.total_tokens,
|
||||
llm_call_count=record.llm_call_count,
|
||||
lead_agent_tokens=record.lead_agent_tokens,
|
||||
subagent_tokens=record.subagent_tokens,
|
||||
middleware_tokens=record.middleware_tokens,
|
||||
message_count=record.message_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -402,8 +418,15 @@ async def list_run_events(
|
||||
|
||||
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse:
|
||||
async def thread_token_usage(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
include_active: bool = Query(default=False, description="Include running run progress snapshots"),
|
||||
) -> ThreadTokenUsageResponse:
|
||||
"""Thread-level token usage aggregation."""
|
||||
run_store = get_run_store(request)
|
||||
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
||||
if include_active:
|
||||
agg = await run_store.aggregate_tokens_by_thread(thread_id, include_active=True)
|
||||
else:
|
||||
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
||||
return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
|
||||
|
||||
@@ -15,7 +15,8 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.utils import convert_to_messages
|
||||
|
||||
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
@@ -76,21 +77,35 @@ def normalize_stream_modes(raw: list[str] | str | None) -> list[str]:
|
||||
|
||||
|
||||
def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Convert LangGraph Platform input format to LangChain state dict."""
|
||||
"""Convert LangGraph Platform input format to LangChain state dict.
|
||||
|
||||
Delegates dict→message coercion to ``langchain_core.messages.utils.convert_to_messages``
|
||||
so that ``additional_kwargs`` (e.g. uploaded-file metadata — gh #3132), ``id``,
|
||||
``name``, and non-human roles (ai/system/tool) survive unchanged. An earlier
|
||||
hand-rolled version only forwarded ``content`` and collapsed every role to
|
||||
``HumanMessage``, which silently stripped frontend-supplied attachments.
|
||||
|
||||
Malformed message dicts (missing ``role``/``type``/``content``, unsupported
|
||||
role, etc.) raise ``HTTPException(400)`` with the offending index, instead
|
||||
of bubbling up as a 500. The gateway is a system boundary, so per-entry
|
||||
validation errors are the right shape for clients to retry against.
|
||||
"""
|
||||
if raw_input is None:
|
||||
return {}
|
||||
messages = raw_input.get("messages")
|
||||
if messages and isinstance(messages, list):
|
||||
converted = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role", msg.get("type", "user"))
|
||||
content = msg.get("content", "")
|
||||
if role in ("user", "human"):
|
||||
converted.append(HumanMessage(content=content))
|
||||
else:
|
||||
# TODO: handle other message types (system, ai, tool)
|
||||
converted.append(HumanMessage(content=content))
|
||||
converted: list[Any] = []
|
||||
for index, msg in enumerate(messages):
|
||||
if isinstance(msg, BaseMessage):
|
||||
converted.append(msg)
|
||||
elif isinstance(msg, dict):
|
||||
try:
|
||||
converted.extend(convert_to_messages([msg]))
|
||||
except (ValueError, TypeError, NotImplementedError) as exc:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid message at input.messages[{index}]: {exc}",
|
||||
) from exc
|
||||
else:
|
||||
converted.append(msg)
|
||||
return {**raw_input, "messages": converted}
|
||||
|
||||
@@ -241,13 +241,6 @@ GET /api/mcp/config
|
||||
"GITHUB_TOKEN": "***"
|
||||
},
|
||||
"description": "GitHub operations"
|
||||
},
|
||||
"filesystem": {
|
||||
"enabled": false,
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
"description": "File system access"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,19 @@ DeerFlow supports configurable MCP servers and skills to extend its capabilities
|
||||
3. Configure each server’s command, arguments, and environment variables as needed.
|
||||
4. Restart the application to load and register MCP tools.
|
||||
|
||||
## Filesystem MCP Servers
|
||||
|
||||
DeerFlow already provides built-in file tools for thread-scoped workspace access.
|
||||
Do not add an MCP filesystem server for the same DeerFlow workspace. The
|
||||
overlapping file tools use different path semantics, which can make LLM tool
|
||||
selection and file access behavior unstable.
|
||||
|
||||
DeerFlow does not currently adapt the MCP Roots mode for filesystem servers. In
|
||||
particular, it does not publish per-thread MCP roots or map DeerFlow sandbox
|
||||
paths such as `/mnt/user-data/...` to paths accepted by
|
||||
`@modelcontextprotocol/server-filesystem`. Use DeerFlow's built-in file tools
|
||||
for DeerFlow workspace files.
|
||||
|
||||
## OAuth Support (HTTP/SSE MCP Servers)
|
||||
|
||||
For `http` and `sse` MCP servers, DeerFlow supports OAuth token acquisition and automatic token refresh.
|
||||
@@ -88,7 +101,6 @@ MCP servers expose tools that are automatically discovered and integrated into D
|
||||
|
||||
MCP servers can provide access to:
|
||||
|
||||
- **File systems**
|
||||
- **Databases** (e.g., PostgreSQL)
|
||||
- **External APIs** (e.g., GitHub, Brave Search)
|
||||
- **Browser automation** (e.g., Puppeteer)
|
||||
@@ -97,4 +109,4 @@ MCP servers can provide access to:
|
||||
## Learn More
|
||||
|
||||
For detailed documentation about the Model Context Protocol, visit:
|
||||
https://modelcontextprotocol.io
|
||||
https://modelcontextprotocol.io
|
||||
|
||||
@@ -29,6 +29,7 @@ from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||
from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
@@ -338,6 +339,15 @@ def _build_middlewares(
|
||||
if custom_middlewares:
|
||||
middlewares.extend(custom_middlewares)
|
||||
|
||||
# SafetyFinishReasonMiddleware — suppress tool execution when the provider
|
||||
# safety-terminated the response. Registered after custom middlewares so
|
||||
# that LangChain's reverse-order after_model dispatch runs Safety first;
|
||||
# cleared tool_calls then flow through Loop/Subagent accounting without
|
||||
# firing extra alarms. See safety_finish_reason_middleware.py docstring.
|
||||
safety_config = resolved_app_config.safety_finish_reason
|
||||
if safety_config.enabled:
|
||||
middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config))
|
||||
|
||||
# ClarificationMiddleware should always be last
|
||||
middlewares.append(ClarificationMiddleware())
|
||||
return middlewares
|
||||
|
||||
+6
-7
@@ -15,6 +15,7 @@ to the end of the message list as before_model + add_messages reducer would do.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
|
||||
@@ -109,10 +110,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
This normalizes model-bound causal order before provider serialization while
|
||||
preserving already-valid transcripts unchanged.
|
||||
"""
|
||||
tool_messages_by_id: dict[str, ToolMessage] = {}
|
||||
tool_messages_by_id: dict[str, deque[ToolMessage]] = defaultdict(deque)
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage):
|
||||
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
|
||||
tool_messages_by_id[msg.tool_call_id].append(msg)
|
||||
|
||||
tool_call_ids: set[str] = set()
|
||||
for msg in messages:
|
||||
@@ -124,7 +125,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
tool_call_ids.add(tc_id)
|
||||
|
||||
patched: list = []
|
||||
consumed_tool_msg_ids: set[str] = set()
|
||||
patch_count = 0
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
|
||||
@@ -136,13 +136,13 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
|
||||
for tc in self._message_tool_calls(msg):
|
||||
tc_id = tc.get("id")
|
||||
if not tc_id or tc_id in consumed_tool_msg_ids:
|
||||
if not tc_id:
|
||||
continue
|
||||
|
||||
existing_tool_msg = tool_messages_by_id.get(tc_id)
|
||||
tool_msg_queue = tool_messages_by_id.get(tc_id)
|
||||
existing_tool_msg = tool_msg_queue.popleft() if tool_msg_queue else None
|
||||
if existing_tool_msg is not None:
|
||||
patched.append(existing_tool_msg)
|
||||
consumed_tool_msg_ids.add(tc_id)
|
||||
else:
|
||||
patched.append(
|
||||
ToolMessage(
|
||||
@@ -152,7 +152,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
status="error",
|
||||
)
|
||||
)
|
||||
consumed_tool_msg_ids.add(tc_id)
|
||||
patch_count += 1
|
||||
|
||||
if patched == messages:
|
||||
|
||||
+317
@@ -0,0 +1,317 @@
|
||||
"""Suppress tool execution when the provider safety-terminated the response.
|
||||
|
||||
Background — see issue bytedance/deer-flow#3028.
|
||||
|
||||
Some providers (OpenAI ``finish_reason='content_filter'``, Anthropic
|
||||
``stop_reason='refusal'``, Gemini ``finish_reason='SAFETY'`` ...) can stop
|
||||
generation mid-stream while still returning partially-formed ``tool_calls``.
|
||||
LangChain's tool router treats any AIMessage with a non-empty ``tool_calls``
|
||||
field as "go execute these", so half-truncated arguments — e.g. a markdown
|
||||
``write_file`` that stops in the middle of a sentence — get dispatched as if
|
||||
they were complete. The agent then sees the truncated file, tries to fix it,
|
||||
gets filtered again, and loops.
|
||||
|
||||
This middleware sits at ``after_model`` and gates that behaviour: when a
|
||||
configured ``SafetyTerminationDetector`` fires *and* the AIMessage carries
|
||||
tool calls, we strip the tool calls (both structured and raw provider
|
||||
payloads), append a user-facing explanation, and stash observability fields
|
||||
in ``additional_kwargs.safety_termination`` so logs, traces, and SSE
|
||||
consumers can see what happened.
|
||||
|
||||
Hook choice: ``after_model`` (not ``wrap_model_call``) because the response
|
||||
is a *normal* return — not an exception — and we want to participate in the
|
||||
same after-model chain as ``LoopDetectionMiddleware``, with which we share
|
||||
the same tool-call-suppression mechanic but a different trigger.
|
||||
|
||||
Placement: register *after* ``LoopDetectionMiddleware`` in the middleware
|
||||
list. LangChain factory wires ``after_model`` edges in reverse list order
|
||||
(``langchain/agents/factory.py:add_edge("model", middleware_w_after_model[-1])``,
|
||||
then walks ``range(len-1, 0, -1)``), so the *last* registered middleware is
|
||||
the *first* to observe the model output. Registering Safety after Loop
|
||||
means Safety sees the raw response first, clears tool calls if it fires,
|
||||
and Loop then accounts against the cleaned message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.middlewares.safety_termination_detectors import (
|
||||
SafetyTermination,
|
||||
SafetyTerminationDetector,
|
||||
default_detectors,
|
||||
)
|
||||
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_USER_FACING_MESSAGE = (
|
||||
"The model provider stopped this response with a safety-related signal "
|
||||
"({reason_field}={reason_value!r}, detector={detector!r}). Any tool "
|
||||
"calls produced in this turn were suppressed because their arguments "
|
||||
"may be truncated and unsafe to execute. Please rephrase the request "
|
||||
"or ask for a narrower output."
|
||||
)
|
||||
|
||||
|
||||
class SafetyFinishReasonMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Strip tool_calls from AIMessages flagged by a SafetyTerminationDetector."""
|
||||
|
||||
def __init__(self, detectors: list[SafetyTerminationDetector] | None = None) -> None:
|
||||
super().__init__()
|
||||
# Copy so caller mutations after construction don't leak into us.
|
||||
self._detectors: list[SafetyTerminationDetector] = list(detectors) if detectors else default_detectors()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: SafetyFinishReasonConfig) -> SafetyFinishReasonMiddleware:
|
||||
"""Construct from validated Pydantic config, honouring the
|
||||
reflection-loaded detector list when provided.
|
||||
|
||||
An explicit empty list is intentionally rejected — it would silently
|
||||
disable detection while leaving the middleware in the chain, which
|
||||
is the worst of both worlds. Use ``enabled: false`` instead.
|
||||
"""
|
||||
if config.detectors is None:
|
||||
return cls()
|
||||
|
||||
if not config.detectors:
|
||||
raise ValueError("safety_finish_reason.detectors must be omitted (use built-ins) or contain at least one entry; use enabled=false to disable the middleware entirely.")
|
||||
|
||||
from deerflow.reflection import resolve_variable
|
||||
|
||||
detectors: list[SafetyTerminationDetector] = []
|
||||
for entry in config.detectors:
|
||||
detector_cls = resolve_variable(entry.use)
|
||||
kwargs = dict(entry.config) if entry.config else {}
|
||||
detector = detector_cls(**kwargs)
|
||||
if not isinstance(detector, SafetyTerminationDetector):
|
||||
raise TypeError(f"{entry.use} did not produce a SafetyTerminationDetector (got {type(detector).__name__}); ensure it has a `name` attribute and a `detect(message)` method")
|
||||
detectors.append(detector)
|
||||
return cls(detectors=detectors)
|
||||
|
||||
# ----- detection -------------------------------------------------------
|
||||
|
||||
def _detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
for detector in self._detectors:
|
||||
try:
|
||||
hit = detector.detect(message)
|
||||
except Exception: # noqa: BLE001 - never let a buggy detector break the agent run
|
||||
logger.exception("SafetyTerminationDetector %r raised; treating as no-match", getattr(detector, "name", type(detector).__name__))
|
||||
continue
|
||||
if hit is not None:
|
||||
return hit
|
||||
return None
|
||||
|
||||
# ----- message rewriting ----------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _append_user_message(content: object, text: str) -> str | list:
|
||||
"""Append a plain-text explanation to AIMessage content.
|
||||
|
||||
Mirrors ``LoopDetectionMiddleware._append_text`` so list-content
|
||||
responses (Anthropic thinking blocks, vLLM reasoning splits) keep
|
||||
their structure instead of being string-coerced into a TypeError.
|
||||
"""
|
||||
if content is None or content == "":
|
||||
return text
|
||||
if isinstance(content, list):
|
||||
return [*content, {"type": "text", "text": f"\n\n{text}"}]
|
||||
if isinstance(content, str):
|
||||
return content + f"\n\n{text}"
|
||||
return str(content) + f"\n\n{text}"
|
||||
|
||||
def _build_suppressed_message(
|
||||
self,
|
||||
message: AIMessage,
|
||||
termination: SafetyTermination,
|
||||
) -> AIMessage:
|
||||
suppressed_names = [tc.get("name") or "unknown" for tc in (message.tool_calls or [])]
|
||||
explanation = _USER_FACING_MESSAGE.format(
|
||||
reason_field=termination.reason_field,
|
||||
reason_value=termination.reason_value,
|
||||
detector=termination.detector,
|
||||
)
|
||||
new_content = self._append_user_message(message.content, explanation)
|
||||
|
||||
# clone_ai_message_with_tool_calls handles structured tool_calls,
|
||||
# raw additional_kwargs.tool_calls, and function_call in one shot.
|
||||
# It only rewrites finish_reason when the old value was "tool_calls",
|
||||
# which is not our case — content_filter / refusal / SAFETY stay put
|
||||
# so downstream SSE / converters keep seeing the real provider reason.
|
||||
cleared = clone_ai_message_with_tool_calls(message, [], content=new_content)
|
||||
|
||||
# Re-clone additional_kwargs so we don't accidentally mutate the
|
||||
# dict returned by clone_ai_message_with_tool_calls (which already
|
||||
# made a shallow copy, but downstream model_copy still references
|
||||
# it). Then stamp the observability record.
|
||||
kwargs = dict(getattr(cleared, "additional_kwargs", None) or {})
|
||||
kwargs["safety_termination"] = {
|
||||
"detector": termination.detector,
|
||||
"reason_field": termination.reason_field,
|
||||
"reason_value": termination.reason_value,
|
||||
"suppressed_tool_call_count": len(suppressed_names),
|
||||
"suppressed_tool_call_names": suppressed_names,
|
||||
"extras": dict(termination.extras) if termination.extras else {},
|
||||
}
|
||||
return cleared.model_copy(update={"additional_kwargs": kwargs})
|
||||
|
||||
# ----- observability ---------------------------------------------------
|
||||
|
||||
def _emit_event(
|
||||
self,
|
||||
termination: SafetyTermination,
|
||||
suppressed_names: list[str],
|
||||
runtime: Runtime,
|
||||
) -> None:
|
||||
"""Notify SSE consumers (e.g. the web UI) that a tool turn was
|
||||
suppressed so they can reconcile any "tool starting..." placeholders
|
||||
already streamed to the user. Failures are logged at debug and
|
||||
ignored — this is a best-effort signal."""
|
||||
try:
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
except Exception: # noqa: BLE001
|
||||
logger.debug("get_stream_writer unavailable; skipping safety_termination event", exc_info=True)
|
||||
return
|
||||
|
||||
thread_id = None
|
||||
if runtime is not None and getattr(runtime, "context", None):
|
||||
thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None
|
||||
|
||||
try:
|
||||
writer(
|
||||
{
|
||||
"type": "safety_termination",
|
||||
"detector": termination.detector,
|
||||
"reason_field": termination.reason_field,
|
||||
"reason_value": termination.reason_value,
|
||||
"suppressed_tool_call_count": len(suppressed_names),
|
||||
"suppressed_tool_call_names": suppressed_names,
|
||||
"thread_id": thread_id,
|
||||
}
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.debug("Failed to emit safety_termination stream event", exc_info=True)
|
||||
|
||||
def _record_audit_event(
|
||||
self,
|
||||
termination: SafetyTermination,
|
||||
message,
|
||||
tool_calls: list[dict],
|
||||
runtime: Runtime,
|
||||
) -> None:
|
||||
"""Write a ``middleware:safety_termination`` record to RunEventStore
|
||||
for post-run auditability.
|
||||
|
||||
The custom stream event in ``_emit_event`` is consumed by live SSE
|
||||
clients and disappears after the run; this event is persisted so an
|
||||
operator can answer "which runs were safety-suppressed today?" from
|
||||
a single SQL query without joining the message body. Worker exposes
|
||||
the run-scoped ``RunJournal`` via ``runtime.context["__run_journal"]``;
|
||||
absent in unit-test / subagent / no-event-store paths, in which case
|
||||
we silently skip.
|
||||
|
||||
Tool **arguments** are deliberately **not** recorded — those are the
|
||||
very content the provider filtered; persisting them would defeat the
|
||||
purpose of the safety filter. Names / count / ids are sufficient for
|
||||
audit and debugging (issue #3028 review).
|
||||
"""
|
||||
journal = None
|
||||
if runtime is not None and getattr(runtime, "context", None):
|
||||
context = runtime.context
|
||||
if isinstance(context, dict):
|
||||
journal = context.get("__run_journal")
|
||||
if journal is None:
|
||||
return
|
||||
|
||||
suppressed_names = [tc.get("name") or "unknown" for tc in tool_calls]
|
||||
suppressed_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
|
||||
|
||||
changes = {
|
||||
"detector": termination.detector,
|
||||
"reason_field": termination.reason_field,
|
||||
"reason_value": termination.reason_value,
|
||||
"suppressed_tool_call_count": len(tool_calls),
|
||||
"suppressed_tool_call_names": suppressed_names,
|
||||
"suppressed_tool_call_ids": suppressed_ids,
|
||||
"message_id": getattr(message, "id", None),
|
||||
"extras": dict(termination.extras) if termination.extras else {},
|
||||
}
|
||||
|
||||
try:
|
||||
journal.record_middleware(
|
||||
tag="safety_termination",
|
||||
name=type(self).__name__,
|
||||
hook="after_model",
|
||||
action="suppress_tool_calls",
|
||||
changes=changes,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
# Audit-event persistence must never break agent execution.
|
||||
logger.debug("Failed to record middleware:safety_termination event", exc_info=True)
|
||||
|
||||
# ----- main apply ------------------------------------------------------
|
||||
|
||||
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
return None
|
||||
|
||||
# Issue scope: only intervene when there's something to suppress.
|
||||
# ``content_filter`` without tool_calls is allowed through unchanged
|
||||
# so the partial text response (if any) reaches the user naturally.
|
||||
tool_calls = last.tool_calls
|
||||
if not tool_calls:
|
||||
return None
|
||||
|
||||
termination = self._detect(last)
|
||||
if termination is None:
|
||||
return None
|
||||
|
||||
patched = self._build_suppressed_message(last, termination)
|
||||
|
||||
thread_id = None
|
||||
if runtime is not None and getattr(runtime, "context", None):
|
||||
thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None
|
||||
|
||||
logger.warning(
|
||||
"Provider safety termination detected — suppressed %d tool call(s)",
|
||||
len(tool_calls),
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"detector": termination.detector,
|
||||
"reason_field": termination.reason_field,
|
||||
"reason_value": termination.reason_value,
|
||||
"suppressed_tool_call_names": [tc.get("name") for tc in tool_calls],
|
||||
},
|
||||
)
|
||||
|
||||
self._emit_event(termination, [tc.get("name") or "unknown" for tc in tool_calls], runtime)
|
||||
self._record_audit_event(termination, last, list(tool_calls), runtime)
|
||||
|
||||
return {"messages": [patched]}
|
||||
|
||||
# ----- hooks -----------------------------------------------------------
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
@@ -0,0 +1,237 @@
|
||||
"""Detectors for provider-side safety termination signals.
|
||||
|
||||
Different LLM providers signal "I stopped this response for safety reasons"
|
||||
through different fields with different values. This module defines a small
|
||||
strategy interface and three built-in detectors that cover the major
|
||||
providers DeerFlow supports today. New providers (Wenxin, Hunyuan, Bedrock
|
||||
adapters, in-house gateways, ...) can be added by implementing
|
||||
``SafetyTerminationDetector`` and wiring it through
|
||||
``config.yaml: safety_finish_reason.detectors``.
|
||||
|
||||
The middleware that consumes these detectors lives in
|
||||
``safety_finish_reason_middleware.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SafetyTermination:
|
||||
"""A detected safety-related termination signal.
|
||||
|
||||
Attributes:
|
||||
detector: Name of the detector that produced this result. Used for
|
||||
observability so operators can see which provider rule fired.
|
||||
reason_field: The message metadata field that carried the signal
|
||||
(e.g. ``finish_reason``, ``stop_reason``).
|
||||
reason_value: The actual value of that field
|
||||
(e.g. ``content_filter``, ``refusal``, ``SAFETY``).
|
||||
extras: Provider-specific metadata that may help downstream
|
||||
consumers (e.g. Azure OpenAI content_filter_results, Gemini
|
||||
safety_ratings). Detectors are free to populate or skip this.
|
||||
"""
|
||||
|
||||
detector: str
|
||||
reason_field: str
|
||||
reason_value: str
|
||||
extras: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SafetyTerminationDetector(Protocol):
|
||||
"""Strategy interface for provider safety termination detection."""
|
||||
|
||||
name: str
|
||||
|
||||
def detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
"""Return a SafetyTermination if *message* indicates provider safety
|
||||
termination, otherwise return ``None``.
|
||||
|
||||
Implementations must be side-effect free and tolerant of missing or
|
||||
oddly-typed metadata — detectors run on every model response.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def _get_metadata_value(message: AIMessage, field_name: str) -> str | None:
|
||||
"""Read a string-typed value from either ``response_metadata`` or
|
||||
``additional_kwargs``.
|
||||
|
||||
LangChain provider adapters are inconsistent about where they stash
|
||||
provider stop signals. Most modern adapters use ``response_metadata``,
|
||||
but some legacy / passthrough paths still surface them via
|
||||
``additional_kwargs``. We check both, in that order, and only accept
|
||||
string values — Pydantic enums or dicts are ignored so we never raise
|
||||
on malformed inputs.
|
||||
"""
|
||||
for container_name in ("response_metadata", "additional_kwargs"):
|
||||
container = getattr(message, container_name, None) or {}
|
||||
if not isinstance(container, dict):
|
||||
continue
|
||||
value = container.get(field_name)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
class OpenAICompatibleContentFilterDetector:
|
||||
"""OpenAI-compatible content_filter signal.
|
||||
|
||||
Covers OpenAI, Azure OpenAI, Moonshot/Kimi, DeepSeek, Mistral, vLLM,
|
||||
Qwen (OpenAI-compatible mode), and any other adapter that follows the
|
||||
OpenAI ``finish_reason`` convention.
|
||||
|
||||
Some Chinese providers ship custom OpenAI-compatible gateways that use
|
||||
alternative tokens like ``sensitive`` or ``violation``. Extend the set
|
||||
via the ``finish_reasons`` kwarg in config.
|
||||
"""
|
||||
|
||||
name = "openai_compatible_content_filter"
|
||||
|
||||
def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None:
|
||||
configured = finish_reasons if finish_reasons is not None else ("content_filter",)
|
||||
self._finish_reasons: frozenset[str] = frozenset(r.lower() for r in configured)
|
||||
|
||||
def detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
value = _get_metadata_value(message, "finish_reason")
|
||||
if value is None or value.lower() not in self._finish_reasons:
|
||||
return None
|
||||
|
||||
extras: dict[str, Any] = {}
|
||||
# Azure OpenAI ships a structured content_filter_results block; carry it
|
||||
# through so operators can see *what* was filtered without re-tracing.
|
||||
response_metadata = getattr(message, "response_metadata", None) or {}
|
||||
if isinstance(response_metadata, dict):
|
||||
filter_results = response_metadata.get("content_filter_results")
|
||||
if filter_results:
|
||||
extras["content_filter_results"] = filter_results
|
||||
|
||||
return SafetyTermination(
|
||||
detector=self.name,
|
||||
reason_field="finish_reason",
|
||||
reason_value=value,
|
||||
extras=extras,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicRefusalDetector:
|
||||
"""Anthropic ``stop_reason == "refusal"`` signal.
|
||||
|
||||
Anthropic models surface safety refusals via a dedicated ``stop_reason``
|
||||
rather than ``finish_reason``. See:
|
||||
https://platform.claude.com/docs/en/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals
|
||||
"""
|
||||
|
||||
name = "anthropic_refusal"
|
||||
|
||||
def __init__(self, stop_reasons: list[str] | tuple[str, ...] | None = None) -> None:
|
||||
configured = stop_reasons if stop_reasons is not None else ("refusal",)
|
||||
self._stop_reasons: frozenset[str] = frozenset(r.lower() for r in configured)
|
||||
|
||||
def detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
value = _get_metadata_value(message, "stop_reason")
|
||||
if value is None or value.lower() not in self._stop_reasons:
|
||||
return None
|
||||
return SafetyTermination(
|
||||
detector=self.name,
|
||||
reason_field="stop_reason",
|
||||
reason_value=value,
|
||||
)
|
||||
|
||||
|
||||
class GeminiSafetyDetector:
|
||||
"""Gemini / Vertex AI safety-related finish reasons.
|
||||
|
||||
Gemini uses the same ``finish_reason`` field as OpenAI but with an
|
||||
enumerated upper-case taxonomy. The default set covers every Gemini
|
||||
finish_reason that means "the model stopped because the content/image
|
||||
tripped a safety, blocklist, recitation, or PII filter" — i.e. cases
|
||||
where any tool_calls returned alongside are likely truncated/
|
||||
unreliable. Full enum:
|
||||
https://docs.cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform_v1.types.Candidate.FinishReason
|
||||
|
||||
Intentionally **excluded** from the default set:
|
||||
- ``STOP`` — normal termination.
|
||||
- ``MAX_TOKENS`` — output length truncation, not safety
|
||||
(same root failure mode as
|
||||
content_filter, but issue #3028
|
||||
scopes it out; expose separately if
|
||||
desired).
|
||||
- ``LANGUAGE`` / ``NO_IMAGE`` — capability mismatches, unrelated to
|
||||
safety; tool_calls would be absent
|
||||
anyway.
|
||||
- ``MALFORMED_FUNCTION_CALL`` /
|
||||
``UNEXPECTED_TOOL_CALL`` — tool-call protocol errors. The
|
||||
tool_calls are *also* unreliable
|
||||
here, but the failure category is
|
||||
distinct from safety filtering;
|
||||
handle in a dedicated detector to
|
||||
keep observability records honest.
|
||||
- ``OTHER`` / ``IMAGE_OTHER`` /
|
||||
``FINISH_REASON_UNSPECIFIED`` — too broad to enable by default;
|
||||
opt in via ``finish_reasons=`` if
|
||||
your provider abuses these.
|
||||
"""
|
||||
|
||||
name = "gemini_safety"
|
||||
|
||||
_DEFAULT_FINISH_REASONS = (
|
||||
# Text safety
|
||||
"SAFETY",
|
||||
"BLOCKLIST",
|
||||
"PROHIBITED_CONTENT",
|
||||
"SPII",
|
||||
"RECITATION",
|
||||
# Image safety (multimodal generation)
|
||||
"IMAGE_SAFETY",
|
||||
"IMAGE_PROHIBITED_CONTENT",
|
||||
"IMAGE_RECITATION",
|
||||
)
|
||||
|
||||
def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None:
|
||||
configured = finish_reasons if finish_reasons is not None else self._DEFAULT_FINISH_REASONS
|
||||
self._finish_reasons: frozenset[str] = frozenset(r.upper() for r in configured)
|
||||
|
||||
def detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
value = _get_metadata_value(message, "finish_reason")
|
||||
if value is None or value.upper() not in self._finish_reasons:
|
||||
return None
|
||||
|
||||
extras: dict[str, Any] = {}
|
||||
response_metadata = getattr(message, "response_metadata", None) or {}
|
||||
if isinstance(response_metadata, dict):
|
||||
# Gemini surfaces per-category scoring under safety_ratings.
|
||||
ratings = response_metadata.get("safety_ratings")
|
||||
if ratings:
|
||||
extras["safety_ratings"] = ratings
|
||||
|
||||
return SafetyTermination(
|
||||
detector=self.name,
|
||||
reason_field="finish_reason",
|
||||
reason_value=value,
|
||||
extras=extras,
|
||||
)
|
||||
|
||||
|
||||
def default_detectors() -> list[SafetyTerminationDetector]:
|
||||
"""Built-in detector set used when no custom detectors are configured."""
|
||||
return [
|
||||
OpenAICompatibleContentFilterDetector(),
|
||||
AnthropicRefusalDetector(),
|
||||
GeminiSafetyDetector(),
|
||||
]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AnthropicRefusalDetector",
|
||||
"GeminiSafetyDetector",
|
||||
"OpenAICompatibleContentFilterDetector",
|
||||
"SafetyTermination",
|
||||
"SafetyTerminationDetector",
|
||||
"default_detectors",
|
||||
]
|
||||
+10
@@ -164,4 +164,14 @@ def build_subagent_runtime_middlewares(
|
||||
|
||||
middlewares.append(ViewImageMiddleware())
|
||||
|
||||
# Same provider safety-termination guard the lead agent uses — subagents
|
||||
# are equally exposed to truncated tool_calls returned with
|
||||
# finish_reason=content_filter (and friends), and the bad call would then
|
||||
# propagate back to the lead agent via the task tool result.
|
||||
safety_config = app_config.safety_finish_reason
|
||||
if safety_config.enabled:
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
|
||||
middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config))
|
||||
|
||||
return middlewares
|
||||
|
||||
@@ -20,6 +20,7 @@ from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.run_events_config import RunEventsConfig
|
||||
from deerflow.config.runtime_paths import existing_project_file
|
||||
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
|
||||
from deerflow.config.skills_config import SkillsConfig
|
||||
@@ -102,6 +103,7 @@ class AppConfig(BaseModel):
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
||||
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
||||
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
|
||||
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Configuration for SafetyFinishReasonMiddleware.
|
||||
|
||||
Mirrors the shape of GuardrailsConfig: detectors are loaded by class path
|
||||
through ``deerflow.reflection.resolve_variable`` (same loader the
|
||||
``guardrails.provider`` config uses) so users can drop in custom provider
|
||||
detectors without modifying core code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SafetyDetectorConfig(BaseModel):
|
||||
"""One detector entry under ``safety_finish_reason.detectors``."""
|
||||
|
||||
use: str = Field(
|
||||
description=("Class path of a SafetyTerminationDetector implementation (e.g. 'deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector')."),
|
||||
)
|
||||
config: dict = Field(
|
||||
default_factory=dict,
|
||||
description="Constructor kwargs passed to the detector class.",
|
||||
)
|
||||
|
||||
|
||||
class SafetyFinishReasonConfig(BaseModel):
|
||||
"""Configuration for the SafetyFinishReasonMiddleware.
|
||||
|
||||
The middleware intercepts AIMessages where the provider signaled a
|
||||
safety-related termination (e.g. OpenAI ``finish_reason='content_filter'``)
|
||||
while still returning tool calls, and suppresses those tool calls so the
|
||||
half-truncated arguments never execute.
|
||||
"""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Master switch for the SafetyFinishReasonMiddleware.",
|
||||
)
|
||||
detectors: list[SafetyDetectorConfig] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Custom detector list. Leave unset (None) to use the built-in "
|
||||
"set covering OpenAI-compatible content_filter, Anthropic "
|
||||
"refusal, and Gemini SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/"
|
||||
"RECITATION. Provide a non-null list to fully override."
|
||||
),
|
||||
)
|
||||
@@ -134,9 +134,25 @@ def reset_mcp_tools_cache() -> None:
|
||||
"""Reset the MCP tools cache.
|
||||
|
||||
This is useful for testing or when you want to reload MCP tools.
|
||||
Also closes all persistent MCP sessions so they are recreated on
|
||||
the next tool load.
|
||||
"""
|
||||
global _mcp_tools_cache, _cache_initialized, _config_mtime
|
||||
_mcp_tools_cache = None
|
||||
_cache_initialized = False
|
||||
_config_mtime = None
|
||||
|
||||
# Close persistent sessions – they will be recreated by the next
|
||||
# get_mcp_tools() call with the (possibly updated) connection config.
|
||||
try:
|
||||
from deerflow.mcp.session_pool import get_session_pool
|
||||
|
||||
pool = get_session_pool()
|
||||
pool.close_all_sync()
|
||||
except Exception:
|
||||
logger.debug("Could not close MCP session pool on cache reset", exc_info=True)
|
||||
|
||||
from deerflow.mcp.session_pool import reset_session_pool
|
||||
|
||||
reset_session_pool()
|
||||
logger.info("MCP tools cache reset")
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
"""Persistent MCP session pool for stateful tool calls.
|
||||
|
||||
When MCP tools are loaded via langchain-mcp-adapters with ``session=None``,
|
||||
each tool call creates a new MCP session. For stateful servers like Playwright,
|
||||
this means browser state (opened pages, filled forms) is lost between calls.
|
||||
|
||||
This module provides a session pool that maintains persistent MCP sessions,
|
||||
scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id —
|
||||
so that consecutive tool calls share the same session and server-side state.
|
||||
Sessions are evicted in LRU order when the pool reaches capacity.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from mcp import ClientSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPSessionPool:
|
||||
"""Manages persistent MCP sessions scoped by ``(server_name, scope_key)``."""
|
||||
|
||||
MAX_SESSIONS = 256
|
||||
SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session via run_coroutine_threadsafe
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: OrderedDict[
|
||||
tuple[str, str],
|
||||
tuple[ClientSession, asyncio.AbstractEventLoop],
|
||||
] = OrderedDict()
|
||||
self._context_managers: dict[tuple[str, str], Any] = {}
|
||||
# threading.Lock is not bound to any event loop, so it is safe to
|
||||
# acquire from both async paths and sync/worker-thread paths.
|
||||
self._lock = threading.Lock()
|
||||
|
||||
async def get_session(
|
||||
self,
|
||||
server_name: str,
|
||||
scope_key: str,
|
||||
connection: dict[str, Any],
|
||||
) -> ClientSession:
|
||||
"""Get or create a persistent MCP session.
|
||||
|
||||
If an existing session was created in a different event loop (e.g.
|
||||
the sync-wrapper path), it is closed and replaced with a fresh one
|
||||
in the current loop.
|
||||
|
||||
Args:
|
||||
server_name: MCP server name.
|
||||
scope_key: Isolation key (typically thread_id).
|
||||
connection: Connection configuration for ``create_session``.
|
||||
|
||||
Returns:
|
||||
An initialized ``ClientSession``.
|
||||
"""
|
||||
key = (server_name, scope_key)
|
||||
current_loop = asyncio.get_running_loop()
|
||||
|
||||
# Phase 1: inspect/mutate the registry under the thread lock (no awaits).
|
||||
cms_to_close: list[tuple[tuple[str, str], Any]] = []
|
||||
with self._lock:
|
||||
if key in self._entries:
|
||||
session, loop = self._entries[key]
|
||||
if loop is current_loop:
|
||||
self._entries.move_to_end(key)
|
||||
return session
|
||||
# Session belongs to a different event loop – evict it.
|
||||
cm = self._context_managers.pop(key, None)
|
||||
self._entries.pop(key)
|
||||
if cm is not None:
|
||||
cms_to_close.append((key, cm))
|
||||
|
||||
# Evict LRU entries when at capacity.
|
||||
while len(self._entries) >= self.MAX_SESSIONS:
|
||||
oldest_key = next(iter(self._entries))
|
||||
cm = self._context_managers.pop(oldest_key, None)
|
||||
self._entries.pop(oldest_key)
|
||||
if cm is not None:
|
||||
cms_to_close.append((oldest_key, cm))
|
||||
|
||||
# Phase 2: async cleanup outside the lock so we never await while holding it.
|
||||
for close_key, cm in cms_to_close:
|
||||
try:
|
||||
await cm.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error closing MCP session %s", close_key, exc_info=True)
|
||||
|
||||
from langchain_mcp_adapters.sessions import create_session
|
||||
|
||||
cm = create_session(connection)
|
||||
session = await cm.__aenter__()
|
||||
await session.initialize()
|
||||
|
||||
# Phase 3: register the new session under the lock.
|
||||
with self._lock:
|
||||
self._entries[key] = (session, current_loop)
|
||||
self._context_managers[key] = cm
|
||||
logger.info("Created persistent MCP session for %s/%s", server_name, scope_key)
|
||||
return session
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cleanup helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _close_cm(self, key: tuple[str, str], cm: Any) -> None:
|
||||
"""Close a single context manager (must be called WITHOUT the lock)."""
|
||||
try:
|
||||
await cm.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error closing MCP session %s", key, exc_info=True)
|
||||
|
||||
async def close_scope(self, scope_key: str) -> None:
|
||||
"""Close all sessions for a given scope (e.g. thread_id)."""
|
||||
with self._lock:
|
||||
keys = [k for k in self._entries if k[1] == scope_key]
|
||||
cms = [(k, self._context_managers.pop(k, None)) for k in keys]
|
||||
for k in keys:
|
||||
self._entries.pop(k, None)
|
||||
for key, cm in cms:
|
||||
if cm is not None:
|
||||
await self._close_cm(key, cm)
|
||||
|
||||
async def close_server(self, server_name: str) -> None:
|
||||
"""Close all sessions for a given server."""
|
||||
with self._lock:
|
||||
keys = [k for k in self._entries if k[0] == server_name]
|
||||
cms = [(k, self._context_managers.pop(k, None)) for k in keys]
|
||||
for k in keys:
|
||||
self._entries.pop(k, None)
|
||||
for key, cm in cms:
|
||||
if cm is not None:
|
||||
await self._close_cm(key, cm)
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close every managed session."""
|
||||
with self._lock:
|
||||
cms = list(self._context_managers.items())
|
||||
self._context_managers.clear()
|
||||
self._entries.clear()
|
||||
for key, cm in cms:
|
||||
await self._close_cm(key, cm)
|
||||
|
||||
def close_all_sync(self) -> None:
|
||||
"""Close all sessions using their owning event loops (synchronous).
|
||||
|
||||
Each session is closed on the loop it was created in, avoiding
|
||||
cross-loop resource leaks. Safe to call from any thread without an
|
||||
active event loop.
|
||||
"""
|
||||
with self._lock:
|
||||
entries = list(self._entries.items())
|
||||
cms = dict(self._context_managers)
|
||||
self._entries.clear()
|
||||
self._context_managers.clear()
|
||||
|
||||
for key, (_, loop) in entries:
|
||||
cm = cms.get(key)
|
||||
if cm is None or loop.is_closed():
|
||||
continue
|
||||
try:
|
||||
if loop.is_running():
|
||||
# Schedule on the owning loop from this (different) thread.
|
||||
future = asyncio.run_coroutine_threadsafe(cm.__aexit__(None, None, None), loop)
|
||||
future.result(timeout=self.SESSION_CLOSE_TIMEOUT)
|
||||
else:
|
||||
loop.run_until_complete(cm.__aexit__(None, None, None))
|
||||
except Exception:
|
||||
logger.debug("Error closing MCP session %s during sync close", key, exc_info=True)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Module-level singleton
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_pool: MCPSessionPool | None = None
|
||||
_pool_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_session_pool() -> MCPSessionPool:
|
||||
"""Return the global session-pool singleton."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
with _pool_lock:
|
||||
if _pool is None:
|
||||
_pool = MCPSessionPool()
|
||||
return _pool
|
||||
|
||||
|
||||
def reset_session_pool() -> None:
|
||||
"""Reset the singleton (for tests)."""
|
||||
global _pool
|
||||
_pool = None
|
||||
@@ -1,21 +1,181 @@
|
||||
"""Load MCP tools using langchain-mcp-adapters."""
|
||||
"""Load MCP tools using langchain-mcp-adapters with persistent sessions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
from langgraph.config import get_config
|
||||
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.mcp.client import build_servers_config
|
||||
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
|
||||
from deerflow.mcp.session_pool import get_session_pool
|
||||
from deerflow.reflection import resolve_variable
|
||||
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_thread_id(runtime: Runtime | None) -> str:
|
||||
"""Extract thread_id from the injected tool runtime or LangGraph config."""
|
||||
if runtime is not None:
|
||||
tid = runtime.context.get("thread_id") if runtime.context else None
|
||||
if tid is not None:
|
||||
return str(tid)
|
||||
config = runtime.config or {}
|
||||
tid = config.get("configurable", {}).get("thread_id")
|
||||
if tid is not None:
|
||||
return str(tid)
|
||||
|
||||
try:
|
||||
tid = get_config().get("configurable", {}).get("thread_id")
|
||||
return str(tid) if tid is not None else "default"
|
||||
except RuntimeError:
|
||||
return "default"
|
||||
|
||||
|
||||
def _convert_call_tool_result(call_tool_result: Any) -> Any:
|
||||
"""Convert an MCP CallToolResult to the LangChain ``content_and_artifact`` format.
|
||||
|
||||
Implements the same conversion logic as the adapter without relying on
|
||||
the private ``langchain_mcp_adapters.tools._convert_call_tool_result`` symbol.
|
||||
"""
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.messages.content import create_file_block, create_image_block, create_text_block
|
||||
from langchain_core.tools import ToolException
|
||||
from mcp.types import EmbeddedResource, ImageContent, ResourceLink, TextContent, TextResourceContents
|
||||
|
||||
# Pass ToolMessage through directly (interceptor short-circuit).
|
||||
if isinstance(call_tool_result, ToolMessage):
|
||||
return call_tool_result, None
|
||||
|
||||
# Pass LangGraph Command through directly when langgraph is installed.
|
||||
try:
|
||||
from langgraph.types import Command
|
||||
|
||||
if isinstance(call_tool_result, Command):
|
||||
return call_tool_result, None
|
||||
except ImportError:
|
||||
# langgraph is optional; if unavailable, continue with standard MCP content conversion.
|
||||
pass
|
||||
|
||||
# Convert MCP content blocks to LangChain content blocks.
|
||||
lc_content = []
|
||||
for item in call_tool_result.content:
|
||||
if isinstance(item, TextContent):
|
||||
lc_content.append(create_text_block(text=item.text))
|
||||
elif isinstance(item, ImageContent):
|
||||
lc_content.append(create_image_block(base64=item.data, mime_type=item.mimeType))
|
||||
elif isinstance(item, ResourceLink):
|
||||
mime = item.mimeType or None
|
||||
if mime and mime.startswith("image/"):
|
||||
lc_content.append(create_image_block(url=str(item.uri), mime_type=mime))
|
||||
else:
|
||||
lc_content.append(create_file_block(url=str(item.uri), mime_type=mime))
|
||||
elif isinstance(item, EmbeddedResource):
|
||||
from mcp.types import BlobResourceContents
|
||||
|
||||
res = item.resource
|
||||
if isinstance(res, TextResourceContents):
|
||||
lc_content.append(create_text_block(text=res.text))
|
||||
elif isinstance(res, BlobResourceContents):
|
||||
mime = res.mimeType or None
|
||||
if mime and mime.startswith("image/"):
|
||||
lc_content.append(create_image_block(base64=res.blob, mime_type=mime))
|
||||
else:
|
||||
lc_content.append(create_file_block(base64=res.blob, mime_type=mime))
|
||||
else:
|
||||
lc_content.append(create_text_block(text=str(res)))
|
||||
else:
|
||||
lc_content.append(create_text_block(text=str(item)))
|
||||
|
||||
if call_tool_result.isError:
|
||||
error_parts = [item["text"] for item in lc_content if isinstance(item, dict) and item.get("type") == "text"]
|
||||
raise ToolException("\n".join(error_parts) if error_parts else str(lc_content))
|
||||
|
||||
artifact = None
|
||||
if call_tool_result.structuredContent is not None:
|
||||
artifact = {"structured_content": call_tool_result.structuredContent}
|
||||
|
||||
return lc_content, artifact
|
||||
|
||||
|
||||
def _make_session_pool_tool(
|
||||
tool: BaseTool,
|
||||
server_name: str,
|
||||
connection: dict[str, Any],
|
||||
tool_interceptors: list[Any] | None = None,
|
||||
) -> BaseTool:
|
||||
"""Wrap an MCP tool so it reuses a persistent session from the pool.
|
||||
|
||||
Replaces the per-call session creation with pool-managed sessions scoped
|
||||
by ``(server_name, thread_id)``. This ensures stateful MCP servers (e.g.
|
||||
Playwright) keep their state across tool calls within the same thread.
|
||||
|
||||
The configured ``tool_interceptors`` (OAuth, custom) are preserved and
|
||||
applied on every call before invoking the pooled session.
|
||||
"""
|
||||
# Strip the server-name prefix to recover the original MCP tool name.
|
||||
original_name = tool.name
|
||||
prefix = f"{server_name}_"
|
||||
if original_name.startswith(prefix):
|
||||
original_name = original_name[len(prefix) :]
|
||||
|
||||
pool = get_session_pool()
|
||||
|
||||
async def call_with_persistent_session(
|
||||
runtime: Runtime | None = None,
|
||||
**arguments: Any,
|
||||
) -> Any:
|
||||
thread_id = _extract_thread_id(runtime)
|
||||
session = await pool.get_session(server_name, thread_id, connection)
|
||||
|
||||
if tool_interceptors:
|
||||
from langchain_mcp_adapters.interceptors import MCPToolCallRequest
|
||||
|
||||
async def base_handler(request: MCPToolCallRequest) -> Any:
|
||||
return await session.call_tool(request.name, request.args)
|
||||
|
||||
handler = base_handler
|
||||
for interceptor in reversed(tool_interceptors):
|
||||
outer = handler
|
||||
|
||||
async def wrapped(req: Any, _i: Any = interceptor, _h: Any = outer) -> Any:
|
||||
return await _i(req, _h)
|
||||
|
||||
handler = wrapped
|
||||
|
||||
request = MCPToolCallRequest(
|
||||
name=original_name,
|
||||
args=arguments,
|
||||
server_name=server_name,
|
||||
runtime=runtime,
|
||||
)
|
||||
call_tool_result = await handler(request)
|
||||
else:
|
||||
call_tool_result = await session.call_tool(original_name, arguments)
|
||||
|
||||
return _convert_call_tool_result(call_tool_result)
|
||||
|
||||
return StructuredTool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
args_schema=tool.args_schema,
|
||||
coroutine=call_with_persistent_session,
|
||||
response_format="content_and_artifact",
|
||||
metadata=tool.metadata,
|
||||
)
|
||||
|
||||
|
||||
async def get_mcp_tools() -> list[BaseTool]:
|
||||
"""Get all tools from enabled MCP servers.
|
||||
|
||||
Tools are wrapped with persistent-session logic so that consecutive
|
||||
calls within the same thread reuse the same MCP session.
|
||||
|
||||
Returns:
|
||||
List of LangChain tools from all enabled MCP servers.
|
||||
"""
|
||||
@@ -50,7 +210,7 @@ async def get_mcp_tools() -> list[BaseTool]:
|
||||
existing_headers["Authorization"] = auth_header
|
||||
servers_config[server_name]["headers"] = existing_headers
|
||||
|
||||
tool_interceptors = []
|
||||
tool_interceptors: list[Any] = []
|
||||
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
|
||||
if oauth_interceptor is not None:
|
||||
tool_interceptors.append(oauth_interceptor)
|
||||
@@ -74,20 +234,42 @@ async def get_mcp_tools() -> list[BaseTool]:
|
||||
elif interceptor is not None:
|
||||
logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True)
|
||||
logger.warning(
|
||||
f"Failed to load MCP interceptor {interceptor_path}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True)
|
||||
client = MultiServerMCPClient(
|
||||
servers_config,
|
||||
tool_interceptors=tool_interceptors,
|
||||
tool_name_prefix=True,
|
||||
)
|
||||
|
||||
# Get all tools from all servers
|
||||
# Get all tools from all servers (discovers tool definitions via
|
||||
# temporary sessions – the persistent-session wrapping is applied below).
|
||||
tools = await client.get_tools()
|
||||
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
|
||||
|
||||
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
||||
# Wrap each tool with persistent-session logic.
|
||||
wrapped_tools: list[BaseTool] = []
|
||||
for tool in tools:
|
||||
tool_server: str | None = None
|
||||
for name in servers_config:
|
||||
if tool.name.startswith(f"{name}_"):
|
||||
tool_server = name
|
||||
break
|
||||
|
||||
if tool_server is not None:
|
||||
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
|
||||
else:
|
||||
wrapped_tools.append(tool)
|
||||
|
||||
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
||||
for tool in wrapped_tools:
|
||||
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
|
||||
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
|
||||
|
||||
return tools
|
||||
return wrapped_tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load MCP tools: {e}", exc_info=True)
|
||||
|
||||
@@ -227,9 +227,48 @@ class RunRepository(RunStore):
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
await session.commit()
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
async def update_run_progress(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
total_input_tokens: int | None = None,
|
||||
total_output_tokens: int | None = None,
|
||||
total_tokens: int | None = None,
|
||||
llm_call_count: int | None = None,
|
||||
lead_agent_tokens: int | None = None,
|
||||
subagent_tokens: int | None = None,
|
||||
middleware_tokens: int | None = None,
|
||||
message_count: int | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
first_human_message: str | None = None,
|
||||
) -> None:
|
||||
"""Update token usage + convenience fields while a run is still active."""
|
||||
values: dict[str, Any] = {"updated_at": datetime.now(UTC)}
|
||||
optional_counters = {
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"llm_call_count": llm_call_count,
|
||||
"lead_agent_tokens": lead_agent_tokens,
|
||||
"subagent_tokens": subagent_tokens,
|
||||
"middleware_tokens": middleware_tokens,
|
||||
"message_count": message_count,
|
||||
}
|
||||
for key, value in optional_counters.items():
|
||||
if value is not None:
|
||||
values[key] = value
|
||||
if last_ai_message is not None:
|
||||
values["last_ai_message"] = last_ai_message[:2000]
|
||||
if first_human_message is not None:
|
||||
values["first_human_message"] = first_human_message[:2000]
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id, RunRow.status == "running").values(**values))
|
||||
await session.commit()
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
|
||||
"""Aggregate token usage via a single SQL GROUP BY query."""
|
||||
_completed = RunRow.status.in_(("success", "error"))
|
||||
statuses = ("success", "error", "running") if include_active else ("success", "error")
|
||||
_completed = RunRow.status.in_(statuses)
|
||||
_thread = RunRow.thread_id == thread_id
|
||||
model_name = func.coalesce(RunRow.model_name, "unknown")
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID
|
||||
@@ -46,6 +46,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
*,
|
||||
track_token_usage: bool = True,
|
||||
flush_threshold: int = 20,
|
||||
progress_reporter: Callable[[dict], Awaitable[None]] | None = None,
|
||||
progress_flush_interval: float = 5.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.run_id = run_id
|
||||
@@ -53,10 +55,16 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._store = event_store
|
||||
self._track_tokens = track_token_usage
|
||||
self._flush_threshold = flush_threshold
|
||||
self._progress_reporter = progress_reporter
|
||||
self._progress_flush_interval = progress_flush_interval
|
||||
|
||||
# Write buffer
|
||||
self._buffer: list[dict] = []
|
||||
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
|
||||
self._pending_progress_task: asyncio.Task[None] | None = None
|
||||
self._pending_progress_delayed = False
|
||||
self._progress_dirty = False
|
||||
self._last_progress_flush = 0.0
|
||||
|
||||
# Token accumulators
|
||||
self._total_input_tokens = 0
|
||||
@@ -294,6 +302,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
self._schedule_progress_flush()
|
||||
|
||||
if messages:
|
||||
self._counted_message_llm_run_ids.add(str(run_id))
|
||||
|
||||
@@ -445,6 +455,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
self._schedule_progress_flush()
|
||||
|
||||
def set_first_human_message(self, content: str) -> None:
|
||||
"""Record the first human message for convenience fields."""
|
||||
self._first_human_msg = content[:2000] if content else None
|
||||
@@ -474,6 +486,14 @@ class RunJournal(BaseCallbackHandler):
|
||||
"""Force flush remaining buffer. Called in worker's finally block."""
|
||||
if self._pending_flush_tasks:
|
||||
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
|
||||
while self._pending_progress_task is not None and not self._pending_progress_task.done():
|
||||
if self._pending_progress_delayed:
|
||||
self._pending_progress_task.cancel()
|
||||
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
|
||||
self._progress_dirty = False
|
||||
self._pending_progress_delayed = False
|
||||
break
|
||||
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
|
||||
|
||||
while self._buffer:
|
||||
batch = self._buffer[: self._flush_threshold]
|
||||
@@ -484,6 +504,57 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._buffer = batch + self._buffer
|
||||
raise
|
||||
|
||||
def _schedule_progress_flush(self) -> None:
|
||||
"""Best-effort throttled progress snapshot for active run visibility."""
|
||||
if self._progress_reporter is None:
|
||||
return
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_progress_flush
|
||||
if elapsed < self._progress_flush_interval:
|
||||
self._progress_dirty = True
|
||||
self._schedule_delayed_progress_flush(self._progress_flush_interval - elapsed)
|
||||
return
|
||||
if self._pending_progress_task is not None and not self._pending_progress_task.done():
|
||||
self._progress_dirty = True
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
self._progress_dirty = False
|
||||
self._pending_progress_task = loop.create_task(self._flush_progress_async(snapshot=self.get_completion_data()))
|
||||
|
||||
def _schedule_delayed_progress_flush(self, delay: float) -> None:
|
||||
if self._pending_progress_task is not None and not self._pending_progress_task.done():
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
delay = max(0.0, delay)
|
||||
self._pending_progress_delayed = delay > 0
|
||||
self._pending_progress_task = loop.create_task(self._flush_progress_async(delay=delay))
|
||||
|
||||
async def _flush_progress_async(self, *, snapshot: dict | None = None, delay: float = 0.0) -> None:
|
||||
if self._progress_reporter is None:
|
||||
return
|
||||
if delay > 0:
|
||||
self._pending_progress_delayed = True
|
||||
await asyncio.sleep(delay)
|
||||
self._pending_progress_delayed = False
|
||||
dirty_before_write = self._progress_dirty
|
||||
self._progress_dirty = False
|
||||
snapshot_to_write = snapshot or self.get_completion_data()
|
||||
try:
|
||||
await self._progress_reporter(snapshot_to_write)
|
||||
self._last_progress_flush = time.monotonic()
|
||||
except Exception:
|
||||
logger.warning("Failed to persist progress snapshot for run %s", self.run_id, exc_info=True)
|
||||
if dirty_before_write or self._progress_dirty:
|
||||
self._progress_dirty = False
|
||||
self._pending_progress_task = None
|
||||
self._schedule_delayed_progress_flush(self._progress_flush_interval)
|
||||
|
||||
def get_completion_data(self) -> dict:
|
||||
"""Return accumulated token and message data for run completion."""
|
||||
return {
|
||||
|
||||
@@ -38,6 +38,16 @@ class RunRecord:
|
||||
error: str | None = None
|
||||
model_name: str | None = None
|
||||
store_only: bool = False
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
llm_call_count: int = 0
|
||||
lead_agent_tokens: int = 0
|
||||
subagent_tokens: int = 0
|
||||
middleware_tokens: int = 0
|
||||
message_count: int = 0
|
||||
last_ai_message: str | None = None
|
||||
first_human_message: str | None = None
|
||||
|
||||
|
||||
class RunManager:
|
||||
@@ -102,16 +112,53 @@ class RunManager:
|
||||
error=row.get("error"),
|
||||
model_name=row.get("model_name"),
|
||||
store_only=True,
|
||||
total_input_tokens=row.get("total_input_tokens") or 0,
|
||||
total_output_tokens=row.get("total_output_tokens") or 0,
|
||||
total_tokens=row.get("total_tokens") or 0,
|
||||
llm_call_count=row.get("llm_call_count") or 0,
|
||||
lead_agent_tokens=row.get("lead_agent_tokens") or 0,
|
||||
subagent_tokens=row.get("subagent_tokens") or 0,
|
||||
middleware_tokens=row.get("middleware_tokens") or 0,
|
||||
message_count=row.get("message_count") or 0,
|
||||
last_ai_message=row.get("last_ai_message"),
|
||||
first_human_message=row.get("first_human_message"),
|
||||
)
|
||||
|
||||
async def update_run_completion(self, run_id: str, **kwargs) -> None:
|
||||
"""Persist token usage and completion data to the backing store."""
|
||||
async with self._lock:
|
||||
record = self._runs.get(run_id)
|
||||
if record is not None:
|
||||
for key, value in kwargs.items():
|
||||
if key == "status":
|
||||
continue
|
||||
if hasattr(record, key) and value is not None:
|
||||
setattr(record, key, value)
|
||||
record.updated_at = _now_iso()
|
||||
if self._store is not None:
|
||||
try:
|
||||
await self._store.update_run_completion(run_id, **kwargs)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
|
||||
|
||||
async def update_run_progress(self, run_id: str, **kwargs) -> None:
|
||||
"""Persist a running token/message snapshot without changing status."""
|
||||
should_persist = True
|
||||
async with self._lock:
|
||||
record = self._runs.get(run_id)
|
||||
if record is not None:
|
||||
should_persist = record.status == RunStatus.running
|
||||
if record is not None and should_persist:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(record, key) and value is not None:
|
||||
setattr(record, key, value)
|
||||
record.updated_at = _now_iso()
|
||||
if should_persist and self._store is not None:
|
||||
try:
|
||||
await self._store.update_run_progress(run_id, **kwargs)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run progress for %s", run_id, exc_info=True)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
|
||||
@@ -95,12 +95,30 @@ class RunStore(abc.ABC):
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def update_run_progress(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
total_input_tokens: int | None = None,
|
||||
total_output_tokens: int | None = None,
|
||||
total_tokens: int | None = None,
|
||||
llm_call_count: int | None = None,
|
||||
lead_agent_tokens: int | None = None,
|
||||
subagent_tokens: int | None = None,
|
||||
middleware_tokens: int | None = None,
|
||||
message_count: int | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
first_human_message: str | None = None,
|
||||
) -> None:
|
||||
"""Persist a best-effort running snapshot without changing run status."""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
|
||||
"""Aggregate token usage for completed runs in a thread.
|
||||
|
||||
Returns a dict with keys: total_tokens, total_input_tokens,
|
||||
|
||||
@@ -82,14 +82,22 @@ class MemoryRunStore(RunStore):
|
||||
self._runs[run_id][key] = value
|
||||
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
||||
|
||||
async def update_run_progress(self, run_id, **kwargs):
|
||||
if run_id in self._runs and self._runs[run_id].get("status") == "running":
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
self._runs[run_id][key] = value
|
||||
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
||||
|
||||
async def list_pending(self, *, before=None):
|
||||
now = before or datetime.now(UTC).isoformat()
|
||||
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
|
||||
results.sort(key=lambda r: r["created_at"])
|
||||
return results
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
|
||||
statuses = ("success", "error", "running") if include_active else ("success", "error")
|
||||
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]
|
||||
by_model: dict[str, dict] = {}
|
||||
for r in completed:
|
||||
model = r.get("model_name") or "unknown"
|
||||
|
||||
@@ -153,8 +153,6 @@ async def run_agent(
|
||||
|
||||
journal = None
|
||||
|
||||
journal = None
|
||||
|
||||
# Track whether "events" was requested but skipped
|
||||
if "events" in requested_modes:
|
||||
logger.info(
|
||||
@@ -177,6 +175,7 @@ async def run_agent(
|
||||
thread_id=thread_id,
|
||||
event_store=event_store,
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
progress_reporter=lambda snapshot: run_manager.update_run_progress(run_id, **snapshot),
|
||||
)
|
||||
|
||||
# 1. Mark running
|
||||
@@ -219,6 +218,12 @@ async def run_agent(
|
||||
# manually here because we drive the graph through ``agent.astream(config=...)``
|
||||
# without passing the official ``context=`` parameter.
|
||||
runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"), ctx.app_config)
|
||||
# Expose the run-scoped journal under a sentinel key so middleware can
|
||||
# write audit events (e.g. SafetyFinishReasonMiddleware recording
|
||||
# suppressed tool calls). Double-underscore prefix marks it as a
|
||||
# runtime-internal channel; user code must not depend on the key name.
|
||||
if journal is not None:
|
||||
runtime_ctx["__run_journal"] = journal
|
||||
_install_runtime_context(config, runtime_ctx)
|
||||
runtime = Runtime(context=cast(Any, runtime_ctx), store=store)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
@@ -42,6 +42,7 @@ _DEFAULT_GLOB_MAX_RESULTS = 200
|
||||
_MAX_GLOB_MAX_RESULTS = 1000
|
||||
_DEFAULT_GREP_MAX_RESULTS = 100
|
||||
_MAX_GREP_MAX_RESULTS = 500
|
||||
_DEFAULT_WRITE_FILE_ERROR_MAX_CHARS = 2000
|
||||
_LOCAL_BASH_CWD_COMMANDS = {"cd", "pushd"}
|
||||
_LOCAL_BASH_COMMAND_WRAPPERS = {"command", "builtin"}
|
||||
_LOCAL_BASH_COMMAND_PREFIX_KEYWORDS = {"!", "{", "case", "do", "elif", "else", "for", "if", "select", "then", "time", "until", "while"}
|
||||
@@ -435,6 +436,42 @@ def _sanitize_error(error: Exception, runtime: Runtime | None = None) -> str:
|
||||
return msg
|
||||
|
||||
|
||||
def _truncate_write_file_error_detail(detail: str, max_chars: int) -> str:
|
||||
"""Middle-truncate write_file error details, preserving the head and tail."""
|
||||
if max_chars == 0:
|
||||
return detail
|
||||
if len(detail) <= max_chars:
|
||||
return detail
|
||||
total = len(detail)
|
||||
marker_max_len = len(f"\n... [write_file error truncated: {total} chars skipped] ...\n")
|
||||
kept = max(0, max_chars - marker_max_len)
|
||||
if kept == 0:
|
||||
return detail[:max_chars]
|
||||
head_len = kept // 2
|
||||
tail_len = kept - head_len
|
||||
skipped = total - kept
|
||||
marker = f"\n... [write_file error truncated: {skipped} chars skipped] ...\n"
|
||||
return f"{detail[:head_len]}{marker}{detail[-tail_len:] if tail_len > 0 else ''}"
|
||||
|
||||
|
||||
def _format_write_file_error(
|
||||
requested_path: str,
|
||||
error: Exception,
|
||||
runtime: Runtime | None = None,
|
||||
*,
|
||||
max_chars: int = _DEFAULT_WRITE_FILE_ERROR_MAX_CHARS,
|
||||
) -> str:
|
||||
"""Return a bounded, sanitized error string for write_file failures."""
|
||||
header = f"Error: Failed to write file '{requested_path}'"
|
||||
detail = _sanitize_error(error, runtime)
|
||||
if max_chars == 0:
|
||||
return f"{header}: {detail}"
|
||||
detail_budget = max_chars - len(header) - 2
|
||||
if detail_budget <= 0:
|
||||
return _truncate_write_file_error_detail(f"{header}: {detail}", max_chars)
|
||||
return f"{header}: {_truncate_write_file_error_detail(detail, detail_budget)}"
|
||||
|
||||
|
||||
def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
|
||||
"""Replace virtual /mnt/user-data paths with actual thread data paths.
|
||||
|
||||
@@ -1651,9 +1688,9 @@ def write_file_tool(
|
||||
append: Whether to append content to the end of the file instead of overwriting it. Defaults to false.
|
||||
"""
|
||||
try:
|
||||
requested_path = path
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
validate_local_tool_path(path, thread_data)
|
||||
@@ -1664,15 +1701,21 @@ def write_file_tool(
|
||||
sandbox.write_file(path, content, append)
|
||||
return "OK"
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
return _format_write_file_error(requested_path, e, runtime)
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied writing to file: {requested_path}"
|
||||
return _truncate_write_file_error_detail(
|
||||
f"Error: Permission denied writing to file: {requested_path}",
|
||||
_DEFAULT_WRITE_FILE_ERROR_MAX_CHARS,
|
||||
)
|
||||
except IsADirectoryError:
|
||||
return f"Error: Path is a directory, not a file: {requested_path}"
|
||||
return _truncate_write_file_error_detail(
|
||||
f"Error: Path is a directory, not a file: {requested_path}",
|
||||
_DEFAULT_WRITE_FILE_ERROR_MAX_CHARS,
|
||||
)
|
||||
except OSError as e:
|
||||
return f"Error: Failed to write file '{requested_path}': {_sanitize_error(e, runtime)}"
|
||||
return _format_write_file_error(requested_path, e, runtime)
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime)}"
|
||||
return _format_write_file_error(requested_path, e, runtime)
|
||||
|
||||
|
||||
async def _write_file_tool_async(
|
||||
|
||||
@@ -7,6 +7,7 @@ from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Annotated, Any, cast
|
||||
|
||||
from langchain.tools import InjectedToolCallId, tool
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
@@ -99,15 +100,31 @@ def _schedule_deferred_subagent_cleanup(task_id: str, trace_id: str, max_polls:
|
||||
|
||||
|
||||
def _find_usage_recorder(runtime: Any) -> Any | None:
|
||||
"""Find a callback handler with ``record_external_llm_usage_records`` in the runtime config."""
|
||||
"""Find a callback handler with ``record_external_llm_usage_records`` in the runtime config.
|
||||
|
||||
LangChain may pass ``config["callbacks"]`` in three different shapes:
|
||||
|
||||
- ``None`` (no callbacks registered): no recorder.
|
||||
- A plain ``list[BaseCallbackHandler]``: iterate it directly.
|
||||
- A ``BaseCallbackManager`` instance (e.g. ``AsyncCallbackManager`` on async
|
||||
tool runs): managers are not iterable, so we unwrap ``.handlers`` first.
|
||||
|
||||
Any other shape (e.g. a single handler object accidentally passed without a
|
||||
list wrapper) cannot be iterated safely; treat it as "no recorder" rather
|
||||
than raise.
|
||||
"""
|
||||
if runtime is None:
|
||||
return None
|
||||
config = getattr(runtime, "config", None)
|
||||
if not isinstance(config, dict):
|
||||
return None
|
||||
callbacks = config.get("callbacks", [])
|
||||
callbacks = config.get("callbacks")
|
||||
if isinstance(callbacks, BaseCallbackManager):
|
||||
callbacks = callbacks.handlers
|
||||
if not callbacks:
|
||||
return None
|
||||
if not isinstance(callbacks, list):
|
||||
return None
|
||||
for cb in callbacks:
|
||||
if hasattr(cb, "record_external_llm_usage_records"):
|
||||
return cb
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
"""End-to-end demo: SafetyFinishReasonMiddleware on the real DeerFlow lead-agent.
|
||||
|
||||
What it proves
|
||||
--------------
|
||||
- The real ``make_lead_agent`` / ``DeerFlowClient`` pipeline is built (full
|
||||
18-middleware chain, sandbox, tools, etc.).
|
||||
- A model that returns ``finish_reason='content_filter'`` + ``tool_calls``
|
||||
triggers SafetyFinishReasonMiddleware.
|
||||
- LangChain's tool router never invokes ``write_file`` — the truncated
|
||||
arguments do **not** reach the sandbox.
|
||||
- A ``safety_termination`` custom event is emitted on the stream and the
|
||||
final AIMessage carries the observability stamp.
|
||||
|
||||
Run from backend/ directory:
|
||||
PYTHONPATH=. uv run python scripts/e2e_safety_termination_demo.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake provider that mimics Moonshot's content_filter behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ContentFilteredFakeModel(BaseChatModel):
|
||||
"""First call returns finish_reason=content_filter + truncated write_file
|
||||
tool_call. Subsequent calls return a normal stop response so the agent
|
||||
can terminate (the middleware should make a second call unnecessary by
|
||||
clearing tool_calls, but we keep this safety net in case loop-detection
|
||||
or anything else triggers another model invocation)."""
|
||||
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-content-filtered"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self.call_count += 1
|
||||
if self.call_count == 1:
|
||||
msg = AIMessage(
|
||||
content="# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_truncated_write",
|
||||
"name": "write_file",
|
||||
"args": {
|
||||
"path": "/mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md",
|
||||
"content": "# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与",
|
||||
},
|
||||
}
|
||||
],
|
||||
response_metadata={
|
||||
"finish_reason": "content_filter",
|
||||
"model_name": "kimi-k2.6",
|
||||
"model_provider": "openai",
|
||||
},
|
||||
)
|
||||
else:
|
||||
msg = AIMessage(
|
||||
content="(secondary call, should not be needed)",
|
||||
response_metadata={"finish_reason": "stop", "model_name": "kimi-k2.6"},
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=msg)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Driver
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> int:
|
||||
# Inject the fake model BEFORE constructing the client. Both the
|
||||
# client module and the lead-agent module bind ``create_chat_model``
|
||||
# at import time via ``from deerflow.models import create_chat_model``,
|
||||
# so we patch both attribute slots — the source-of-truth patch on
|
||||
# ``factory.create_chat_model`` doesn't propagate back into already-
|
||||
# imported names.
|
||||
import deerflow.agents.lead_agent.agent as lead_agent_module
|
||||
import deerflow.client as client_module
|
||||
|
||||
fake = _ContentFilteredFakeModel()
|
||||
originals = {
|
||||
"lead": lead_agent_module.create_chat_model,
|
||||
"client": client_module.create_chat_model,
|
||||
}
|
||||
|
||||
def fake_create_chat_model(*args, **kwargs):
|
||||
return fake
|
||||
|
||||
lead_agent_module.create_chat_model = fake_create_chat_model
|
||||
client_module.create_chat_model = fake_create_chat_model
|
||||
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
try:
|
||||
client = DeerFlowClient()
|
||||
|
||||
print("\n=== Streaming a turn through the real lead-agent ===")
|
||||
events: list[dict[str, Any]] = []
|
||||
for event in client.stream(
|
||||
"帮我整理一下最近一周政经新闻,写到 /mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md",
|
||||
thread_id="e2e-safety-1",
|
||||
):
|
||||
events.append({"type": event.type, "data": event.data})
|
||||
|
||||
# ---- Assertions ----
|
||||
safety_event = next(
|
||||
(e for e in events if e["type"] == "custom" and isinstance(e["data"], dict) and e["data"].get("type") == "safety_termination"),
|
||||
None,
|
||||
)
|
||||
final_values = next(
|
||||
(e for e in reversed(events) if e["type"] == "values"),
|
||||
None,
|
||||
)
|
||||
tool_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "tool"]
|
||||
ai_tool_call_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "ai" and e["data"].get("tool_calls")]
|
||||
|
||||
print(f"\n[stats] total stream events: {len(events)}")
|
||||
print(f"[stats] model call count: {fake.call_count}")
|
||||
print(f"[stats] tool messages on stream: {len(tool_messages)}")
|
||||
print(f"[stats] AI messages carrying tool_calls: {len(ai_tool_call_messages)}")
|
||||
|
||||
print("\n[event] safety_termination custom event:")
|
||||
if safety_event is None:
|
||||
print(" *** NOT FOUND ***")
|
||||
return 1
|
||||
for k, v in safety_event["data"].items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
print("\n[state] final AIMessage from last values snapshot:")
|
||||
if final_values is None:
|
||||
print(" *** no values snapshot ***")
|
||||
return 1
|
||||
# `values` event carries `_serialize_message` dicts, not Message objects.
|
||||
final_messages = final_values["data"].get("messages") or []
|
||||
last_ai = next((m for m in reversed(final_messages) if isinstance(m, dict) and m.get("type") == "ai"), None)
|
||||
if last_ai is None:
|
||||
print(" *** no AIMessage in final state ***")
|
||||
print(f" message types seen: {[m.get('type') if isinstance(m, dict) else type(m).__name__ for m in final_messages]}")
|
||||
return 1
|
||||
|
||||
tool_calls = last_ai.get("tool_calls") or []
|
||||
additional_kwargs = last_ai.get("additional_kwargs") or {}
|
||||
response_metadata = last_ai.get("response_metadata") or {}
|
||||
content = last_ai.get("content")
|
||||
|
||||
print(f" tool_calls (must be empty): {tool_calls}")
|
||||
print(f" additional_kwargs.safety_termination: {additional_kwargs.get('safety_termination')}")
|
||||
content_preview = (content if isinstance(content, str) else str(content))[:200]
|
||||
print(f" content[:200]: {content_preview!r}")
|
||||
print(f" response_metadata.finish_reason: {response_metadata.get('finish_reason')}")
|
||||
|
||||
# NOTE: `client._serialize_message` does not include `response_metadata`
|
||||
# in the values-event payload (client-layer behaviour, unrelated to the
|
||||
# middleware). The middleware *does* preserve finish_reason on the
|
||||
# AIMessage object — see test_safety_finish_reason_middleware.py::
|
||||
# TestMessageRewrite::test_preserves_response_metadata_finish_reason.
|
||||
# Here we assert on the observability stamp, which carries the same
|
||||
# evidence and is in the serialized payload.
|
||||
stamp = additional_kwargs.get("safety_termination") or {}
|
||||
failures = []
|
||||
if tool_calls:
|
||||
failures.append("final AIMessage still has tool_calls — middleware did NOT clear them")
|
||||
if not stamp:
|
||||
failures.append("final AIMessage missing safety_termination observability stamp")
|
||||
if tool_messages:
|
||||
failures.append(f"tool node was invoked: {len(tool_messages)} ToolMessage(s) on stream")
|
||||
if stamp.get("reason_value") != "content_filter":
|
||||
failures.append(f"safety_termination.reason_value was {stamp.get('reason_value')!r}, expected 'content_filter'")
|
||||
if safety_event is None:
|
||||
failures.append("safety_termination custom event was not emitted on the stream")
|
||||
|
||||
if failures:
|
||||
print("\n=== FAIL ===")
|
||||
for f in failures:
|
||||
print(f" - {f}")
|
||||
return 1
|
||||
|
||||
print("\n=== PASS ===")
|
||||
print(" - tool_calls cleared on final AIMessage")
|
||||
print(" - tool node never invoked (no ToolMessage on stream)")
|
||||
print(" - safety_termination custom event emitted")
|
||||
print(" - observability stamp written to additional_kwargs")
|
||||
print(" - response_metadata.finish_reason preserved for downstream SSE")
|
||||
return 0
|
||||
finally:
|
||||
lead_agent_module.create_chat_model = originals["lead"]
|
||||
client_module.create_chat_model = originals["client"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -218,6 +218,70 @@ class TestBuildPatchedMessagesPatching:
|
||||
|
||||
assert mw._build_patched_messages(msgs) is None
|
||||
|
||||
def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}),
|
||||
_ai_with_tool_calls(
|
||||
[
|
||||
_tc("web_search", "web_search:11"),
|
||||
_tc("web_search", "web_search:12"),
|
||||
_tc("web_search", "web_search:13"),
|
||||
]
|
||||
),
|
||||
_tool_msg("web_search:11", "web_search"),
|
||||
_tool_msg("web_search:12", "web_search"),
|
||||
_tool_msg("web_search:13", "web_search"),
|
||||
_ai_with_tool_calls(
|
||||
[
|
||||
_tc("web_search", "web_search:9"),
|
||||
_tc("web_search", "web_search:10"),
|
||||
_tc("web_search", "web_search:11"),
|
||||
]
|
||||
),
|
||||
_tool_msg("web_search:9", "web_search"),
|
||||
_tool_msg("web_search:10", "web_search"),
|
||||
_tool_msg("web_search:11", "web_search"),
|
||||
]
|
||||
|
||||
assert mw._build_patched_messages(msgs) is None
|
||||
|
||||
def test_reused_tool_call_id_patches_second_dangling_occurrence(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||
_tool_msg("web_search:11", "web_search"),
|
||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert patched[1].tool_call_id == "web_search:11"
|
||||
assert patched[1].status == "success"
|
||||
assert isinstance(patched[3], ToolMessage)
|
||||
assert patched[3].tool_call_id == "web_search:11"
|
||||
assert patched[3].status == "error"
|
||||
|
||||
def test_reused_tool_call_id_consumes_later_result_for_first_dangling_occurrence(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
result = _tool_msg("web_search:11", "web_search")
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||
result,
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert patched[1] is result
|
||||
assert patched[1].status == "success"
|
||||
assert isinstance(patched[3], ToolMessage)
|
||||
assert patched[3].tool_call_id == "web_search:11"
|
||||
assert patched[3].status == "error"
|
||||
|
||||
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Regression tests for gateway config freshness on the request hot path.
|
||||
|
||||
Bytedance/deer-flow issue #3107 BUG-001: the worker and lead-agent path
|
||||
captured ``app.state.config`` at gateway startup. ``config.yaml`` edits during
|
||||
runtime were therefore ignored — ``get_app_config()``'s mtime-based reload
|
||||
existed but was bypassed because the snapshot object was passed through
|
||||
explicitly.
|
||||
|
||||
These tests pin the desired behaviour: a request-time ``get_config`` call must
|
||||
observe the most recent on-disk ``config.yaml`` (mtime reload), and the
|
||||
runtime ``ContextVar`` override must keep working for per-request injection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway import deps as gateway_deps
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.config.app_config import (
|
||||
AppConfig,
|
||||
pop_current_app_config,
|
||||
push_current_app_config,
|
||||
reset_app_config,
|
||||
set_app_config,
|
||||
)
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_app_config_singleton():
|
||||
"""Ensure each test starts with a clean module-level cache."""
|
||||
reset_app_config()
|
||||
yield
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def _write_config_yaml(path: Path, *, log_level: str) -> None:
|
||||
path.write_text(
|
||||
f"""
|
||||
sandbox:
|
||||
use: deerflow.sandbox.local.provider:LocalSandboxProvider
|
||||
log_level: {log_level}
|
||||
""".strip()
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def _build_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/probe")
|
||||
def probe(cfg: AppConfig = Depends(get_config)):
|
||||
return {"log_level": cfg.log_level}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_get_config_reflects_file_mtime_reload(tmp_path, monkeypatch):
|
||||
"""Editing config.yaml at runtime must be visible to /probe without restart.
|
||||
|
||||
This is the literal repro for the issue: the gateway must not freeze the
|
||||
config to whatever was on disk when the process started.
|
||||
"""
|
||||
config_file = tmp_path / "config.yaml"
|
||||
_write_config_yaml(config_file, log_level="info")
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
|
||||
|
||||
app = _build_app()
|
||||
client = TestClient(app)
|
||||
assert client.get("/probe").json() == {"log_level": "info"}
|
||||
|
||||
# Edit the file and bump its mtime — simulating a maintainer changing
|
||||
# max_tokens / model settings in production while the gateway is live.
|
||||
_write_config_yaml(config_file, log_level="debug")
|
||||
future_mtime = config_file.stat().st_mtime + 5
|
||||
os.utime(config_file, (future_mtime, future_mtime))
|
||||
|
||||
assert client.get("/probe").json() == {"log_level": "debug"}
|
||||
|
||||
|
||||
def test_get_config_respects_runtime_context_override(tmp_path, monkeypatch):
|
||||
"""Per-request ``push_current_app_config`` injection must still win."""
|
||||
config_file = tmp_path / "config.yaml"
|
||||
_write_config_yaml(config_file, log_level="info")
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
|
||||
|
||||
override = AppConfig(sandbox=SandboxConfig(use="test"), log_level="trace")
|
||||
push_current_app_config(override)
|
||||
try:
|
||||
app = _build_app()
|
||||
client = TestClient(app)
|
||||
assert client.get("/probe").json() == {"log_level": "trace"}
|
||||
finally:
|
||||
pop_current_app_config()
|
||||
|
||||
|
||||
def test_get_config_respects_test_set_app_config():
|
||||
"""``set_app_config`` (used by upload/skills router tests) keeps working."""
|
||||
injected = AppConfig(sandbox=SandboxConfig(use="test"), log_level="warning")
|
||||
set_app_config(injected)
|
||||
|
||||
app = _build_app()
|
||||
client = TestClient(app)
|
||||
assert client.get("/probe").json() == {"log_level": "warning"}
|
||||
|
||||
|
||||
def test_run_context_app_config_reflects_yaml_edit(tmp_path, monkeypatch):
|
||||
"""``RunContext.app_config`` must follow live `config.yaml` edits.
|
||||
|
||||
BUG-001 review feedback: the run-context that feeds worker / lead-agent
|
||||
factories must observe the same mtime reload that `get_config()` does;
|
||||
otherwise stale config slips back in through the run path even after the
|
||||
request dependency is fixed.
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.gateway.deps import get_run_context
|
||||
|
||||
config_file = tmp_path / "config.yaml"
|
||||
_write_config_yaml(config_file, log_level="info")
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
|
||||
|
||||
app = FastAPI()
|
||||
# Sentinel values for the rest of the RunContext wiring — we only care
|
||||
# about ``ctx.app_config`` for this assertion.
|
||||
app.state.checkpointer = MagicMock()
|
||||
app.state.store = MagicMock()
|
||||
app.state.run_event_store = MagicMock()
|
||||
app.state.run_events_config = {"frozen": "startup"}
|
||||
app.state.thread_store = MagicMock()
|
||||
|
||||
@app.get("/run-ctx-log-level")
|
||||
def probe(ctx=Depends(get_run_context)):
|
||||
return {
|
||||
"log_level": ctx.app_config.log_level,
|
||||
"run_events_config": ctx.run_events_config,
|
||||
}
|
||||
|
||||
client = TestClient(app)
|
||||
first = client.get("/run-ctx-log-level").json()
|
||||
assert first == {"log_level": "info", "run_events_config": {"frozen": "startup"}}
|
||||
|
||||
_write_config_yaml(config_file, log_level="debug")
|
||||
future_mtime = config_file.stat().st_mtime + 5
|
||||
os.utime(config_file, (future_mtime, future_mtime))
|
||||
|
||||
second = client.get("/run-ctx-log-level").json()
|
||||
# app_config follows the edit; run_events_config stays frozen to the
|
||||
# startup snapshot we wrote onto app.state above.
|
||||
assert second == {"log_level": "debug", "run_events_config": {"frozen": "startup"}}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exception",
|
||||
[
|
||||
FileNotFoundError("config.yaml not found"),
|
||||
PermissionError("config.yaml not readable"),
|
||||
ValueError("invalid config"),
|
||||
RuntimeError("yaml parse error"),
|
||||
],
|
||||
)
|
||||
def test_get_config_returns_503_on_any_load_failure(monkeypatch, exception):
|
||||
"""Any failure to materialise the config must surface as 503, not 500.
|
||||
|
||||
Bytedance/deer-flow issue #3107 BUG-001 review: the original snapshot
|
||||
contract returned 503 when ``app.state.config is None``. The first cut of
|
||||
this fix only mapped ``FileNotFoundError`` to 503, which left
|
||||
``PermissionError`` / ``yaml.YAMLError`` / ``ValidationError`` etc. bubbling
|
||||
up as 500. Catch every load failure at the request boundary.
|
||||
"""
|
||||
|
||||
def _broken_get_app_config():
|
||||
raise exception
|
||||
|
||||
monkeypatch.setattr(gateway_deps, "get_app_config", _broken_get_app_config)
|
||||
|
||||
app = _build_app()
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/probe")
|
||||
|
||||
assert response.status_code == 503
|
||||
assert response.json() == {"detail": "Configuration not available"}
|
||||
@@ -1,41 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def test_get_config_returns_app_state_config():
|
||||
"""get_config should return the exact AppConfig stored on app.state."""
|
||||
app = FastAPI()
|
||||
config = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
app.state.config = config
|
||||
|
||||
@app.get("/probe")
|
||||
def probe(cfg: AppConfig = Depends(get_config)):
|
||||
return {"same_identity": cfg is config, "log_level": cfg.log_level}
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/probe")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"same_identity": True, "log_level": "info"}
|
||||
|
||||
|
||||
def test_get_config_reads_updated_app_state():
|
||||
"""Swapping app.state.config should be visible to the dependency."""
|
||||
app = FastAPI()
|
||||
app.state.config = AppConfig(sandbox=SandboxConfig(use="test"), log_level="info")
|
||||
|
||||
@app.get("/log-level")
|
||||
def log_level(cfg: AppConfig = Depends(get_config)):
|
||||
return {"level": cfg.log_level}
|
||||
|
||||
client = TestClient(app)
|
||||
assert client.get("/log-level").json() == {"level": "info"}
|
||||
|
||||
app.state.config = app.state.config.model_copy(update={"log_level": "debug"})
|
||||
assert client.get("/log-level").json() == {"level": "debug"}
|
||||
@@ -17,7 +17,7 @@ from fastapi import FastAPI
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_langgraph_runtime(_app):
|
||||
async def _noop_langgraph_runtime(_app, _startup_config):
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@@ -81,6 +81,94 @@ def test_normalize_input_passthrough():
|
||||
assert result == {"custom_key": "value"}
|
||||
|
||||
|
||||
def test_normalize_input_preserves_additional_kwargs_and_id():
|
||||
"""Regression: gh #3132 — frontend ships uploaded-file metadata in
|
||||
additional_kwargs.files (and a client-side message id). The gateway must
|
||||
not strip them before the graph runs, otherwise UploadsMiddleware reports
|
||||
"(empty)" for new uploads and the frontend message loses its file chip.
|
||||
"""
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
files = [{"filename": "a.csv", "size": 100, "path": "/mnt/user-data/uploads/a.csv", "status": "uploaded"}]
|
||||
result = normalize_input(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"type": "human",
|
||||
"id": "client-msg-1",
|
||||
"name": "user-input",
|
||||
"content": [{"type": "text", "text": "clean it"}],
|
||||
"additional_kwargs": {"files": files, "custom": "keep-me"},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
assert len(result["messages"]) == 1
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg, HumanMessage)
|
||||
assert msg.id == "client-msg-1"
|
||||
assert msg.name == "user-input"
|
||||
assert msg.content == [{"type": "text", "text": "clean it"}]
|
||||
assert msg.additional_kwargs == {"files": files, "custom": "keep-me"}
|
||||
|
||||
|
||||
def test_normalize_input_passes_through_basemessage_instances():
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
msg = HumanMessage(content="hello", id="m-1", additional_kwargs={"files": [{"filename": "x"}]})
|
||||
result = normalize_input({"messages": [msg]})
|
||||
assert result["messages"][0] is msg
|
||||
|
||||
|
||||
def test_normalize_input_rejects_malformed_message_with_400():
|
||||
"""Boundary validation: ``convert_to_messages`` raises ``ValueError`` when a
|
||||
message dict is missing ``role``/``type``/``content``. ``normalize_input``
|
||||
runs inside the gateway HTTP boundary, so a malformed payload should surface
|
||||
as a 400 referencing the offending entry — not bubble up as a 500.
|
||||
|
||||
Raised after the Copilot review on PR #3136.
|
||||
"""
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
normalize_input({"messages": [{"role": "human", "content": "ok"}, {"oops": "no role here"}]})
|
||||
assert excinfo.value.status_code == 400
|
||||
assert "input.messages[1]" in excinfo.value.detail
|
||||
|
||||
|
||||
def test_normalize_input_handles_non_human_roles():
|
||||
"""The previous implementation collapsed every role to HumanMessage with a
|
||||
`# TODO: handle other message types` comment. Resuming a thread with prior
|
||||
AI/tool messages would silently rewrite them as human turns — corrupting
|
||||
the conversation. Use langchain's standard conversion so ai/system/tool
|
||||
roles round-trip correctly.
|
||||
"""
|
||||
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
|
||||
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
result = normalize_input(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "ai", "content": "hi", "id": "ai-1"},
|
||||
{"role": "tool", "content": "result", "tool_call_id": "call-1"},
|
||||
]
|
||||
}
|
||||
)
|
||||
types = [type(m) for m in result["messages"]]
|
||||
assert types == [SystemMessage, AIMessage, ToolMessage]
|
||||
assert result["messages"][1].id == "ai-1"
|
||||
assert result["messages"][2].tool_call_id == "call-1"
|
||||
|
||||
|
||||
def test_build_run_config_basic():
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
|
||||
@@ -336,8 +336,11 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
)
|
||||
|
||||
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
|
||||
# verify the custom middleware is injected correctly
|
||||
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
|
||||
# verify the custom middleware is injected correctly.
|
||||
# Chain tail order after the custom middleware is:
|
||||
# ..., custom, SafetyFinishReasonMiddleware, ClarificationMiddleware
|
||||
# so the custom mock sits at index [-3].
|
||||
assert len(middlewares) > 0 and isinstance(middlewares[-3], MagicMock)
|
||||
|
||||
|
||||
def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch):
|
||||
|
||||
@@ -0,0 +1,409 @@
|
||||
"""Tests for the MCP persistent-session pool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.mcp.session_pool import MCPSessionPool, get_session_pool, reset_session_pool
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_pool():
|
||||
reset_session_pool()
|
||||
yield
|
||||
reset_session_pool()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCPSessionPool unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_creates_new():
|
||||
"""First call for a key creates a new session."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
session = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert session is mock_session
|
||||
mock_session.initialize.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_reuses_existing():
|
||||
"""Second call for the same key returns the cached session."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
s2 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert s1 is s2
|
||||
# Only one session should have been created.
|
||||
assert mock_cm.__aenter__.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_scope_creates_different_session():
|
||||
"""Different scope keys get different sessions."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
sessions = [AsyncMock(), AsyncMock()]
|
||||
idx = 0
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.enter_count = 0
|
||||
|
||||
async def __aenter__(self):
|
||||
nonlocal idx
|
||||
s = sessions[idx]
|
||||
idx += 1
|
||||
self.enter_count += 1
|
||||
return s
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return False
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=lambda *a, **kw: CmFactory()):
|
||||
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
s2 = await pool.get_session("server", "thread-2", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert s1 is not s2
|
||||
assert s1 is sessions[0]
|
||||
assert s2 is sessions[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_eviction():
|
||||
"""Oldest entries are evicted when the pool is full."""
|
||||
pool = MCPSessionPool()
|
||||
pool.MAX_SESSIONS = 2
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return AsyncMock()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
cms: list[CmFactory] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = CmFactory()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
# Pool is full (2). Adding t3 should evict t1.
|
||||
await pool.get_session("s", "t3", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert cms[0].closed is True
|
||||
assert cms[1].closed is False
|
||||
assert cms[2].closed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_scope():
|
||||
"""close_scope shuts down sessions for a specific scope key."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return AsyncMock()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
cms: list[CmFactory] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = CmFactory()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
await pool.close_scope("t1")
|
||||
|
||||
assert cms[0].closed is True
|
||||
assert cms[1].closed is False
|
||||
|
||||
# t2 session still exists.
|
||||
assert ("s", "t2") in pool._entries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_all():
|
||||
"""close_all shuts down every session."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return AsyncMock()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
cms: list[CmFactory] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = CmFactory()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await pool.get_session("s1", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await pool.get_session("s2", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
await pool.close_all()
|
||||
|
||||
assert all(cm.closed for cm in cms)
|
||||
assert len(pool._entries) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_session_pool_singleton():
|
||||
"""get_session_pool returns the same instance."""
|
||||
p1 = get_session_pool()
|
||||
p2 = get_session_pool()
|
||||
assert p1 is p2
|
||||
|
||||
|
||||
def test_reset_session_pool():
|
||||
"""reset_session_pool clears the singleton."""
|
||||
p1 = get_session_pool()
|
||||
reset_session_pool()
|
||||
p2 = get_session_pool()
|
||||
assert p1 is not p2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: _make_session_pool_tool uses the pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_wrapping():
|
||||
"""The wrapper tool delegates to a pool-managed session."""
|
||||
# Build a dummy StructuredTool (as returned by langchain-mcp-adapters).
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
url: str = Field(..., description="url")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="playwright_navigate",
|
||||
description="Navigate browser",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
connection = {"transport": "stdio", "command": "pw", "args": []}
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
|
||||
|
||||
# Simulate a tool call with a runtime context containing thread_id.
|
||||
mock_runtime = MagicMock()
|
||||
mock_runtime.context = {"thread_id": "thread-42"}
|
||||
mock_runtime.config = {}
|
||||
|
||||
await wrapped.coroutine(runtime=mock_runtime, url="https://example.com")
|
||||
|
||||
mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_extracts_thread_id():
|
||||
"""Thread ID is extracted from runtime.config when not in context."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
x: int = Field(..., description="x")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="server_tool",
|
||||
description="test",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
mock_runtime = MagicMock()
|
||||
mock_runtime.context = {}
|
||||
mock_runtime.config = {"configurable": {"thread_id": "from-config"}}
|
||||
|
||||
await wrapped.coroutine(runtime=mock_runtime, x=1)
|
||||
|
||||
# Verify the session was created with the correct scope key.
|
||||
pool = get_session_pool()
|
||||
assert ("server", "from-config") in pool._entries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_default_scope():
|
||||
"""When no thread_id is available, 'default' is used as scope key."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
x: int = Field(..., description="x")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="server_tool",
|
||||
description="test",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
# No thread_id in runtime at all.
|
||||
await wrapped.coroutine(runtime=None, x=1)
|
||||
|
||||
pool = get_session_pool()
|
||||
assert ("server", "default") in pool._entries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_get_config_fallback():
|
||||
"""When runtime is None, get_config() provides thread_id as fallback."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
x: int = Field(..., description="x")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="server_tool",
|
||||
description="test",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
fake_config = {"configurable": {"thread_id": "from-langgraph-config"}}
|
||||
|
||||
with (
|
||||
patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm),
|
||||
patch("deerflow.mcp.tools.get_config", return_value=fake_config),
|
||||
):
|
||||
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
# runtime=None — get_config() fallback should provide thread_id
|
||||
await wrapped.coroutine(runtime=None, x=1)
|
||||
|
||||
pool = get_session_pool()
|
||||
assert ("server", "from-langgraph-config") in pool._entries
|
||||
|
||||
|
||||
def test_session_pool_tool_sync_wrapper_path_is_safe():
|
||||
"""Sync wrapper (tool.func) invocation doesn't crash on cross-loop access."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||
|
||||
class Args(BaseModel):
|
||||
url: str = Field(..., description="url")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="playwright_navigate",
|
||||
description="Navigate browser",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
connection = {"transport": "stdio", "command": "pw", "args": []}
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
|
||||
# Attach the sync wrapper exactly as get_mcp_tools() does.
|
||||
wrapped.func = make_sync_tool_wrapper(wrapped.coroutine, wrapped.name)
|
||||
|
||||
# Call via the sync path (asyncio.run in a worker thread).
|
||||
# runtime is not supplied so _extract_thread_id falls back to "default".
|
||||
wrapped.func(url="https://example.com")
|
||||
|
||||
mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"})
|
||||
@@ -714,6 +714,110 @@ class TestExternalUsageRecords:
|
||||
assert j._subagent_tokens == 0
|
||||
|
||||
|
||||
class TestProgressSnapshots:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_reports_progress_snapshot(self):
|
||||
snapshots: list[dict] = []
|
||||
|
||||
async def reporter(snapshot: dict) -> None:
|
||||
snapshots.append(snapshot)
|
||||
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal(
|
||||
"r1",
|
||||
"t1",
|
||||
store,
|
||||
flush_threshold=100,
|
||||
progress_reporter=reporter,
|
||||
progress_flush_interval=0,
|
||||
)
|
||||
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
j.on_llm_end(_make_llm_response("Answer", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
await j.flush()
|
||||
|
||||
assert snapshots
|
||||
assert snapshots[-1]["total_tokens"] == 15
|
||||
assert snapshots[-1]["llm_call_count"] == 1
|
||||
assert snapshots[-1]["message_count"] == 1
|
||||
assert snapshots[-1]["last_ai_message"] == "Answer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_throttled_progress_flush_emits_trailing_snapshot(self):
|
||||
snapshots: list[dict] = []
|
||||
trailing_seen = asyncio.Event()
|
||||
|
||||
async def reporter(snapshot: dict) -> None:
|
||||
snapshots.append(snapshot)
|
||||
if snapshot["total_tokens"] == 45:
|
||||
trailing_seen.set()
|
||||
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal(
|
||||
"r1",
|
||||
"t1",
|
||||
store,
|
||||
flush_threshold=100,
|
||||
progress_reporter=reporter,
|
||||
progress_flush_interval=0.01,
|
||||
)
|
||||
j.on_llm_end(
|
||||
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
j.on_llm_end(
|
||||
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
await asyncio.wait_for(trailing_seen.wait(), timeout=1.0)
|
||||
await j.flush()
|
||||
|
||||
assert len(snapshots) >= 2
|
||||
assert snapshots[-1]["total_tokens"] == 45
|
||||
assert snapshots[-1]["llm_call_count"] == 2
|
||||
assert snapshots[-1]["last_ai_message"] == "Second"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_flush_cancels_delayed_progress_without_final_progress_write(self):
|
||||
snapshots: list[dict] = []
|
||||
|
||||
async def reporter(snapshot: dict) -> None:
|
||||
snapshots.append(snapshot)
|
||||
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal(
|
||||
"r1",
|
||||
"t1",
|
||||
store,
|
||||
flush_threshold=100,
|
||||
progress_reporter=reporter,
|
||||
progress_flush_interval=10.0,
|
||||
)
|
||||
j.on_llm_end(
|
||||
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
assert snapshots[-1]["total_tokens"] == 15
|
||||
j.on_llm_end(
|
||||
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
|
||||
await asyncio.wait_for(j.flush(), timeout=0.2)
|
||||
|
||||
assert snapshots[-1]["total_tokens"] == 15
|
||||
assert snapshots[-1]["llm_call_count"] == 1
|
||||
assert snapshots[-1]["last_ai_message"] == "First"
|
||||
|
||||
|
||||
class TestChatModelStartHumanMessage:
|
||||
"""Tests for on_chat_model_start extracting the first human message."""
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.dialects import postgresql
|
||||
|
||||
from deerflow.persistence.run import RunRepository
|
||||
from deerflow.runtime import RunManager, RunStatus
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
|
||||
async def _make_repo(tmp_path):
|
||||
@@ -26,6 +27,42 @@ async def _cleanup():
|
||||
await close_engine()
|
||||
|
||||
|
||||
class _CustomRunStoreWithoutProgress(RunStore):
|
||||
async def put(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def get(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def list_by_thread(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
async def update_status(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def delete(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def update_model_name(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def update_run_completion(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def list_pending(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
async def aggregate_tokens_by_thread(self, *args, **kwargs):
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_progress_defaults_to_noop_for_custom_store():
|
||||
store = _CustomRunStoreWithoutProgress()
|
||||
|
||||
await store.update_run_progress("r1", total_tokens=1)
|
||||
|
||||
|
||||
class TestRunRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_put_and_get(self, tmp_path):
|
||||
@@ -170,6 +207,69 @@ class TestRunRepository:
|
||||
assert row["total_tokens"] == 100
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_progress_keeps_status_running(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="running")
|
||||
await repo.update_run_progress(
|
||||
"r1",
|
||||
total_input_tokens=40,
|
||||
total_output_tokens=10,
|
||||
total_tokens=50,
|
||||
llm_call_count=1,
|
||||
message_count=2,
|
||||
last_ai_message="partial answer",
|
||||
)
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "running"
|
||||
assert row["total_tokens"] == 50
|
||||
assert row["llm_call_count"] == 1
|
||||
assert row["message_count"] == 2
|
||||
assert row["last_ai_message"] == "partial answer"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_progress_preserves_omitted_fields(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="running")
|
||||
await repo.update_run_progress(
|
||||
"r1",
|
||||
total_input_tokens=40,
|
||||
total_output_tokens=10,
|
||||
total_tokens=50,
|
||||
llm_call_count=1,
|
||||
lead_agent_tokens=30,
|
||||
subagent_tokens=20,
|
||||
message_count=2,
|
||||
)
|
||||
|
||||
await repo.update_run_progress("r1", total_tokens=60, last_ai_message="updated")
|
||||
|
||||
row = await repo.get("r1")
|
||||
assert row["total_input_tokens"] == 40
|
||||
assert row["total_output_tokens"] == 10
|
||||
assert row["total_tokens"] == 60
|
||||
assert row["llm_call_count"] == 1
|
||||
assert row["lead_agent_tokens"] == 30
|
||||
assert row["subagent_tokens"] == 20
|
||||
assert row["message_count"] == 2
|
||||
assert row["last_ai_message"] == "updated"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_progress_skips_terminal_runs(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="running")
|
||||
await repo.update_run_completion("r1", status="success", total_tokens=100, llm_call_count=1)
|
||||
|
||||
await repo.update_run_progress("r1", total_tokens=200, llm_call_count=2)
|
||||
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "success"
|
||||
assert row["total_tokens"] == 100
|
||||
assert row["llm_call_count"] == 1
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
@@ -225,6 +325,28 @@ class TestRunRepository:
|
||||
}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_tokens_by_thread_can_include_active_runs(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("success-run", thread_id="t1", status="running")
|
||||
await repo.update_run_completion("success-run", status="success", total_tokens=100, lead_agent_tokens=100)
|
||||
await repo.put("running-run", thread_id="t1", status="running")
|
||||
await repo.update_run_progress("running-run", total_tokens=25, lead_agent_tokens=20, subagent_tokens=5)
|
||||
|
||||
without_active = await repo.aggregate_tokens_by_thread("t1")
|
||||
with_active = await repo.aggregate_tokens_by_thread("t1", include_active=True)
|
||||
|
||||
assert without_active["total_tokens"] == 100
|
||||
assert without_active["total_runs"] == 1
|
||||
assert with_active["total_tokens"] == 125
|
||||
assert with_active["total_runs"] == 2
|
||||
assert with_active["by_caller"] == {
|
||||
"lead_agent": 120,
|
||||
"subagent": 5,
|
||||
"middleware": 0,
|
||||
}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_ordered_desc(self, tmp_path):
|
||||
"""list_by_thread returns newest first."""
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
"""End-to-end graph integration test for SafetyFinishReasonMiddleware.
|
||||
|
||||
Unit tests prove ``_apply`` does the right thing on a synthetic state.
|
||||
This test does one level up: builds a real ``langchain.agents.create_agent``
|
||||
graph with the SafetyFinishReasonMiddleware in place, feeds it a fake model
|
||||
that returns ``finish_reason='content_filter'`` + tool_calls, and asserts:
|
||||
|
||||
1. The tool node is **not** invoked (the dangerous truncated tool call
|
||||
is suppressed).
|
||||
2. The final AIMessage in graph state has ``tool_calls == []``.
|
||||
3. The observability ``safety_termination`` record is attached.
|
||||
4. The user-facing explanation is appended to the message content.
|
||||
|
||||
This is the closest we can get to the issue's failure mode without a live
|
||||
Moonshot key, and it proves the middleware actually gates LangChain's
|
||||
tool router — not just rewrites state in isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
|
||||
_TOOL_INVOCATIONS: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@tool
|
||||
def write_file(path: str, content: str) -> str:
|
||||
"""Pretend to write *content* to *path*. Records the call for assertion."""
|
||||
_TOOL_INVOCATIONS.append({"path": path, "content": content})
|
||||
return f"wrote {len(content)} bytes to {path}"
|
||||
|
||||
|
||||
class _ContentFilteredModel(BaseChatModel):
|
||||
"""Fake chat model that mimics OpenAI/Moonshot's content_filter response.
|
||||
|
||||
First call returns finish_reason='content_filter' + a tool_call whose
|
||||
arguments are visibly truncated. Second call (if reached) returns a
|
||||
normal text completion so the agent can terminate cleanly.
|
||||
"""
|
||||
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-content-filtered"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
# create_agent binds tools onto the model; we don't actually need
|
||||
# to bind anything since responses are hard-coded, but the method
|
||||
# must not raise.
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self.call_count += 1
|
||||
if self.call_count == 1:
|
||||
message = AIMessage(
|
||||
content="Here is the report:\n# Weekly Politics\n- Meeting time: 2026-05-12—",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_truncated_1",
|
||||
"name": "write_file",
|
||||
"args": {
|
||||
"path": "/mnt/user-data/outputs/report.md",
|
||||
"content": "# Weekly Politics\n- Meeting time: 2026-05-12—",
|
||||
},
|
||||
}
|
||||
],
|
||||
response_metadata={"finish_reason": "content_filter", "model_name": "fake-kimi"},
|
||||
)
|
||||
else:
|
||||
message = AIMessage(content="ack", response_metadata={"finish_reason": "stop"})
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
|
||||
class _InspectMiddleware(AgentMiddleware):
|
||||
"""Captures the messages list at every model entry so we can assert
|
||||
no synthetic tool result was injected back into the conversation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.observed: list[list[Any]] = []
|
||||
|
||||
def wrap_model_call(self, request: ModelRequest, handler) -> ModelResponse:
|
||||
self.observed.append(list(request.messages))
|
||||
return handler(request)
|
||||
|
||||
|
||||
def test_content_filter_with_tool_calls_does_not_invoke_tool_node():
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
inspector = _InspectMiddleware()
|
||||
|
||||
agent = create_agent(
|
||||
model=_ContentFilteredModel(),
|
||||
tools=[write_file],
|
||||
# Inspector first so its after_model is registered; Safety last in
|
||||
# the list so it executes first under LIFO (matches production wiring).
|
||||
middleware=[inspector, SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage(content="write me a report")]})
|
||||
|
||||
# Critical assertion: the dangerous truncated tool call must NOT have
|
||||
# been executed. This is the entire point of the middleware.
|
||||
assert _TOOL_INVOCATIONS == [], f"write_file was invoked despite content_filter: {_TOOL_INVOCATIONS}"
|
||||
|
||||
# Final AIMessage has no tool calls left.
|
||||
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
|
||||
assert final_ai.tool_calls == []
|
||||
|
||||
# Observability stamp is present.
|
||||
record = final_ai.additional_kwargs.get("safety_termination")
|
||||
assert record is not None
|
||||
assert record["detector"] == "openai_compatible_content_filter"
|
||||
assert record["reason_field"] == "finish_reason"
|
||||
assert record["reason_value"] == "content_filter"
|
||||
assert record["suppressed_tool_call_count"] == 1
|
||||
assert record["suppressed_tool_call_names"] == ["write_file"]
|
||||
|
||||
# User-facing explanation is appended.
|
||||
assert "safety-related signal" in final_ai.content
|
||||
# Original partial text preserved (we don't throw away what the user
|
||||
# already saw in the stream — see middleware docstring).
|
||||
assert "Weekly Politics" in final_ai.content
|
||||
|
||||
# finish_reason on response_metadata is preserved (so SSE / converters
|
||||
# downstream still see the real provider reason).
|
||||
assert final_ai.response_metadata.get("finish_reason") == "content_filter"
|
||||
|
||||
|
||||
def test_content_filter_without_tool_calls_passes_through_unchanged():
|
||||
"""No tool calls => issue scope says don't intervene; the partial
|
||||
response should be delivered as-is so the user sees what they got."""
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
|
||||
class _NoToolModel(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-no-tool"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
msg = AIMessage(
|
||||
content="Partial answer truncated by safety filter",
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=msg)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
agent = create_agent(
|
||||
model=_NoToolModel(),
|
||||
tools=[write_file],
|
||||
middleware=[SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
result = agent.invoke({"messages": [HumanMessage(content="hi")]})
|
||||
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
|
||||
|
||||
# Content untouched.
|
||||
assert final_ai.content == "Partial answer truncated by safety filter"
|
||||
# No safety_termination stamp because we didn't intervene.
|
||||
assert "safety_termination" not in final_ai.additional_kwargs
|
||||
# tool node never ran (there were no tool calls in the first place).
|
||||
assert _TOOL_INVOCATIONS == []
|
||||
|
||||
|
||||
def test_normal_tool_call_round_trip_is_not_affected():
|
||||
"""Regression: a healthy finish_reason='tool_calls' response must still
|
||||
execute the tool. The middleware must not over-fire."""
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
|
||||
class _HealthyToolModel(BaseChatModel):
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-healthy"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self.call_count += 1
|
||||
if self.call_count == 1:
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_ok",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/tmp/ok", "content": "complete content"},
|
||||
}
|
||||
],
|
||||
response_metadata={"finish_reason": "tool_calls"},
|
||||
)
|
||||
else:
|
||||
msg = AIMessage(content="done", response_metadata={"finish_reason": "stop"})
|
||||
return ChatResult(generations=[ChatGeneration(message=msg)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
agent = create_agent(
|
||||
model=_HealthyToolModel(),
|
||||
tools=[write_file],
|
||||
middleware=[SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
agent.invoke({"messages": [HumanMessage(content="write")]})
|
||||
|
||||
assert _TOOL_INVOCATIONS == [{"path": "/tmp/ok", "content": "complete content"}]
|
||||
@@ -0,0 +1,651 @@
|
||||
"""Unit tests for SafetyFinishReasonMiddleware."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
from deerflow.agents.middlewares.safety_termination_detectors import (
|
||||
SafetyTermination,
|
||||
)
|
||||
from deerflow.config.safety_finish_reason_config import (
|
||||
SafetyDetectorConfig,
|
||||
SafetyFinishReasonConfig,
|
||||
)
|
||||
|
||||
|
||||
def _runtime(thread_id="t-1"):
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": thread_id}
|
||||
return runtime
|
||||
|
||||
|
||||
def _ai(
|
||||
*,
|
||||
content="",
|
||||
tool_calls=None,
|
||||
response_metadata=None,
|
||||
additional_kwargs=None,
|
||||
):
|
||||
return AIMessage(
|
||||
content=content,
|
||||
tool_calls=tool_calls or [],
|
||||
response_metadata=response_metadata or {},
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
def _write_call(idx=1, content_text="半截"):
|
||||
return {
|
||||
"id": f"call_write_{idx}",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/mnt/user-data/outputs/x.md", "content": content_text},
|
||||
}
|
||||
|
||||
|
||||
class AlwaysHitDetector:
|
||||
"""Test fixture: always reports the given termination."""
|
||||
|
||||
name = "always_hit"
|
||||
|
||||
def __init__(self, *, reason_field="finish_reason", reason_value="content_filter", extras=None):
|
||||
self.reason_field = reason_field
|
||||
self.reason_value = reason_value
|
||||
self.extras = extras or {}
|
||||
|
||||
def detect(self, message):
|
||||
return SafetyTermination(
|
||||
detector=self.name,
|
||||
reason_field=self.reason_field,
|
||||
reason_value=self.reason_value,
|
||||
extras=self.extras,
|
||||
)
|
||||
|
||||
|
||||
class NeverHitDetector:
|
||||
name = "never_hit"
|
||||
|
||||
def detect(self, message):
|
||||
return None
|
||||
|
||||
|
||||
class RaisingDetector:
|
||||
name = "raising"
|
||||
|
||||
def detect(self, message):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core trigger behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTriggerCriteria:
|
||||
def test_content_filter_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
patched = result["messages"][0]
|
||||
assert patched.tool_calls == []
|
||||
|
||||
def test_content_filter_without_tool_calls_passes_through(self):
|
||||
"""issue scope: when there are no tool calls the partial text is a
|
||||
legitimate final response and should not be rewritten."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial response",
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_normal_tool_calls_pass_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "tool_calls"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_normal_stop_with_tool_calls_pass_through(self):
|
||||
# Some providers report finish_reason='stop' for tool-call messages.
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "stop"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_empty_message_list_passes_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
assert mw._apply({"messages": []}, _runtime()) is None
|
||||
|
||||
def test_non_ai_last_message_passes_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {"messages": [HumanMessage(content="hi"), SystemMessage(content="sys")]}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_anthropic_refusal_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"stop_reason": "refusal"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_gemini_safety_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "SAFETY"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message rewriting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageRewrite:
|
||||
def test_clears_structured_tool_calls(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1), _write_call(2)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
patched = result["messages"][0]
|
||||
assert patched.tool_calls == []
|
||||
|
||||
def test_clears_raw_additional_kwargs_tool_calls(self):
|
||||
"""Critical defence-in-depth: DanglingToolCallMiddleware will recover
|
||||
tool calls from additional_kwargs.tool_calls if we forget them, which
|
||||
would re-emit a synthetic ToolMessage downstream and confuse the
|
||||
model. We must wipe both."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
raw_tool_calls = [
|
||||
{
|
||||
"id": "call_write_1",
|
||||
"type": "function",
|
||||
"function": {"name": "write_file", "arguments": '{"path": "/x"}'},
|
||||
}
|
||||
]
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
additional_kwargs={
|
||||
"tool_calls": raw_tool_calls,
|
||||
"function_call": {"name": "write_file", "arguments": "{}"},
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
patched = result["messages"][0]
|
||||
assert "tool_calls" not in patched.additional_kwargs
|
||||
assert "function_call" not in patched.additional_kwargs
|
||||
|
||||
def test_preserves_other_additional_kwargs(self):
|
||||
# vLLM puts reasoning under additional_kwargs.reasoning; Anthropic
|
||||
# may carry other provider-specific keys. They must not be wiped.
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
additional_kwargs={
|
||||
"reasoning": "thinking text",
|
||||
"custom_provider_field": {"x": 1},
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.additional_kwargs["reasoning"] == "thinking text"
|
||||
assert patched.additional_kwargs["custom_provider_field"] == {"x": 1}
|
||||
|
||||
def test_writes_observability_field(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1), _write_call(2)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
record = patched.additional_kwargs["safety_termination"]
|
||||
assert record["detector"] == "openai_compatible_content_filter"
|
||||
assert record["reason_field"] == "finish_reason"
|
||||
assert record["reason_value"] == "content_filter"
|
||||
assert record["suppressed_tool_call_count"] == 2
|
||||
assert record["suppressed_tool_call_names"] == ["write_file", "write_file"]
|
||||
|
||||
def test_preserves_response_metadata_finish_reason(self):
|
||||
"""Downstream SSE converters read response_metadata.finish_reason —
|
||||
we want them to see the *real* provider reason, not 'stop'."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter", "model_name": "kimi-k2"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.response_metadata["finish_reason"] == "content_filter"
|
||||
assert patched.response_metadata["model_name"] == "kimi-k2"
|
||||
|
||||
def test_appends_user_facing_explanation_to_str_content(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="some partial text",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, str)
|
||||
assert patched.content.startswith("some partial text")
|
||||
assert "safety-related signal" in patched.content
|
||||
|
||||
def test_handles_empty_content(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, str)
|
||||
assert "safety-related signal" in patched.content
|
||||
|
||||
def test_handles_list_content_thinking_blocks(self):
|
||||
"""Anthropic thinking / vLLM reasoning models emit content blocks.
|
||||
Naively concatenating a string would raise TypeError."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
thinking_blocks = [
|
||||
{"type": "thinking", "text": "let me consider..."},
|
||||
{"type": "text", "text": "partial answer"},
|
||||
]
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content=thinking_blocks,
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, list)
|
||||
assert patched.content[:2] == thinking_blocks
|
||||
assert patched.content[-1]["type"] == "text"
|
||||
assert "safety-related signal" in patched.content[-1]["text"]
|
||||
|
||||
def test_idempotent_on_already_cleared_message(self):
|
||||
# Re-running the middleware on a message we already cleared must not
|
||||
# re-trigger (tool_calls is now empty → fast passthrough).
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
first = mw._apply(state, _runtime())
|
||||
state2 = {"messages": [first["messages"][0]]}
|
||||
second = mw._apply(state2, _runtime())
|
||||
assert second is None
|
||||
|
||||
def test_preserves_message_id_for_add_messages_replacement(self):
|
||||
"""LangGraph's add_messages reducer treats same-id messages as
|
||||
replacements. model_copy keeps id by default."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
original = _ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
# AIMessage auto-generates id; capture it
|
||||
original_id = original.id
|
||||
state = {"messages": [original]}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.id == original_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detector wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDetectorWiring:
|
||||
def test_iterates_detectors_in_order(self):
|
||||
first = AlwaysHitDetector(reason_value="first")
|
||||
second = AlwaysHitDetector(reason_value="second")
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[first, second])
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.additional_kwargs["safety_termination"]["reason_value"] == "first"
|
||||
|
||||
def test_returns_none_when_no_detector_matches(self):
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[NeverHitDetector(), NeverHitDetector()])
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_buggy_detector_does_not_break_run(self):
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[RaisingDetector(), AlwaysHitDetector()])
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].additional_kwargs["safety_termination"]["detector"] == "always_hit"
|
||||
|
||||
def test_constructor_copies_detectors(self):
|
||||
"""Caller mutation after construction must not leak into us."""
|
||||
detectors = [AlwaysHitDetector()]
|
||||
mw = SafetyFinishReasonMiddleware(detectors=detectors)
|
||||
detectors.clear()
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
assert mw._apply(state, _runtime()) is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# from_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFromConfig:
|
||||
def test_default_config_uses_builtin_detectors(self):
|
||||
mw = SafetyFinishReasonMiddleware.from_config(SafetyFinishReasonConfig())
|
||||
assert len(mw._detectors) == 3
|
||||
names = {d.name for d in mw._detectors}
|
||||
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
|
||||
|
||||
def test_custom_detectors_loaded_via_reflection(self):
|
||||
cfg = SafetyFinishReasonConfig(
|
||||
detectors=[
|
||||
SafetyDetectorConfig(
|
||||
use="deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector",
|
||||
config={"finish_reasons": ["custom_filter"]},
|
||||
),
|
||||
]
|
||||
)
|
||||
mw = SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
assert len(mw._detectors) == 1
|
||||
# Confirm the kwargs propagated.
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "custom_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is not None
|
||||
# Default token no longer matches.
|
||||
state2 = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state2, _runtime()) is None
|
||||
|
||||
def test_empty_detector_list_rejected(self):
|
||||
cfg = SafetyFinishReasonConfig(detectors=[])
|
||||
with pytest.raises(ValueError, match="enabled=false"):
|
||||
SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
|
||||
def test_non_detector_class_rejected(self):
|
||||
cfg = SafetyFinishReasonConfig(
|
||||
detectors=[SafetyDetectorConfig(use="builtins:dict")],
|
||||
)
|
||||
with pytest.raises(TypeError):
|
||||
SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stream event
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuditEvent:
|
||||
"""Verify SafetyFinishReasonMiddleware records a `middleware:safety_termination`
|
||||
audit event via RunJournal.record_middleware when the run-scoped journal is
|
||||
exposed under runtime.context["__run_journal"].
|
||||
|
||||
Background: review on PR #3035 — SSE custom event handles live consumers,
|
||||
but post-run audit needs a row in run_events that can be queried with one
|
||||
SQL statement (no JOIN against message body).
|
||||
"""
|
||||
|
||||
def _runtime_with_journal(self, journal):
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "t-audit", "__run_journal": journal}
|
||||
return runtime
|
||||
|
||||
def test_records_audit_event_when_journal_present(self):
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
tc = _write_call(1)
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial",
|
||||
tool_calls=[tc],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, self._runtime_with_journal(journal))
|
||||
assert result is not None
|
||||
|
||||
journal.record_middleware.assert_called_once()
|
||||
call = journal.record_middleware.call_args
|
||||
# tag is positional or kwarg depending on call style; we use kwargs.
|
||||
assert call.kwargs["tag"] == "safety_termination"
|
||||
assert call.kwargs["name"] == "SafetyFinishReasonMiddleware"
|
||||
assert call.kwargs["hook"] == "after_model"
|
||||
assert call.kwargs["action"] == "suppress_tool_calls"
|
||||
|
||||
changes = call.kwargs["changes"]
|
||||
assert changes["detector"] == "openai_compatible_content_filter"
|
||||
assert changes["reason_field"] == "finish_reason"
|
||||
assert changes["reason_value"] == "content_filter"
|
||||
assert changes["suppressed_tool_call_count"] == 1
|
||||
assert changes["suppressed_tool_call_names"] == ["write_file"]
|
||||
assert changes["suppressed_tool_call_ids"] == ["call_write_1"]
|
||||
assert "message_id" in changes
|
||||
assert isinstance(changes["extras"], dict)
|
||||
|
||||
def test_audit_event_never_carries_tool_arguments(self):
|
||||
"""PR #3035 review IMPORTANT: tool args are the filtered content itself
|
||||
and must NOT be persisted to run_events under any circumstance."""
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
sensitive_tc = {
|
||||
"id": "call_x",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/x", "content": "FILTERED_CONTENT_DO_NOT_PERSIST"},
|
||||
}
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[sensitive_tc],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
mw._apply(state, self._runtime_with_journal(journal))
|
||||
flat = repr(journal.record_middleware.call_args)
|
||||
assert "FILTERED_CONTENT_DO_NOT_PERSIST" not in flat, "tool arguments must not leak into audit event"
|
||||
assert "args" not in journal.record_middleware.call_args.kwargs["changes"]
|
||||
|
||||
def test_no_journal_in_runtime_context_is_silently_skipped(self):
|
||||
"""Subagent runtime / unit tests / no-event-store paths have no journal.
|
||||
Middleware must still intervene and clear tool_calls — only the audit
|
||||
event is skipped."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "t-noj"} # no __run_journal
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Should not raise; should still clear tool_calls.
|
||||
result = mw._apply(state, runtime)
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_journal_record_exception_does_not_break_run(self):
|
||||
"""Buggy journal must never propagate an exception into the agent loop."""
|
||||
journal = MagicMock()
|
||||
journal.record_middleware.side_effect = RuntimeError("db down")
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Must not raise.
|
||||
result = mw._apply(state, self._runtime_with_journal(journal))
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_no_record_when_passthrough(self):
|
||||
"""When the middleware does NOT intervene, no audit event is written."""
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "tool_calls"}, # healthy
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, self._runtime_with_journal(journal)) is None
|
||||
journal.record_middleware.assert_not_called()
|
||||
|
||||
|
||||
class TestStreamEvent:
|
||||
def test_emits_event_when_writer_available(self, monkeypatch):
|
||||
captured: list = []
|
||||
|
||||
def fake_writer(payload):
|
||||
captured.append(payload)
|
||||
|
||||
# Patch get_stream_writer at the symbol-resolution site.
|
||||
import langgraph.config
|
||||
|
||||
monkeypatch.setattr(langgraph.config, "get_stream_writer", lambda: fake_writer)
|
||||
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
mw._apply(state, _runtime("t-stream"))
|
||||
|
||||
assert len(captured) == 1
|
||||
payload = captured[0]
|
||||
assert payload["type"] == "safety_termination"
|
||||
assert payload["detector"] == "openai_compatible_content_filter"
|
||||
assert payload["reason_field"] == "finish_reason"
|
||||
assert payload["reason_value"] == "content_filter"
|
||||
assert payload["suppressed_tool_call_count"] == 1
|
||||
assert payload["suppressed_tool_call_names"] == ["write_file"]
|
||||
assert payload["thread_id"] == "t-stream"
|
||||
|
||||
def test_writer_unavailable_does_not_break(self, monkeypatch):
|
||||
import langgraph.config
|
||||
|
||||
def boom():
|
||||
raise LookupError("not in a stream context")
|
||||
|
||||
monkeypatch.setattr(langgraph.config, "get_stream_writer", boom)
|
||||
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Should not raise.
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
@@ -0,0 +1,176 @@
|
||||
"""Unit tests for SafetyTerminationDetector built-ins."""
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from deerflow.agents.middlewares.safety_termination_detectors import (
|
||||
AnthropicRefusalDetector,
|
||||
GeminiSafetyDetector,
|
||||
OpenAICompatibleContentFilterDetector,
|
||||
SafetyTermination,
|
||||
SafetyTerminationDetector,
|
||||
default_detectors,
|
||||
)
|
||||
|
||||
|
||||
def _ai(*, content="", tool_calls=None, response_metadata=None, additional_kwargs=None) -> AIMessage:
|
||||
return AIMessage(
|
||||
content=content,
|
||||
tool_calls=tool_calls or [],
|
||||
response_metadata=response_metadata or {},
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAICompatibleContentFilterDetector:
|
||||
def test_default_matches_content_filter(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
hit = d.detect(_ai(response_metadata={"finish_reason": "content_filter"}))
|
||||
assert hit is not None
|
||||
assert hit.detector == "openai_compatible_content_filter"
|
||||
assert hit.reason_field == "finish_reason"
|
||||
assert hit.reason_value == "content_filter"
|
||||
|
||||
def test_case_insensitive_match(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "CONTENT_FILTER"})) is not None
|
||||
|
||||
def test_other_finish_reasons_pass_through(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "stop"})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "tool_calls"})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "length"})) is None
|
||||
|
||||
def test_missing_metadata_passes_through(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai()) is None
|
||||
|
||||
def test_non_string_finish_reason_passes_through(self):
|
||||
# Some adapters may stash an enum or dict — must not raise.
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": 42})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": {"value": "content_filter"}})) is None
|
||||
|
||||
def test_falls_back_to_additional_kwargs(self):
|
||||
# Legacy adapters surface finish_reason via additional_kwargs.
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
hit = d.detect(_ai(additional_kwargs={"finish_reason": "content_filter"}))
|
||||
assert hit is not None
|
||||
|
||||
def test_configurable_extra_values(self):
|
||||
# Chinese providers sometimes use bespoke tokens.
|
||||
d = OpenAICompatibleContentFilterDetector(finish_reasons=["content_filter", "sensitive", "violation"])
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "sensitive"})) is not None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "violation"})) is not None
|
||||
# Original token still matches.
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "content_filter"})) is not None
|
||||
|
||||
def test_carries_azure_content_filter_results(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
filter_results = {"hate": {"filtered": True, "severity": "high"}}
|
||||
hit = d.detect(
|
||||
_ai(
|
||||
response_metadata={
|
||||
"finish_reason": "content_filter",
|
||||
"content_filter_results": filter_results,
|
||||
},
|
||||
)
|
||||
)
|
||||
assert hit is not None
|
||||
assert hit.extras["content_filter_results"] == filter_results
|
||||
|
||||
|
||||
class TestAnthropicRefusalDetector:
|
||||
def test_default_matches_refusal(self):
|
||||
hit = AnthropicRefusalDetector().detect(_ai(response_metadata={"stop_reason": "refusal"}))
|
||||
assert hit is not None
|
||||
assert hit.reason_field == "stop_reason"
|
||||
assert hit.reason_value == "refusal"
|
||||
|
||||
def test_other_stop_reasons_pass_through(self):
|
||||
d = AnthropicRefusalDetector()
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "end_turn"})) is None
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "tool_use"})) is None
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "max_tokens"})) is None
|
||||
|
||||
def test_anthropic_does_not_steal_finish_reason(self):
|
||||
# An OpenAI message must not accidentally trip the Anthropic detector.
|
||||
assert AnthropicRefusalDetector().detect(_ai(response_metadata={"finish_reason": "content_filter"})) is None
|
||||
|
||||
|
||||
class TestGeminiSafetyDetector:
|
||||
def test_default_set_covers_documented_reasons(self):
|
||||
d = GeminiSafetyDetector()
|
||||
for reason in (
|
||||
# text safety
|
||||
"SAFETY",
|
||||
"BLOCKLIST",
|
||||
"PROHIBITED_CONTENT",
|
||||
"SPII",
|
||||
"RECITATION",
|
||||
# image safety
|
||||
"IMAGE_SAFETY",
|
||||
"IMAGE_PROHIBITED_CONTENT",
|
||||
"IMAGE_RECITATION",
|
||||
):
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is not None, reason
|
||||
|
||||
def test_normal_termination_passes_through(self):
|
||||
d = GeminiSafetyDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "STOP"})) is None
|
||||
# MAX_TOKENS / LANGUAGE / NO_IMAGE / OTHER / IMAGE_OTHER /
|
||||
# MALFORMED_FUNCTION_CALL / UNEXPECTED_TOOL_CALL are intentionally
|
||||
# excluded from the default set — they are either normal termination,
|
||||
# capability mismatches, too broad (OTHER), or tool-call protocol
|
||||
# errors. See GeminiSafetyDetector docstring.
|
||||
for reason in (
|
||||
"MAX_TOKENS",
|
||||
"LANGUAGE",
|
||||
"NO_IMAGE",
|
||||
"OTHER",
|
||||
"IMAGE_OTHER",
|
||||
"MALFORMED_FUNCTION_CALL",
|
||||
"UNEXPECTED_TOOL_CALL",
|
||||
"FINISH_REASON_UNSPECIFIED",
|
||||
):
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is None, reason
|
||||
|
||||
def test_carries_safety_ratings(self):
|
||||
ratings = [{"category": "HARM_CATEGORY_HARASSMENT", "probability": "HIGH"}]
|
||||
hit = GeminiSafetyDetector().detect(
|
||||
_ai(
|
||||
response_metadata={
|
||||
"finish_reason": "SAFETY",
|
||||
"safety_ratings": ratings,
|
||||
},
|
||||
)
|
||||
)
|
||||
assert hit is not None
|
||||
assert hit.extras["safety_ratings"] == ratings
|
||||
|
||||
|
||||
class TestDefaultDetectorSet:
|
||||
def test_default_set_returns_three_detectors(self):
|
||||
dets = default_detectors()
|
||||
names = {d.name for d in dets}
|
||||
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
|
||||
|
||||
def test_default_set_returns_fresh_list(self):
|
||||
# Caller mutation must not affect later calls.
|
||||
first = default_detectors()
|
||||
first.clear()
|
||||
second = default_detectors()
|
||||
assert len(second) == 3
|
||||
|
||||
|
||||
class TestProtocolConformance:
|
||||
def test_builtins_satisfy_protocol(self):
|
||||
for d in default_detectors():
|
||||
assert isinstance(d, SafetyTerminationDetector)
|
||||
|
||||
def test_safety_termination_is_frozen(self):
|
||||
t = SafetyTermination(detector="x", reason_field="finish_reason", reason_value="content_filter")
|
||||
try:
|
||||
t.detector = "y" # type: ignore[misc]
|
||||
except Exception:
|
||||
return
|
||||
raise AssertionError("SafetyTermination should be frozen")
|
||||
@@ -5,6 +5,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.sandbox.exceptions import SandboxError
|
||||
from deerflow.sandbox.tools import (
|
||||
VIRTUAL_PATH_PREFIX,
|
||||
_apply_cwd_prefix,
|
||||
@@ -1140,6 +1141,170 @@ def test_str_replace_and_append_on_same_path_should_preserve_both_updates(monkey
|
||||
assert sandbox.content == "ALPHA\ntail\n"
|
||||
|
||||
|
||||
def test_write_file_tool_bounds_large_oserror_and_masks_local_paths(monkeypatch) -> None:
|
||||
class FailingSandbox:
|
||||
id = "sandbox-write-large-oserror"
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
host_path = f"{_THREAD_DATA['workspace_path']}/nested/output.txt"
|
||||
raise OSError(f"write failed at {host_path}\n{'A' * 12000}\nremote tail marker")
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
sandbox = FailingSandbox()
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: True)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.get_thread_data", lambda runtime: _THREAD_DATA)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.validate_local_tool_path", lambda path, thread_data: None)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.sandbox.tools._resolve_and_validate_user_data_path",
|
||||
lambda path, thread_data: f"{_THREAD_DATA['workspace_path']}/output.txt",
|
||||
)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="写入大文件失败",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="report body",
|
||||
)
|
||||
|
||||
assert len(result) <= 2000
|
||||
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
|
||||
assert "/tmp/deer-flow/threads/t1/user-data/workspace" not in result
|
||||
assert "/mnt/user-data/workspace/nested/output.txt" in result
|
||||
assert "remote tail marker" in result
|
||||
assert "[write_file error truncated:" in result
|
||||
|
||||
|
||||
def test_write_file_tool_preserves_short_oserror_without_truncation(monkeypatch) -> None:
|
||||
class FailingSandbox:
|
||||
id = "sandbox-write-short-oserror"
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
raise OSError("disk quota exceeded")
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
sandbox = FailingSandbox()
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="写入失败",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="tiny payload",
|
||||
)
|
||||
|
||||
assert result == "Error: Failed to write file '/mnt/user-data/workspace/output.txt': OSError: disk quota exceeded"
|
||||
assert "[write_file error truncated:" not in result
|
||||
|
||||
|
||||
def test_write_file_tool_bounds_large_sandbox_error(monkeypatch) -> None:
|
||||
class FailingSandbox:
|
||||
id = "sandbox-write-large-sandbox-error"
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
raise SandboxError(f"remote write rejected {'B' * 12000} final detail")
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
sandbox = FailingSandbox()
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="远端写入失败",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="tiny payload",
|
||||
)
|
||||
|
||||
assert len(result) <= 2000
|
||||
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
|
||||
assert "SandboxError: remote write rejected" in result
|
||||
assert "final detail" in result
|
||||
assert "[write_file error truncated:" in result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raised_error", "expected_fragment"),
|
||||
[
|
||||
pytest.param(
|
||||
PermissionError("permission denied"),
|
||||
"Error: Permission denied writing to file: /mnt/user-data/workspace/output.txt",
|
||||
id="permission",
|
||||
),
|
||||
pytest.param(
|
||||
IsADirectoryError("target is a directory"),
|
||||
"Error: Path is a directory, not a file: /mnt/user-data/workspace/output.txt",
|
||||
id="directory",
|
||||
),
|
||||
pytest.param(
|
||||
Exception("remote sandbox timeout"),
|
||||
"Exception: remote sandbox timeout",
|
||||
id="generic",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_write_file_tool_formats_all_other_failure_branches(
|
||||
monkeypatch,
|
||||
raised_error: Exception,
|
||||
expected_fragment: str,
|
||||
) -> None:
|
||||
class FailingSandbox:
|
||||
id = "sandbox-write-other-failure"
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
raise raised_error
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
sandbox = FailingSandbox()
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="验证错误分支格式化",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="tiny payload",
|
||||
)
|
||||
|
||||
assert "/mnt/user-data/workspace/output.txt" in result
|
||||
assert expected_fragment in result
|
||||
assert "[write_file error truncated:" not in result
|
||||
|
||||
|
||||
def test_write_file_tool_handles_sandbox_init_failure(monkeypatch) -> None:
|
||||
"""Regression for #3133 review: SandboxError raised during sandbox
|
||||
initialization (before the local `requested_path` assignment) must still
|
||||
surface as a bounded tool error rather than an UnboundLocalError.
|
||||
"""
|
||||
|
||||
def raise_sandbox_error(runtime):
|
||||
raise SandboxError("sandbox missing")
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", raise_sandbox_error)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="sandbox 初始化失败",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="tiny payload",
|
||||
)
|
||||
|
||||
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
|
||||
assert "SandboxError: sandbox missing" in result
|
||||
assert "[write_file error truncated:" not in result
|
||||
|
||||
|
||||
def test_file_operation_lock_memory_cleanup() -> None:
|
||||
"""Verify that released locks are eventually cleaned up by WeakValueDictionary.
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from types import SimpleNamespace
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from app.gateway.routers import skills as skills_router
|
||||
from deerflow.skills.storage import get_or_new_skill_storage
|
||||
from deerflow.skills.types import Skill
|
||||
@@ -38,7 +39,8 @@ def _make_skill(name: str, *, enabled: bool) -> Skill:
|
||||
|
||||
def _make_test_app(config) -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.state.config = config
|
||||
app.state.config = config # kept for any startup-style reads
|
||||
app.dependency_overrides[get_config] = lambda: config
|
||||
app.include_router(skills_router.router)
|
||||
return app
|
||||
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Regression tests for _find_usage_recorder callback shape handling.
|
||||
|
||||
Bytedance issue #3107 BUG-002: When LangChain passes ``config["callbacks"]`` as
|
||||
an ``AsyncCallbackManager`` (instead of a plain list), the previous
|
||||
``for cb in callbacks`` loop raised ``TypeError: 'AsyncCallbackManager' object
|
||||
is not iterable``. ToolErrorHandlingMiddleware then converted the entire ``task``
|
||||
tool call into an error ToolMessage, losing the subagent result.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackManager, CallbackManager
|
||||
|
||||
from deerflow.tools.builtins.task_tool import _find_usage_recorder
|
||||
|
||||
|
||||
class _RecorderHandler:
|
||||
def record_external_llm_usage_records(self, records):
|
||||
self.records = records
|
||||
|
||||
|
||||
class _OtherHandler:
|
||||
pass
|
||||
|
||||
|
||||
def _make_runtime(callbacks):
|
||||
return SimpleNamespace(config={"callbacks": callbacks})
|
||||
|
||||
|
||||
def test_find_usage_recorder_with_plain_list():
|
||||
recorder = _RecorderHandler()
|
||||
runtime = _make_runtime([_OtherHandler(), recorder])
|
||||
assert _find_usage_recorder(runtime) is recorder
|
||||
|
||||
|
||||
def test_find_usage_recorder_with_async_callback_manager():
|
||||
"""LangChain wraps callbacks in AsyncCallbackManager for async tool runs.
|
||||
|
||||
The old implementation raised TypeError here. The recorder lives on
|
||||
``manager.handlers``; we must look there too.
|
||||
"""
|
||||
recorder = _RecorderHandler()
|
||||
manager = AsyncCallbackManager(handlers=[_OtherHandler(), recorder])
|
||||
runtime = _make_runtime(manager)
|
||||
assert _find_usage_recorder(runtime) is recorder
|
||||
|
||||
|
||||
def test_find_usage_recorder_with_sync_callback_manager():
|
||||
"""Sync flavor of the same wrapper used by some langchain code paths."""
|
||||
recorder = _RecorderHandler()
|
||||
manager = CallbackManager(handlers=[recorder])
|
||||
runtime = _make_runtime(manager)
|
||||
assert _find_usage_recorder(runtime) is recorder
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_when_no_recorder():
|
||||
manager = AsyncCallbackManager(handlers=[_OtherHandler()])
|
||||
runtime = _make_runtime(manager)
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_handles_empty_manager():
|
||||
manager = AsyncCallbackManager(handlers=[])
|
||||
runtime = _make_runtime(manager)
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_for_none_runtime():
|
||||
assert _find_usage_recorder(None) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_when_callbacks_is_none():
|
||||
runtime = _make_runtime(None)
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_for_single_handler_object():
|
||||
"""A single handler instance (not wrapped in a list or manager) should not crash.
|
||||
|
||||
LangChain's contract is that ``config["callbacks"]`` is a list-or-manager,
|
||||
but we treat any other shape defensively rather than letting a ``for`` loop
|
||||
blow up at runtime.
|
||||
"""
|
||||
runtime = _make_runtime(_RecorderHandler())
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_when_config_not_dict():
|
||||
"""Defensive: a runtime without a dict-shaped config should not raise."""
|
||||
runtime = SimpleNamespace(config="not-a-dict")
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
@@ -53,3 +53,30 @@ def test_thread_token_usage_returns_stable_shape():
|
||||
},
|
||||
}
|
||||
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
|
||||
|
||||
|
||||
def test_thread_token_usage_can_include_active_runs():
|
||||
run_store = MagicMock()
|
||||
run_store.aggregate_tokens_by_thread = AsyncMock(
|
||||
return_value={
|
||||
"total_tokens": 175,
|
||||
"total_input_tokens": 120,
|
||||
"total_output_tokens": 55,
|
||||
"total_runs": 3,
|
||||
"by_model": {"unknown": {"tokens": 175, "runs": 3}},
|
||||
"by_caller": {
|
||||
"lead_agent": 145,
|
||||
"subagent": 25,
|
||||
"middleware": 5,
|
||||
},
|
||||
},
|
||||
)
|
||||
app = _make_app(run_store)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-1/token-usage?include_active=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["total_tokens"] == 175
|
||||
assert response.json()["total_runs"] == 3
|
||||
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1", include_active=True)
|
||||
|
||||
@@ -134,8 +134,14 @@ def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
|
||||
|
||||
assert captured["app_config"] is app_config
|
||||
assert len(middlewares) == 6
|
||||
assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware)
|
||||
# 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling,
|
||||
# SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware
|
||||
# (enabled by default — see SafetyFinishReasonConfig).
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
|
||||
assert len(middlewares) == 7
|
||||
assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares)
|
||||
assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware)
|
||||
|
||||
|
||||
def test_wrap_tool_call_passthrough_on_success():
|
||||
|
||||
@@ -11,6 +11,7 @@ from _router_auth_helpers import call_unwrapped, make_authed_test_app
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from app.gateway.routers import uploads
|
||||
|
||||
|
||||
@@ -687,6 +688,7 @@ def test_upload_limits_endpoint_requires_thread_access():
|
||||
cfg.uploads = {}
|
||||
app = make_authed_test_app(owner_check_passes=False)
|
||||
app.state.config = cfg
|
||||
app.dependency_overrides[get_config] = lambda: cfg
|
||||
app.include_router(uploads.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
|
||||
+48
-7
@@ -15,7 +15,7 @@
|
||||
# ============================================================================
|
||||
# Bump this number when the config schema changes.
|
||||
# Run `make config-upgrade` to merge new fields into your local config.yaml.
|
||||
config_version: 9
|
||||
config_version: 10
|
||||
|
||||
# ============================================================================
|
||||
# Logging
|
||||
@@ -118,19 +118,25 @@ models:
|
||||
# For Docker deployments, use host.docker.internal instead of localhost:
|
||||
# base_url: http://host.docker.internal:11434
|
||||
|
||||
# Example: Anthropic Claude model
|
||||
# - name: claude-3-5-sonnet
|
||||
# display_name: Claude 3.5 Sonnet
|
||||
# Example: Anthropic Claude model (with extended thinking)
|
||||
# supports_thinking: true is required — without it, DeerFlow silently falls
|
||||
# back to non-thinking mode even when the UI thinking toggle is on.
|
||||
# budget_tokens is required by the Anthropic API when thinking.type=enabled
|
||||
# (no server default; min 1024; must be less than max_tokens).
|
||||
# - name: claude-sonnet-4
|
||||
# display_name: Claude Sonnet 4
|
||||
# use: langchain_anthropic:ChatAnthropic
|
||||
# model: claude-3-5-sonnet-20241022
|
||||
# model: claude-sonnet-4-20250514
|
||||
# api_key: $ANTHROPIC_API_KEY
|
||||
# default_request_timeout: 600.0
|
||||
# max_retries: 2
|
||||
# max_tokens: 8192
|
||||
# supports_vision: true # Enable vision support for view_image tool
|
||||
# max_tokens: 16000
|
||||
# supports_vision: true
|
||||
# supports_thinking: true
|
||||
# when_thinking_enabled:
|
||||
# thinking:
|
||||
# type: enabled
|
||||
# budget_tokens: 4096 # required; min 1024; must be < max_tokens
|
||||
# when_thinking_disabled:
|
||||
# thinking:
|
||||
# type: disabled
|
||||
@@ -529,6 +535,41 @@ loop_detection:
|
||||
# warn: 150
|
||||
# hard_limit: 300
|
||||
|
||||
# ============================================================================
|
||||
# Provider Safety Termination Configuration
|
||||
# ============================================================================
|
||||
# Intercept AIMessages where the provider stopped generation for safety reasons
|
||||
# (e.g. OpenAI finish_reason='content_filter', Anthropic stop_reason='refusal',
|
||||
# Gemini finish_reason='SAFETY') while still returning tool_calls. The
|
||||
# tool_calls in such responses are typically truncated/unreliable and must
|
||||
# not be executed. See issue #3028 for the full failure mode.
|
||||
#
|
||||
# Detectors are loaded by class path via reflection (same pattern as
|
||||
# guardrails / models / tools). The built-in set covers OpenAI-compatible
|
||||
# content_filter, Anthropic refusal, and Gemini SAFETY/BLOCKLIST/
|
||||
# PROHIBITED_CONTENT/SPII/RECITATION.
|
||||
|
||||
safety_finish_reason:
|
||||
enabled: true
|
||||
# Leave `detectors` unset to use the built-in detector set. Set to a
|
||||
# non-empty list to fully override (use `enabled: false` to disable instead
|
||||
# of providing an empty list).
|
||||
#
|
||||
# Example — extend the OpenAI-compatible detector for a Chinese provider
|
||||
# whose gateway uses a non-standard finish_reason token:
|
||||
# detectors:
|
||||
# - use: deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector
|
||||
# config:
|
||||
# finish_reasons: ["content_filter", "sensitive", "risk_control"]
|
||||
# - use: deerflow.agents.middlewares.safety_termination_detectors:AnthropicRefusalDetector
|
||||
# - use: deerflow.agents.middlewares.safety_termination_detectors:GeminiSafetyDetector
|
||||
#
|
||||
# Example — add a custom detector for an in-house provider:
|
||||
# detectors:
|
||||
# - use: my_company.deerflow_ext:WenxinSafetyDetector
|
||||
# config:
|
||||
# error_codes: [336003, 17, 18]
|
||||
|
||||
# ============================================================================
|
||||
# Sandbox Configuration
|
||||
# ============================================================================
|
||||
|
||||
@@ -3,18 +3,6 @@
|
||||
"my_package.mcp.auth:build_auth_interceptor"
|
||||
],
|
||||
"mcpServers": {
|
||||
"filesystem": {
|
||||
"enabled": false,
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-filesystem",
|
||||
"/path/to/allowed/files"
|
||||
],
|
||||
"env": {},
|
||||
"description": "Provides filesystem access within allowed directories"
|
||||
},
|
||||
"github": {
|
||||
"enabled": false,
|
||||
"type": "stdio",
|
||||
@@ -42,4 +30,4 @@
|
||||
}
|
||||
},
|
||||
"skills": {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ import {
|
||||
import { useRehypeSplitWordsIntoSpans } from "@/core/rehype";
|
||||
import type { Subtask } from "@/core/tasks";
|
||||
import { useUpdateSubtask } from "@/core/tasks/context";
|
||||
import { parseSubtaskResult } from "@/core/tasks/subtask-result";
|
||||
import type { AgentThreadState } from "@/core/threads";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
@@ -359,33 +360,10 @@ export function MessageList({
|
||||
} else if (message.type === "tool") {
|
||||
const taskId = message.tool_call_id;
|
||||
if (taskId) {
|
||||
const result = extractTextFromMessage(message);
|
||||
if (result.startsWith("Task Succeeded. Result:")) {
|
||||
updateSubtask({
|
||||
id: taskId,
|
||||
status: "completed",
|
||||
result: result
|
||||
.split("Task Succeeded. Result:")[1]
|
||||
?.trim(),
|
||||
});
|
||||
} else if (result.startsWith("Task failed.")) {
|
||||
updateSubtask({
|
||||
id: taskId,
|
||||
status: "failed",
|
||||
error: result.split("Task failed.")[1]?.trim(),
|
||||
});
|
||||
} else if (result.startsWith("Task timed out")) {
|
||||
updateSubtask({
|
||||
id: taskId,
|
||||
status: "failed",
|
||||
error: result,
|
||||
});
|
||||
} else {
|
||||
updateSubtask({
|
||||
id: taskId,
|
||||
status: "in_progress",
|
||||
});
|
||||
}
|
||||
const parsed = parseSubtaskResult(
|
||||
extractTextFromMessage(message),
|
||||
);
|
||||
updateSubtask({ id: taskId, ...parsed });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,11 +29,6 @@ The default location is the project root (same directory as `config.yaml`). The
|
||||
"args": ["-y", "@my-org/my-mcp-server"],
|
||||
"enabled": true
|
||||
},
|
||||
"filesystem": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/dir"],
|
||||
"enabled": true
|
||||
},
|
||||
"sqlite": {
|
||||
"command": "uvx",
|
||||
"args": ["mcp-server-sqlite", "--db-path", "/path/to/db.sqlite"],
|
||||
@@ -43,6 +38,16 @@ The default location is the project root (same directory as `config.yaml`). The
|
||||
}
|
||||
```
|
||||
|
||||
<Callout type="warning">
|
||||
Do not add an MCP filesystem server for DeerFlow workspace files. DeerFlow
|
||||
already provides built-in file tools for thread-scoped workspace access, and
|
||||
overlapping file tools with different path semantics can make LLM tool
|
||||
selection and file access behavior unstable. DeerFlow does not currently
|
||||
adapt MCP Roots mode for filesystem servers: it does not publish per-thread
|
||||
MCP roots or map sandbox paths such as <code>/mnt/user-data/...</code> to
|
||||
paths accepted by <code>@modelcontextprotocol/server-filesystem</code>.
|
||||
</Callout>
|
||||
|
||||
Each server entry supports:
|
||||
|
||||
- `command`: the executable to run (e.g., `npx`, `uvx`, `python`)
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
---
|
||||
title: Tool-Using Agents Must Handle Provider Safety Termination Signals Correctly
|
||||
description: Why tool calls left in a safety-terminated model response must not be executed, and how to configure provider detectors in DeerFlow.
|
||||
date: 2026-05-22
|
||||
tags:
|
||||
- Safety
|
||||
- Agents
|
||||
- Model Providers
|
||||
---
|
||||
|
||||
## Tool-Using Agents Must Handle Provider Safety Termination Signals Correctly
|
||||
|
||||
When a large model provider decides that an input or output has triggered a safety policy, the important outcome is not merely that the model says less. The application needs to know that the current generation turn has been terminated. In a normal chat interface, this may appear as a refusal, filtered text, or an error response. For an Agent that can call tools, the risk is higher: if the provider has already stopped generation while the response still contains `tool_calls`, those tool arguments may only be partially generated.
|
||||
|
||||
These partial tool calls must not be executed as normal intent. A truncated `write_file` call may write an incomplete report. A truncated `bash` call may enter the sandbox with incomplete arguments. After seeing the failed result, the Agent may retry and trigger the same safety rule repeatedly.
|
||||
|
||||
[PR #3035](https://github.com/bytedance/deer-flow/pull/3035) addresses this boundary: when a provider stops generation with a safety signal while the response still contains tool calls, DeerFlow should suppress those tool calls first and record the turn as a safety termination event.
|
||||
|
||||
## Why Safety Termination Needs Dedicated Handling
|
||||
|
||||
A safety termination is not a normal tool-call finish reason.
|
||||
|
||||
In a healthy tool turn, the provider explicitly tells the application that it should call tools. A safety termination says something different: the output has been blocked by provider policy, or streaming generation has been cut off early. Even if tool-call fragments remain in the response object, the application cannot assume that their JSON arguments, file contents, or command text are complete.
|
||||
|
||||
In a real Agent run, this creates two kinds of risk:
|
||||
|
||||
| Risk | Impact |
|
||||
| --- | --- |
|
||||
| Runtime risk | Executing truncated tool arguments can create corrupted files, malformed commands, repeated retries, or tool loops |
|
||||
| Provider risk | Repeatedly sending similar violating inputs or outputs to a provider increases safety review and abuse-control pressure |
|
||||
|
||||
The second risk matters. Providers enforce their policies differently, but their official materials already make clear that safety policy can affect more than a single completion. It can also affect end users, API access, or account status.
|
||||
|
||||
## What Providers Expose and How They Respond
|
||||
|
||||
Providers do not use one common field name, and they do not share one enforcement process. Deployments need to distinguish at least two layers:
|
||||
|
||||
1. Which signal in this response says that generation was stopped by a safety policy.
|
||||
2. Which follow-up actions the provider has publicly described when safety problems keep recurring.
|
||||
|
||||
| Provider | Runtime signal | Publicly documented response or recommendation |
|
||||
| --- | --- | --- |
|
||||
| GLM | Synchronous calls may return a safety audit error; streaming output may end with `finish_reason="sensitive"` | Pass `user_id` to distinguish end users; the platform may block violating end-user requests so enterprise accounts are not affected by end-user abuse |
|
||||
| OpenAI | Chat Completions may return `finish_reason="content_filter"` | Use Moderation and `safety_identifier`; repeated usage policy violations may lead to warnings, restrictions, or account deactivation |
|
||||
| Anthropic | Streaming refusals may be exposed through `stop_reason="refusal"` | Reset, rewrite, or narrow context after a refusal; the AUP describes request limiting, output modification, suspension, or termination |
|
||||
| Gemini | A safety-filtered candidate may return `finishReason=SAFETY`, and blocked content is not returned | Abuse monitoring covers prompts and outputs; follow-up actions can escalate from contacting the developer to temporary restrictions, suspension, or account closure |
|
||||
| DeepSeek | Chat completion `finish_reason` includes `content_filter` | The `user` field can help content safety review; potential usage guideline violations may trigger a temporary suspension protocol |
|
||||
|
||||
GLM is the most direct example. Its safety audit documentation describes the streaming safety finish signal, the recommendation to identify end users, and the possibility of blocking requests from violating end users. [GLM safety audit documentation](https://docs.bigmodel.cn/cn/guide/platform/securityaudit)
|
||||
|
||||
OpenAI defines `content_filter` as a Chat Completions finish reason. Its safety best practices recommend using `safety_identifier` for end users so policy violations can be attributed more precisely than a shared API key alone. OpenAI help documentation also says repeated usage policy violations may lead to account deactivation. [Safety best practices](https://developers.openai.com/api/docs/guides/safety-best-practices/) [Why Was My OpenAI Account Deactivated?](https://help.openai.com/en/articles/10562188)
|
||||
|
||||
Anthropic distinguishes ordinary stops from safety refusals in its Claude streaming refusal guidance: when the streaming classifier intervenes, the response can carry `stop_reason="refusal"`. It also recommends that applications do not keep feeding refused content back into later context, and instead reset the conversation, rewrite the prompt, or narrow the task. The Anthropic AUP says it may limit requests, block or modify outputs, and suspend or terminate access when necessary. [Handle streaming refusals](https://platform.claude.com/docs/en/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals) [Acceptable Use Policy](https://www.anthropic.com/legal/aup)
|
||||
|
||||
Gemini safety documentation emphasizes another shape of intervention. A prompt may be blocked before generation, and a candidate may be filtered after generation. When a response candidate is stopped by safety policy, the response can expose `finishReason=SAFETY` without returning the blocked content itself. Gemini API terms also say abuse monitoring covers prompts and outputs and list progressively stronger follow-up actions. [Gemini safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) [Gemini API Additional Terms of Service](https://ai.google.dev/gemini-api/terms)
|
||||
|
||||
DeepSeek lists `content_filter` as a chat completion finish reason and describes the request `user` field as helpful for content safety review. Its FAQ also says potential usage guideline violations may trigger a temporary suspension process. [Create Chat Completion](https://api-docs.deepseek.com/api/create-chat-completion)
|
||||
|
||||
Some providers intervene earlier or at a layer outside the model message. For example, Azure OpenAI tells applications to inspect `finish_reason` because `content_filter` may leave a completion incomplete. Amazon Bedrock Guardrails can return `stopReason="guardrail_intervened"` in a response. In Alibaba Cloud Model Studio guardrail examples, output-side blocking may also appear directly as a `DataInspectionFailed` error. Together, these examples show that a safety intervention may be a stop signal in a model message or an API-level error. Applications need more than one handling path. [Azure OpenAI content filtering](https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/content-filter) [Amazon Bedrock Guardrails](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-use-converse-api.html)
|
||||
|
||||
## What DeerFlow Does at This Boundary
|
||||
|
||||
`SafetyFinishReasonMiddleware` has a narrow responsibility. It does not replace provider content review, and it does not rewrite every refusal into the same error. It only intervenes when both conditions below are true:
|
||||
|
||||
1. The provider response carries a configured safety termination signal.
|
||||
2. The current `AIMessage` still contains non-empty `tool_calls`.
|
||||
|
||||
When it intervenes, it:
|
||||
|
||||
1. Clears structured tool calls and residual tool-call fields in raw provider metadata.
|
||||
2. Prevents those tool arguments from reaching the tool node for execution.
|
||||
3. Preserves already generated partial text and appends a user-facing explanation.
|
||||
4. Records the detector, reason field, reason value, and suppressed tool names and counts.
|
||||
5. Avoids writing tool arguments that may themselves contain filtered content into audit events again.
|
||||
|
||||
This makes the safety termination signal take priority over the fact that tool calls are present in the response. For the Agent runtime, that is the more conservative and more correct control flow.
|
||||
|
||||
## Default Configuration
|
||||
|
||||
The default configuration only needs `safety_finish_reason` enabled:
|
||||
|
||||
```yaml
|
||||
safety_finish_reason:
|
||||
enabled: true
|
||||
```
|
||||
|
||||
When `detectors` is not configured explicitly, DeerFlow uses the built-in detector set:
|
||||
|
||||
| Detector | Default match |
|
||||
| --- | --- |
|
||||
| `OpenAICompatibleContentFilterDetector` | `finish_reason="content_filter"` |
|
||||
| `AnthropicRefusalDetector` | `stop_reason="refusal"` |
|
||||
| `GeminiSafetyDetector` | Gemini safety-related `finish_reason` values such as `SAFETY`, `BLOCKLIST`, `PROHIBITED_CONTENT`, `SPII`, and `RECITATION` |
|
||||
|
||||
This default set covers common DeerFlow paths for OpenAI-compatible providers, Anthropic, and Gemini. It does not treat a normal `finish_reason="tool_calls"` as a safety termination, and it does not fold length truncation such as `length` or `max_tokens` into the safety category.
|
||||
|
||||
## Example: Extend the Streaming Safety Finish Signal for GLM
|
||||
|
||||
GLM streaming responses use `sensitive` as the safety finish value. If the current adapter preserves that value in `AIMessage.response_metadata.finish_reason` or `additional_kwargs.finish_reason`, it can be handled through the configurable finish reason set on the OpenAI-compatible detector:
|
||||
|
||||
```yaml
|
||||
safety_finish_reason:
|
||||
enabled: true
|
||||
detectors:
|
||||
- use: deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector
|
||||
config:
|
||||
finish_reasons: ["content_filter", "sensitive"]
|
||||
|
||||
- use: deerflow.agents.middlewares.safety_termination_detectors:AnthropicRefusalDetector
|
||||
|
||||
- use: deerflow.agents.middlewares.safety_termination_detectors:GeminiSafetyDetector
|
||||
```
|
||||
|
||||
Two configuration details matter here.
|
||||
|
||||
First, `detectors` replaces the default list. It does not append one item to it. The example therefore keeps the Anthropic and Gemini detectors while adding GLM's `sensitive` value.
|
||||
|
||||
Second, this middleware handles safety finish signals that have already reached a model message. If the provider returns a safety audit error at the API layer, such as a synchronous GLM safety audit error code, the caller still needs to handle it in the LLM or API error path.
|
||||
|
||||
## Boundary
|
||||
|
||||
`SafetyFinishReasonMiddleware` solves a specific Agent control-flow problem. It is not a complete content safety solution. It does not replace moderation, permission isolation, user governance, or provider-side review, and it does not cover every plain-text refusal.
|
||||
|
||||
This boundary is still worth protecting explicitly: when a provider has already stopped output for safety reasons, a tool-using Agent should treat that turn as interrupted output, not executable tool intent.
|
||||
@@ -193,15 +193,23 @@ BETTER_AUTH_SECRET=local-dev-secret-at-least-32-chars
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"filesystem": {
|
||||
"my-server": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/dir"],
|
||||
"args": ["-y", "@my-org/my-mcp-server"],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
<Callout type="warning">
|
||||
不要为 DeerFlow 工作区文件引入 MCP filesystem server。它会与 DeerFlow
|
||||
内置文件工具形成路径语义不同的重复能力,让 LLM 行为不稳定。DeerFlow
|
||||
当前没有为 filesystem server 适配 MCP Roots 模式,也不会把{" "}
|
||||
<code>/mnt/user-data/...</code> 这类沙箱路径映射成{" "}
|
||||
<code>@modelcontextprotocol/server-filesystem</code> 可接受的路径。
|
||||
</Callout>
|
||||
|
||||
### 技能启用状态
|
||||
|
||||
技能启用状态会反映在 `extensions_config.json` 中。你可以直接编辑它,或通过 DeerFlow 应用界面进行管理。
|
||||
|
||||
@@ -28,11 +28,6 @@ MCP 服务器在 `extensions_config.json` 中配置,这个文件独立于 `con
|
||||
"args": ["-y", "@my-org/my-mcp-server"],
|
||||
"enabled": true
|
||||
},
|
||||
"filesystem": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/dir"],
|
||||
"enabled": true
|
||||
},
|
||||
"sqlite": {
|
||||
"command": "uvx",
|
||||
"args": ["mcp-server-sqlite", "--db-path", "/path/to/db.sqlite"],
|
||||
@@ -42,6 +37,15 @@ MCP 服务器在 `extensions_config.json` 中配置,这个文件独立于 `con
|
||||
}
|
||||
```
|
||||
|
||||
<Callout type="warning">
|
||||
不要为 DeerFlow 工作区文件引入 MCP filesystem server。DeerFlow 已提供按
|
||||
thread 隔离的内置文件工具;重复引入路径语义不同的文件工具,会让 LLM
|
||||
的工具选择和文件访问行为不稳定。DeerFlow 当前没有为 filesystem server
|
||||
适配 MCP Roots 模式:不会发布按 thread 收窄的 MCP roots,也不会把{" "}
|
||||
<code>/mnt/user-data/...</code> 这类沙箱路径映射成{" "}
|
||||
<code>@modelcontextprotocol/server-filesystem</code> 可接受的路径。
|
||||
</Callout>
|
||||
|
||||
每个服务器条目支持:
|
||||
|
||||
- `command`:要运行的可执行文件(如 `npx`、`uvx`、`python`)
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
---
|
||||
title: 工具型 Agent 需要正确处理模型提供商的安全中止信号
|
||||
description: 当模型输出因安全策略被中止时,为什么不能继续执行残留的工具调用,以及如何在 DeerFlow 中配置 provider detector。
|
||||
date: 2026-05-22
|
||||
tags:
|
||||
- Safety
|
||||
- Agents
|
||||
- Model Providers
|
||||
---
|
||||
|
||||
## 工具型 Agent 需要正确处理模型提供商的安全中止信号
|
||||
|
||||
当大模型提供商认为输入或输出触发了安全策略时,最理想的结果不是“模型少说了几句话”,而是应用已经明确知道这一轮生成被中止了。对于普通聊天界面,这通常表现为拒答、过滤后的文本,或者一个错误响应。对于能调用工具的 Agent,风险会更高:如果 provider 已经中止输出,但响应里仍残留了 `tool_calls`,这些工具参数很可能只生成了一半。
|
||||
|
||||
这类半截工具调用不应被当成正常意图执行。一个被截断的 `write_file` 可能写出不完整的报告;一个被截断的 `bash` 调用可能带着残缺参数进入沙箱;Agent 看到失败结果后还可能继续重试,反复触发同一条安全规则。
|
||||
|
||||
[PR #3035](https://github.com/bytedance/deer-flow/pull/3035) 处理的就是这个边界:当 provider 用安全信号中止生成,同时响应仍带有工具调用时,DeerFlow 应先压制这些工具调用,再把这一轮作为安全中止事件记录下来。
|
||||
|
||||
## 为什么需要单独处理安全中止
|
||||
|
||||
安全中止不是普通的工具调用结束原因。
|
||||
|
||||
一次健康的工具轮次通常由 provider 明确告诉应用“现在应该调用工具”。但安全中止表达的是另一件事:输出已经被 provider 的策略拦住,或者流式生成已经被提前切断。此时即使响应对象里还能看到工具调用片段,也不能假设它的 JSON 参数、文件内容或命令文本已经完整。
|
||||
|
||||
在真实 Agent 运行中,这会同时产生两类风险:
|
||||
|
||||
| 风险 | 影响 |
|
||||
| --- | --- |
|
||||
| 运行时风险 | 执行被截断的工具参数,产生损坏文件、异常命令、重复重试或工具循环 |
|
||||
| provider 风险 | 应用反复把同类违规输入或输出送到 provider,累积安全审核和风控压力 |
|
||||
|
||||
第二类风险不能被忽略。不同 provider 的处置力度不同,但官方材料已经表明,安全策略不仅影响单次 completion,也可能影响终端用户、API 访问能力或账号状态。
|
||||
|
||||
## 各家 provider 公开了什么信号和处置方式
|
||||
|
||||
provider 并没有统一的字段名,也没有统一的处罚流程。部署方至少要区分两层信息:
|
||||
|
||||
1. 这一轮响应里,什么信号说明生成被安全策略中止。
|
||||
2. 如果安全问题反复出现,provider 公开说明过哪些后续动作。
|
||||
|
||||
| Provider | 运行时信号 | 公开的后续处置或建议 |
|
||||
| --- | --- | --- |
|
||||
| GLM | 同步调用可能返回安全审核错误;流式输出可能以 `finish_reason="sensitive"` 结束 | 建议传入 `user_id` 区分终端用户;平台可封禁违规终端用户请求,避免企业账号受终端用户滥用影响 |
|
||||
| OpenAI | Chat Completions 的 `finish_reason` 可为 `content_filter` | 建议使用 Moderation 和 `safety_identifier`;重复违反使用政策可能带来警告、限制或账号停用 |
|
||||
| Anthropic | 流式拒绝可通过 `stop_reason="refusal"` 暴露 | 收到拒绝后应重置、改写或缩小上下文;AUP 说明可限制请求、修改输出、暂停或终止访问 |
|
||||
| Gemini | 被安全过滤的 candidate 可返回 `finishReason=SAFETY`,且被拦截内容不会返回 | abuse monitoring 会检查 prompts 和 outputs;后续动作可从联系开发者升级到临时限制、暂停或账号关闭 |
|
||||
| DeepSeek | Chat completion 的 `finish_reason` 枚举包含 `content_filter` | `user` 字段可帮助内容安全审核;潜在使用规范违规可能触发临时 suspension protocol |
|
||||
|
||||
GLM 的说明最直接。它的安全审核文档同时给出了流式安全结束信号、终端用户标识建议,以及对违规终端用户请求做封禁处理的说明。[GLM 安全审核文档](https://docs.bigmodel.cn/cn/guide/platform/securityaudit)
|
||||
|
||||
OpenAI 把 `content_filter` 定义为 Chat Completions 的一种 finish reason,并在安全最佳实践中推荐对终端用户使用 `safety_identifier`,以便在违反策略时定位到具体用户而不是只看到一个共享的 API key。OpenAI 的帮助文档还说明,重复违反使用政策可能导致账号被停用。 [Safety best practices](https://developers.openai.com/api/docs/guides/safety-best-practices/) [Why Was My OpenAI Account Deactivated?](https://help.openai.com/en/articles/10562188)
|
||||
|
||||
Anthropic 在 Claude 流式拒绝说明中明确区分了普通停止和安全拒绝:当 streaming classifier 介入时,响应可以带有 `stop_reason="refusal"`。它同时建议应用不要把被拒绝内容继续塞回下一轮上下文,而应重置对话、改写提示或缩小任务范围。Anthropic AUP 也说明,它可以限制请求、拦截或修改输出,并在必要时暂停或终止访问。[Handle streaming refusals](https://platform.claude.com/docs/en/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals) [Acceptable Use Policy](https://www.anthropic.com/legal/aup)
|
||||
|
||||
Gemini 的安全文档则强调另一种形态:prompt 可能在生成前被拦截,candidate 也可能在生成后被过滤;当 response candidate 被安全策略拦下时,可以看到 `finishReason=SAFETY`,但不会拿到被拦截内容本身。Gemini API 的使用政策还说明,abuse monitoring 会覆盖 prompts 和 outputs,并列出了逐步升级的处置动作。[Gemini safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) [Gemini API Additional Terms of Service](https://ai.google.dev/gemini-api/terms)
|
||||
|
||||
DeepSeek 的 API 文档把 `content_filter` 列为 chat completion finish reason,并把请求里的 `user` 字段说明为有助于内容安全审核。它的 FAQ 也说明,潜在违反使用规范的场景可能触发临时暂停流程。[Create Chat Completion](https://api-docs.deepseek.com/api/create-chat-completion) [DeepSeek FAQ](https://api-docs.deepseek.com/faq)
|
||||
|
||||
还有一些 provider 会在更早或更外层的位置拦截请求。例如 Azure OpenAI 提醒应用检查 `finish_reason`,因为 `content_filter` 可能让 completion 不完整;Amazon Bedrock Guardrails 可在响应中返回 `stopReason="guardrail_intervened"`;阿里云百炼的安全护栏示例里,输出侧拦截也可能直接表现为 `DataInspectionFailed` 错误。它们共同说明了一点:安全拦截既可能是模型消息里的停止信号,也可能是 API 层错误,应用不能只准备一种处理路径。[Azure OpenAI content filtering](https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/content-filter) [Amazon Bedrock Guardrails](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-use-converse-api.html)
|
||||
|
||||
## DeerFlow 在这条边界上做什么
|
||||
|
||||
`SafetyFinishReasonMiddleware` 的职责很窄:它不替代 provider 的内容审核,也不把所有拒答都改写成同一种错误。它只在下面两个条件同时成立时介入:
|
||||
|
||||
1. provider 响应携带了已配置的安全中止信号。
|
||||
2. 当前 `AIMessage` 仍包含非空的 `tool_calls`。
|
||||
|
||||
介入后,它会:
|
||||
|
||||
1. 清空结构化工具调用以及 raw provider metadata 中残留的工具调用字段。
|
||||
2. 阻止这些工具参数进入工具节点执行。
|
||||
3. 保留已经生成的部分文本,并追加面向用户的说明。
|
||||
4. 记录 detector、reason 字段、reason 值、被压制的工具名和数量。
|
||||
5. 避免把可能正是被过滤内容的工具参数再次写入审计事件。
|
||||
|
||||
这意味着安全中止信号的优先级高于“响应里看到了工具调用”。对于 Agent 运行时,这是更保守也更正确的控制流。
|
||||
|
||||
## 默认配置
|
||||
|
||||
默认情况下只需要启用 `safety_finish_reason`:
|
||||
|
||||
```yaml
|
||||
safety_finish_reason:
|
||||
enabled: true
|
||||
```
|
||||
|
||||
不显式配置 `detectors` 时,DeerFlow 使用内置 detector 集合:
|
||||
|
||||
| Detector | 默认匹配 |
|
||||
| --- | --- |
|
||||
| `OpenAICompatibleContentFilterDetector` | `finish_reason="content_filter"` |
|
||||
| `AnthropicRefusalDetector` | `stop_reason="refusal"` |
|
||||
| `GeminiSafetyDetector` | Gemini 安全相关 `finish_reason`,例如 `SAFETY`、`BLOCKLIST`、`PROHIBITED_CONTENT`、`SPII`、`RECITATION` |
|
||||
|
||||
这个默认集合覆盖了 DeerFlow 常见的 OpenAI-compatible provider、Anthropic 和 Gemini 路径。它不会把普通 `finish_reason="tool_calls"` 当成安全中止,也不会把 `length`、`max_tokens` 之类的长度截断混入安全分类。
|
||||
|
||||
## 例子:为 GLM 扩展流式安全结束信号
|
||||
|
||||
GLM 流式响应使用的安全结束值是 `sensitive`。如果当前适配层把这个值保留在 `AIMessage.response_metadata.finish_reason` 或 `additional_kwargs.finish_reason` 中,可以通过 OpenAI-compatible detector 的可配置 finish reason 集合接入:
|
||||
|
||||
```yaml
|
||||
safety_finish_reason:
|
||||
enabled: true
|
||||
detectors:
|
||||
- use: deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector
|
||||
config:
|
||||
finish_reasons: ["content_filter", "sensitive"]
|
||||
|
||||
- use: deerflow.agents.middlewares.safety_termination_detectors:AnthropicRefusalDetector
|
||||
|
||||
- use: deerflow.agents.middlewares.safety_termination_detectors:GeminiSafetyDetector
|
||||
```
|
||||
|
||||
这里有两个配置细节需要注意。
|
||||
|
||||
第一,`detectors` 是覆盖默认列表,不是向默认列表追加一项。因此为了给 GLM 增加 `sensitive`,示例里也保留了 Anthropic 和 Gemini detector。
|
||||
|
||||
第二,这个 middleware 处理的是已经进入模型消息的安全结束信号。如果 provider 在 API 层直接返回安全审核错误,例如 GLM 同步调用的安全审核错误码,调用方还需要在 LLM/API 错误处理路径里单独处理它。
|
||||
|
||||
|
||||
## 边界
|
||||
|
||||
`SafetyFinishReasonMiddleware` 解决的是一个明确的 Agent 控制流问题,不是完整的内容安全方案。它不替代 moderation、权限隔离、用户治理或 provider 自身的审核策略,也不覆盖每一种普通文本拒答。
|
||||
|
||||
但这一条边界值得单独守住:当 provider 已经因为安全原因停下输出时,工具型 Agent 应把这一轮视为被中断的输出,而不是可执行的工具意图。
|
||||
@@ -397,6 +397,50 @@ export function stripUploadedFilesTag(content: string): string {
|
||||
.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* Tag names that backend middlewares wrap around internal payloads before
|
||||
* letting them ride along inside LangGraph message ``content``.
|
||||
*
|
||||
* These markers are *not* user copy — they come from:
|
||||
*
|
||||
* - ``UploadsMiddleware`` → ``<uploaded_files>``
|
||||
* - ``DynamicContextMiddleware`` → ``<system-reminder>`` (carrying
|
||||
* ``<memory>`` / ``<current_date>`` inside)
|
||||
* - ``TodoListMiddleware`` / ``LoopDetectionMiddleware`` style reminders
|
||||
* live in ``hide_from_ui`` HumanMessages, but their inner payload uses
|
||||
* the same tag vocabulary.
|
||||
*
|
||||
* The primary export filter is {@link isHiddenFromUIMessage}. This list is
|
||||
* the defence-in-depth strip for any message that — by middleware bug,
|
||||
* provider quirk, or merge-conflict regression — slips through without
|
||||
* its ``hide_from_ui`` flag set.
|
||||
*/
|
||||
export const INTERNAL_MARKER_TAGS = [
|
||||
"uploaded_files",
|
||||
"system-reminder",
|
||||
"memory",
|
||||
"current_date",
|
||||
] as const;
|
||||
|
||||
const INTERNAL_MARKER_RE = new RegExp(
|
||||
`<(${INTERNAL_MARKER_TAGS.join("|")})>[\\s\\S]*?</\\1>`,
|
||||
"g",
|
||||
);
|
||||
|
||||
/**
|
||||
* Strip every known backend-injected marker from message content.
|
||||
*
|
||||
* Intended for the chat export path where a marker leaking through is a
|
||||
* privacy regression. UI render paths should keep using
|
||||
* {@link stripUploadedFilesTag} — they receive ``hide_from_ui`` messages
|
||||
* via a separate filter and the narrower function avoids stripping content
|
||||
* a user might legitimately type into a meta-discussion (e.g. asking the
|
||||
* model about its own ``<memory>`` system).
|
||||
*/
|
||||
export function stripInternalMarkers(content: string): string {
|
||||
return content.replace(INTERNAL_MARKER_RE, "").trim();
|
||||
}
|
||||
|
||||
export function parseUploadedFiles(content: string): FileInMessage[] {
|
||||
// Match <uploaded_files>...</uploaded_files> tag
|
||||
const uploadedFilesRegex = /<uploaded_files>([\s\S]*?)<\/uploaded_files>/;
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import type { Subtask } from "./types";
|
||||
|
||||
export type SubtaskStatus = Subtask["status"];
|
||||
|
||||
export interface SubtaskResultUpdate {
|
||||
status: SubtaskStatus;
|
||||
result?: string;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Prefix strings the backend `task` tool writes into its result `content`.
|
||||
*
|
||||
* These values are not user-facing copy — they are part of the
|
||||
* backend↔frontend contract defined in
|
||||
* `backend/packages/harness/deerflow/tools/builtins/task_tool.py` (returned
|
||||
* from the tool body) and in
|
||||
* `backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py`
|
||||
* (wrapper for tool exceptions). Any change here must be paired with the
|
||||
* matching backend change. Exported so a future structured-status migration
|
||||
* can reference the same values from one place.
|
||||
*
|
||||
* `task_tool.py` also emits three `Error:` strings for pre-execution failures
|
||||
* — unknown subagent type, host-bash disabled, and "task disappeared from
|
||||
* background tasks". They are handled by {@link ERROR_WRAPPER_PATTERN}
|
||||
* rather than dedicated prefixes because the wrapper already produces
|
||||
* exactly the right `terminal failed` shape.
|
||||
*/
|
||||
export const SUCCESS_PREFIX = "Task Succeeded. Result:";
|
||||
export const FAILURE_PREFIX = "Task failed.";
|
||||
export const TIMEOUT_PREFIX = "Task timed out";
|
||||
export const CANCELLED_PREFIX = "Task cancelled by user.";
|
||||
export const POLLING_TIMEOUT_PREFIX = "Task polling timed out";
|
||||
export const ERROR_WRAPPER_PATTERN = /^Error\b/i;
|
||||
|
||||
/**
|
||||
* Map a `task` tool result string to a {@link SubtaskStatus}.
|
||||
*
|
||||
* Bytedance/deer-flow issue #3107 BUG-007: parent-visible task tool errors do
|
||||
* not always start with one of the three legacy prefixes (e.g. when
|
||||
* `ToolErrorHandlingMiddleware` wraps an exception as
|
||||
* `Error: Tool 'task' failed ...`). Treat any leading `Error:` token as a
|
||||
* terminal failure so subtask cards stop being stuck on "in_progress".
|
||||
*
|
||||
* Returning `in_progress` is the **deliberate** fallback for content that
|
||||
* matches none of the known prefixes. LangChain only ever emits a
|
||||
* `ToolMessage` once the tool itself has returned (success or wrapped
|
||||
* exception), so an unknown shape means "the contract changed underneath us"
|
||||
* — surfacing it as still-running prompts the operator to investigate, where
|
||||
* eagerly marking it terminal-failed would mask the drift.
|
||||
*/
|
||||
export function parseSubtaskResult(text: string): SubtaskResultUpdate {
|
||||
const trimmed = text.trim();
|
||||
|
||||
if (trimmed.startsWith(SUCCESS_PREFIX)) {
|
||||
return {
|
||||
status: "completed",
|
||||
result: trimmed.slice(SUCCESS_PREFIX.length).trim(),
|
||||
};
|
||||
}
|
||||
|
||||
if (trimmed.startsWith(FAILURE_PREFIX)) {
|
||||
return {
|
||||
status: "failed",
|
||||
error: trimmed.slice(FAILURE_PREFIX.length).trim(),
|
||||
};
|
||||
}
|
||||
|
||||
if (trimmed.startsWith(TIMEOUT_PREFIX)) {
|
||||
return { status: "failed", error: trimmed };
|
||||
}
|
||||
|
||||
if (trimmed.startsWith(CANCELLED_PREFIX)) {
|
||||
return { status: "failed", error: trimmed };
|
||||
}
|
||||
|
||||
if (trimmed.startsWith(POLLING_TIMEOUT_PREFIX)) {
|
||||
return { status: "failed", error: trimmed };
|
||||
}
|
||||
|
||||
// ToolErrorHandlingMiddleware-style wrapper, or any other terminal error
|
||||
// signal the backend forwards to the lead agent.
|
||||
if (ERROR_WRAPPER_PATTERN.test(trimmed)) {
|
||||
return { status: "failed", error: trimmed };
|
||||
}
|
||||
|
||||
return { status: "in_progress" };
|
||||
}
|
||||
@@ -5,16 +5,53 @@ import {
|
||||
extractReasoningContentFromMessage,
|
||||
hasContent,
|
||||
hasToolCalls,
|
||||
stripUploadedFilesTag,
|
||||
isHiddenFromUIMessage,
|
||||
stripInternalMarkers,
|
||||
} from "../messages/utils";
|
||||
|
||||
import type { AgentThread } from "./types";
|
||||
import { titleOfThread } from "./utils";
|
||||
|
||||
/**
|
||||
* Optional debug switches for advanced exports.
|
||||
*
|
||||
* Bytedance/deer-flow issue #3107 BUG-006 explicitly prescribes that the
|
||||
* default export includes only the user-visible transcript and excludes
|
||||
* thinking/reasoning content, tool calls, tool results, hidden messages,
|
||||
* memory injection, and `<system-reminder>` payloads. These options let a
|
||||
* future "debug export" surface re-include any of those categories without
|
||||
* forking the formatter. They are not currently wired to any UI control —
|
||||
* callers that want them must construct the options object explicitly.
|
||||
*/
|
||||
export interface ExportOptions {
|
||||
includeReasoning?: boolean;
|
||||
includeToolCalls?: boolean;
|
||||
includeToolMessages?: boolean;
|
||||
includeHidden?: boolean;
|
||||
}
|
||||
|
||||
function visibleMessages(
|
||||
messages: Message[],
|
||||
options: ExportOptions,
|
||||
): Message[] {
|
||||
return messages.filter((message) => {
|
||||
if (!options.includeHidden && isHiddenFromUIMessage(message)) {
|
||||
return false;
|
||||
}
|
||||
if (!options.includeToolMessages && message.type === "tool") {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
function formatMessageContent(message: Message): string {
|
||||
const text = extractContentFromMessage(message);
|
||||
if (!text) return "";
|
||||
return stripUploadedFilesTag(text);
|
||||
// Defence-in-depth: even if a middleware-injected marker slipped through
|
||||
// the `hide_from_ui` filter, scrub every known internal tag before the
|
||||
// content lands in a user-visible export file.
|
||||
return stripInternalMarkers(text);
|
||||
}
|
||||
|
||||
function formatToolCalls(message: Message): string {
|
||||
@@ -26,6 +63,7 @@ function formatToolCalls(message: Message): string {
|
||||
export function formatThreadAsMarkdown(
|
||||
thread: AgentThread,
|
||||
messages: Message[],
|
||||
options: ExportOptions = {},
|
||||
): string {
|
||||
const title = titleOfThread(thread);
|
||||
const createdAt = thread.created_at
|
||||
@@ -41,16 +79,20 @@ export function formatThreadAsMarkdown(
|
||||
"",
|
||||
];
|
||||
|
||||
for (const message of messages) {
|
||||
for (const message of visibleMessages(messages, options)) {
|
||||
if (message.type === "human") {
|
||||
const content = formatMessageContent(message);
|
||||
if (content) {
|
||||
lines.push(`## 🧑 User`, "", content, "", "---", "");
|
||||
}
|
||||
} else if (message.type === "ai") {
|
||||
const reasoning = extractReasoningContentFromMessage(message);
|
||||
const reasoning = options.includeReasoning
|
||||
? extractReasoningContentFromMessage(message)
|
||||
: undefined;
|
||||
const content = formatMessageContent(message);
|
||||
const toolCalls = formatToolCalls(message);
|
||||
const toolCalls = options.includeToolCalls
|
||||
? formatToolCalls(message)
|
||||
: "";
|
||||
|
||||
if (!content && !toolCalls && !reasoning) continue;
|
||||
|
||||
@@ -83,23 +125,65 @@ export function formatThreadAsMarkdown(
|
||||
return lines.join("\n").trimEnd() + "\n";
|
||||
}
|
||||
|
||||
interface JSONExportMessage {
|
||||
type: Message["type"];
|
||||
id: string | undefined;
|
||||
content: string;
|
||||
reasoning?: string;
|
||||
tool_calls?: unknown;
|
||||
}
|
||||
|
||||
function buildJSONMessage(
|
||||
msg: Message,
|
||||
options: ExportOptions,
|
||||
): JSONExportMessage | null {
|
||||
// Run the same sanitiser the Markdown path uses so the JSON `content`
|
||||
// field never carries inline `<think>...</think>` wrappers, content-array
|
||||
// thinking blocks, `<uploaded_files>` markers, or other internal payloads.
|
||||
const content = formatMessageContent(msg);
|
||||
const reasoning =
|
||||
options.includeReasoning && msg.type === "ai"
|
||||
? (extractReasoningContentFromMessage(msg) ?? undefined)
|
||||
: undefined;
|
||||
const toolCalls =
|
||||
options.includeToolCalls &&
|
||||
msg.type === "ai" &&
|
||||
"tool_calls" in msg &&
|
||||
msg.tool_calls?.length
|
||||
? msg.tool_calls
|
||||
: undefined;
|
||||
|
||||
// Drop rows with no exportable payload (empty content + no opted-in
|
||||
// reasoning / tool_calls). Uses falsy semantics so `reasoning: ""` (the
|
||||
// empty string ``extractReasoningContentFromMessage`` can hand back) is
|
||||
// treated the same way Markdown's `!reasoning` continue does — otherwise
|
||||
// an opted-in but empty reasoning field would leak as `{reasoning: ""}`.
|
||||
if (!content && !reasoning && !toolCalls) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
type: msg.type,
|
||||
id: msg.id,
|
||||
content,
|
||||
...(reasoning !== undefined ? { reasoning } : {}),
|
||||
...(toolCalls !== undefined ? { tool_calls: toolCalls } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
export function formatThreadAsJSON(
|
||||
thread: AgentThread,
|
||||
messages: Message[],
|
||||
options: ExportOptions = {},
|
||||
): string {
|
||||
const exportData = {
|
||||
title: titleOfThread(thread),
|
||||
thread_id: thread.thread_id,
|
||||
created_at: thread.created_at,
|
||||
exported_at: new Date().toISOString(),
|
||||
messages: messages.map((msg) => ({
|
||||
type: msg.type,
|
||||
id: msg.id,
|
||||
content: typeof msg.content === "string" ? msg.content : msg.content,
|
||||
...(msg.type === "ai" && msg.tool_calls?.length
|
||||
? { tool_calls: msg.tool_calls }
|
||||
: {}),
|
||||
})),
|
||||
messages: visibleMessages(messages, options)
|
||||
.map((msg) => buildJSONMessage(msg, options))
|
||||
.filter((m): m is JSONExportMessage => m !== null),
|
||||
};
|
||||
return JSON.stringify(exportData, null, 2);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
|
||||
import { parseSubtaskResult } from "@/core/tasks/subtask-result";
|
||||
|
||||
describe("parseSubtaskResult", () => {
|
||||
it("recognises the standard success prefix", () => {
|
||||
const parsed = parseSubtaskResult(
|
||||
"Task Succeeded. Result: investigated and produced a 3-page report",
|
||||
);
|
||||
expect(parsed.status).toBe("completed");
|
||||
expect(parsed.result).toBe("investigated and produced a 3-page report");
|
||||
});
|
||||
|
||||
it("recognises the standard failure prefix", () => {
|
||||
const parsed = parseSubtaskResult(
|
||||
"Task failed. underlying tool raised RuntimeError",
|
||||
);
|
||||
expect(parsed.status).toBe("failed");
|
||||
expect(parsed.error).toBe("underlying tool raised RuntimeError");
|
||||
});
|
||||
|
||||
it("recognises the standard timeout prefix", () => {
|
||||
const parsed = parseSubtaskResult("Task timed out after 900s");
|
||||
expect(parsed.status).toBe("failed");
|
||||
expect(parsed.error).toBe("Task timed out after 900s");
|
||||
});
|
||||
|
||||
it("recognises the cancelled-by-user prefix", () => {
|
||||
// bytedance/deer-flow#3131 review: this is one of the five terminal
|
||||
// strings task_tool.py actually emits — the previous cut treated it as
|
||||
// unrecognised content and pushed the card back to in_progress.
|
||||
const parsed = parseSubtaskResult("Task cancelled by user.");
|
||||
expect(parsed.status).toBe("failed");
|
||||
expect(parsed.error).toBe("Task cancelled by user.");
|
||||
});
|
||||
|
||||
it("recognises the polling-timed-out prefix", () => {
|
||||
// Emitted by task_tool when the background polling loop runs out of
|
||||
// budget waiting for the subagent to reach a terminal state.
|
||||
const parsed = parseSubtaskResult(
|
||||
"Task polling timed out after 15 minutes. This may indicate the background task is stuck. Status: RUNNING",
|
||||
);
|
||||
expect(parsed.status).toBe("failed");
|
||||
expect(parsed.error).toContain("polling timed out");
|
||||
});
|
||||
|
||||
it("recognises polling-timed-out with different durations", () => {
|
||||
// `task_tool` emits `Task polling timed out after {N} minutes` where N
|
||||
// varies with the configured subagent timeout. Guard against the regex
|
||||
// accidentally being pinned to a specific number.
|
||||
for (const n of [1, 5, 60]) {
|
||||
const parsed = parseSubtaskResult(
|
||||
`Task polling timed out after ${n} minutes. Status: RUNNING`,
|
||||
);
|
||||
expect(parsed.status).toBe("failed");
|
||||
}
|
||||
});
|
||||
|
||||
it("trims whitespace around cancelled and polling-timed-out prefixes", () => {
|
||||
// Streaming chunks sometimes arrive with leading/trailing newlines.
|
||||
expect(parseSubtaskResult(" Task cancelled by user. \n").status).toBe(
|
||||
"failed",
|
||||
);
|
||||
expect(
|
||||
parseSubtaskResult("\n\nTask polling timed out after 3 minutes").status,
|
||||
).toBe("failed");
|
||||
});
|
||||
|
||||
it("recognises task_tool pre-execution Error: returns via the wrapper", () => {
|
||||
// `task_tool.py` returns three `Error:` strings for unknown subagent
|
||||
// type, host-bash disabled, and "task disappeared". They share the
|
||||
// ERROR_WRAPPER_PATTERN, not a dedicated prefix, so this guards
|
||||
// against a refactor splitting them off.
|
||||
for (const text of [
|
||||
"Error: Unknown subagent type 'foo'. Available: bash, general-purpose",
|
||||
"Error: Host bash subagent is disabled by configuration",
|
||||
"Error: Task 1234 disappeared from background tasks",
|
||||
]) {
|
||||
expect(parseSubtaskResult(text).status).toBe("failed");
|
||||
}
|
||||
});
|
||||
|
||||
it("treats middleware-wrapped tool errors as terminal failures", () => {
|
||||
// bytedance/deer-flow issue #3107 BUG-007: the parent-visible ToolMessage
|
||||
// produced by ToolErrorHandlingMiddleware never matches the three legacy
|
||||
// prefixes, so subtask cards stay stuck on "in_progress".
|
||||
const parsed = parseSubtaskResult(
|
||||
"Error: Tool 'task' failed with TypeError: 'AsyncCallbackManager' object is not iterable. Continue with available context, or choose an alternative tool.",
|
||||
);
|
||||
expect(parsed.status).toBe("failed");
|
||||
expect(parsed.error).toContain("AsyncCallbackManager");
|
||||
});
|
||||
|
||||
it("treats any other Error: prefix as a terminal failure", () => {
|
||||
const parsed = parseSubtaskResult("Error: subagent worker pool exhausted");
|
||||
expect(parsed.status).toBe("failed");
|
||||
});
|
||||
|
||||
it("keeps unrecognised non-error output as in_progress", () => {
|
||||
// Streaming partial chunks should not flip the card to terminal early.
|
||||
const parsed = parseSubtaskResult("Investigating ...");
|
||||
expect(parsed.status).toBe("in_progress");
|
||||
expect(parsed.error).toBeUndefined();
|
||||
expect(parsed.result).toBeUndefined();
|
||||
});
|
||||
|
||||
it("trims surrounding whitespace before matching prefixes", () => {
|
||||
const parsed = parseSubtaskResult(" Task Succeeded. Result: ok ");
|
||||
expect(parsed.status).toBe("completed");
|
||||
expect(parsed.result).toBe("ok");
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,317 @@
|
||||
import type { Message } from "@langchain/langgraph-sdk";
|
||||
import { describe, expect, it } from "vitest";
|
||||
|
||||
import {
|
||||
formatThreadAsJSON,
|
||||
formatThreadAsMarkdown,
|
||||
} from "@/core/threads/export";
|
||||
import type { AgentThread } from "@/core/threads/types";
|
||||
|
||||
// Bytedance/deer-flow issue #3107 BUG-006: the chat export path bypasses the
|
||||
// UI-level hidden-message filter and emits reasoning content, tool calls, and
|
||||
// any other "internal" payload as if it were part of the user transcript.
|
||||
|
||||
function makeThread(): AgentThread {
|
||||
return {
|
||||
thread_id: "thread-1",
|
||||
created_at: "2026-05-21T00:00:00Z",
|
||||
updated_at: "2026-05-21T00:00:00Z",
|
||||
metadata: { title: "Demo thread" },
|
||||
status: "idle",
|
||||
values: { messages: [] },
|
||||
} as unknown as AgentThread;
|
||||
}
|
||||
|
||||
function human(content: string, extra: Partial<Message> = {}): Message {
|
||||
return {
|
||||
id: `h-${content}`,
|
||||
type: "human",
|
||||
content,
|
||||
...extra,
|
||||
} as Message;
|
||||
}
|
||||
|
||||
function ai(
|
||||
content: string,
|
||||
extra: Partial<Message> & { tool_calls?: unknown } = {},
|
||||
): Message {
|
||||
return {
|
||||
id: `a-${content}`,
|
||||
type: "ai",
|
||||
content,
|
||||
...extra,
|
||||
} as Message;
|
||||
}
|
||||
|
||||
function toolMsg(content: string): Message {
|
||||
return {
|
||||
id: `t-${content}`,
|
||||
type: "tool",
|
||||
content,
|
||||
name: "task",
|
||||
tool_call_id: "call-1",
|
||||
} as unknown as Message;
|
||||
}
|
||||
|
||||
describe("formatThreadAsMarkdown", () => {
|
||||
it("includes plain user and assistant text", () => {
|
||||
const md = formatThreadAsMarkdown(makeThread(), [
|
||||
human("hello"),
|
||||
ai("hi there"),
|
||||
]);
|
||||
expect(md).toContain("hello");
|
||||
expect(md).toContain("hi there");
|
||||
});
|
||||
|
||||
it("drops messages marked hide_from_ui", () => {
|
||||
const hidden = human("internal system reminder", {
|
||||
additional_kwargs: { hide_from_ui: true },
|
||||
} as Partial<Message>);
|
||||
const md = formatThreadAsMarkdown(makeThread(), [
|
||||
hidden,
|
||||
ai("public answer"),
|
||||
]);
|
||||
expect(md).not.toContain("internal system reminder");
|
||||
expect(md).toContain("public answer");
|
||||
});
|
||||
|
||||
it("does not emit reasoning_content by default", () => {
|
||||
const message = ai("final answer", {
|
||||
additional_kwargs: {
|
||||
reasoning_content: "secret chain of thought",
|
||||
},
|
||||
} as Partial<Message>);
|
||||
const md = formatThreadAsMarkdown(makeThread(), [message]);
|
||||
expect(md).not.toContain("secret chain of thought");
|
||||
expect(md).not.toContain("Thinking");
|
||||
});
|
||||
|
||||
it("does not emit tool calls by default", () => {
|
||||
const message = ai("ok", {
|
||||
tool_calls: [{ id: "1", name: "task", args: { description: "do work" } }],
|
||||
} as Partial<Message>);
|
||||
const md = formatThreadAsMarkdown(makeThread(), [message]);
|
||||
expect(md).not.toContain("**Tool:**");
|
||||
expect(md).not.toContain("`task`");
|
||||
});
|
||||
|
||||
it("drops tool result messages", () => {
|
||||
const md = formatThreadAsMarkdown(makeThread(), [
|
||||
ai("delegating"),
|
||||
toolMsg("Task Succeeded. Result: confidential"),
|
||||
]);
|
||||
expect(md).not.toContain("confidential");
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatThreadAsMarkdown opt-in flags", () => {
|
||||
it("emits reasoning when includeReasoning is true", () => {
|
||||
const message = ai("final answer", {
|
||||
additional_kwargs: {
|
||||
reasoning_content: "step-by-step chain of thought",
|
||||
},
|
||||
} as Partial<Message>);
|
||||
const md = formatThreadAsMarkdown(makeThread(), [message], {
|
||||
includeReasoning: true,
|
||||
});
|
||||
expect(md).toContain("step-by-step chain of thought");
|
||||
expect(md).toContain("Thinking");
|
||||
});
|
||||
|
||||
it("emits tool call rows when includeToolCalls is true", () => {
|
||||
const message = ai("ok", {
|
||||
tool_calls: [{ id: "1", name: "task", args: { description: "do work" } }],
|
||||
} as Partial<Message>);
|
||||
const md = formatThreadAsMarkdown(makeThread(), [message], {
|
||||
includeToolCalls: true,
|
||||
});
|
||||
expect(md).toContain("**Tool:**");
|
||||
expect(md).toContain("`task`");
|
||||
});
|
||||
|
||||
it("keeps hidden messages when includeHidden is true", () => {
|
||||
const hidden = human("internal reminder", {
|
||||
additional_kwargs: { hide_from_ui: true },
|
||||
} as Partial<Message>);
|
||||
const md = formatThreadAsMarkdown(makeThread(), [hidden], {
|
||||
includeHidden: true,
|
||||
});
|
||||
expect(md).toContain("internal reminder");
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatThreadAsJSON opt-in flags", () => {
|
||||
it("emits tool_calls field when includeToolCalls is true", () => {
|
||||
const message = ai("ok", {
|
||||
tool_calls: [{ id: "1", name: "task", args: { description: "x" } }],
|
||||
} as Partial<Message>);
|
||||
const raw = formatThreadAsJSON(makeThread(), [message], {
|
||||
includeToolCalls: true,
|
||||
});
|
||||
expect(raw).toContain("tool_calls");
|
||||
expect(raw).toContain('"task"');
|
||||
});
|
||||
|
||||
it("keeps tool messages when includeToolMessages is true", () => {
|
||||
const raw = formatThreadAsJSON(
|
||||
makeThread(),
|
||||
[toolMsg("Task Succeeded. Result: keep me")],
|
||||
{ includeToolMessages: true },
|
||||
);
|
||||
const parsed = JSON.parse(raw) as { messages: { type: string }[] };
|
||||
expect(parsed.messages.some((m) => m.type === "tool")).toBe(true);
|
||||
expect(raw).toContain("keep me");
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatThreadAsJSON", () => {
|
||||
it("strips hidden messages, tool messages, reasoning, and tool calls", () => {
|
||||
const messages = [
|
||||
human("hello"),
|
||||
human("secret reminder", {
|
||||
additional_kwargs: { hide_from_ui: true },
|
||||
} as Partial<Message>),
|
||||
ai("answer", {
|
||||
additional_kwargs: {
|
||||
reasoning_content: "secret reasoning",
|
||||
},
|
||||
tool_calls: [{ id: "1", name: "task", args: {} }],
|
||||
} as Partial<Message>),
|
||||
toolMsg("internal trace"),
|
||||
];
|
||||
const raw = formatThreadAsJSON(makeThread(), messages);
|
||||
const parsed = JSON.parse(raw) as {
|
||||
messages: { type: string; tool_calls?: unknown[] }[];
|
||||
};
|
||||
|
||||
expect(parsed.messages).toHaveLength(2);
|
||||
expect(parsed.messages.every((m) => m.type !== "tool")).toBe(true);
|
||||
expect(raw).not.toContain("secret reminder");
|
||||
expect(raw).not.toContain("secret reasoning");
|
||||
expect(raw).not.toContain("internal trace");
|
||||
expect(raw).not.toContain("tool_calls");
|
||||
});
|
||||
|
||||
it("strips inline <think>...</think> wrappers from content", () => {
|
||||
// bytedance/deer-flow#3131 review: JSON export must run the same
|
||||
// sanitiser the Markdown path uses so inline reasoning never leaks
|
||||
// even when `includeReasoning` is left at its default false.
|
||||
const message = ai("<think>internal monologue</think>visible answer", {
|
||||
id: "ai-1",
|
||||
} as Partial<Message>);
|
||||
const raw = formatThreadAsJSON(makeThread(), [message]);
|
||||
expect(raw).not.toContain("internal monologue");
|
||||
expect(raw).not.toContain("<think>");
|
||||
expect(raw).toContain("visible answer");
|
||||
});
|
||||
|
||||
it("strips content-array thinking blocks from content", () => {
|
||||
const message = ai("placeholder", {
|
||||
id: "ai-2",
|
||||
content: [
|
||||
{ type: "thinking", thinking: "hidden reasoning step" },
|
||||
{ type: "text", text: "final visible text" },
|
||||
],
|
||||
} as unknown as Partial<Message>);
|
||||
const raw = formatThreadAsJSON(makeThread(), [message]);
|
||||
expect(raw).not.toContain("hidden reasoning step");
|
||||
expect(raw).toContain("final visible text");
|
||||
});
|
||||
|
||||
it("strips <uploaded_files> markers from content", () => {
|
||||
const message = human(
|
||||
"real prompt\n<uploaded_files>\n/mnt/user-data/uploads/secret.pdf\n</uploaded_files>",
|
||||
{ id: "h-clean" } as Partial<Message>,
|
||||
);
|
||||
const raw = formatThreadAsJSON(makeThread(), [message]);
|
||||
expect(raw).not.toContain("<uploaded_files>");
|
||||
expect(raw).not.toContain("secret.pdf");
|
||||
expect(raw).toContain("real prompt");
|
||||
});
|
||||
|
||||
it("drops AI messages that sanitise to empty content", () => {
|
||||
// Pure-reasoning AI fragments (no visible text, no tool calls) should
|
||||
// not survive as `{content: ""}` rows in the export.
|
||||
const message = ai("<think>only thinking, no answer</think>", {
|
||||
id: "ai-3",
|
||||
} as Partial<Message>);
|
||||
const raw = formatThreadAsJSON(makeThread(), [message]);
|
||||
const parsed = JSON.parse(raw) as { messages: unknown[] };
|
||||
expect(parsed.messages).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("strips <system-reminder>/<memory>/<current_date> as defence in depth", () => {
|
||||
// Primary protection is `isHiddenFromUIMessage` filtering the whole
|
||||
// hidden HumanMessage. If a regression strips the `hide_from_ui` flag
|
||||
// (or the marker leaks into an otherwise-visible message), the
|
||||
// sanitiser must still scrub the payload before export.
|
||||
const leaky = human("real user text", {
|
||||
id: "leak-1",
|
||||
content:
|
||||
"<system-reminder>\n<memory>secret fact A</memory>\n<current_date>2026-01-01, Tuesday</current_date>\n</system-reminder>\nreal user text",
|
||||
// Deliberately *not* setting hide_from_ui to model the regression
|
||||
// case the defence-in-depth strip is guarding against.
|
||||
} as unknown as Partial<Message>);
|
||||
const raw = formatThreadAsJSON(makeThread(), [leaky]);
|
||||
expect(raw).not.toContain("<system-reminder>");
|
||||
expect(raw).not.toContain("<memory>");
|
||||
expect(raw).not.toContain("<current_date>");
|
||||
expect(raw).not.toContain("secret fact A");
|
||||
expect(raw).toContain("real user text");
|
||||
});
|
||||
|
||||
it("sanitises tool message content when includeToolMessages is true", () => {
|
||||
const message = {
|
||||
id: "t-leak",
|
||||
type: "tool",
|
||||
content:
|
||||
"Task Succeeded. Result: payload\n<uploaded_files>\n/mnt/user-data/uploads/secret.pdf\n</uploaded_files>",
|
||||
name: "task",
|
||||
tool_call_id: "call-leak",
|
||||
} as unknown as Message;
|
||||
|
||||
const raw = formatThreadAsJSON(makeThread(), [message], {
|
||||
includeToolMessages: true,
|
||||
});
|
||||
expect(raw).toContain("Task Succeeded");
|
||||
expect(raw).not.toContain("<uploaded_files>");
|
||||
expect(raw).not.toContain("secret.pdf");
|
||||
});
|
||||
|
||||
it("preserves text and image_url parts in mixed content arrays", () => {
|
||||
// `extractContentFromMessage` keeps `text` and `image_url` parts and
|
||||
// drops `thinking` parts. The JSON export must agree with that
|
||||
// contract.
|
||||
const message = ai("placeholder", {
|
||||
id: "ai-mixed",
|
||||
content: [
|
||||
{ type: "thinking", thinking: "internal reasoning" },
|
||||
{ type: "text", text: "user-visible answer" },
|
||||
{
|
||||
type: "image_url",
|
||||
image_url: { url: "https://example.invalid/cat.png" },
|
||||
},
|
||||
],
|
||||
} as unknown as Partial<Message>);
|
||||
const raw = formatThreadAsJSON(makeThread(), [message]);
|
||||
expect(raw).toContain("user-visible answer");
|
||||
expect(raw).toContain("https://example.invalid/cat.png");
|
||||
expect(raw).not.toContain("internal reasoning");
|
||||
});
|
||||
|
||||
it("drops opted-in empty reasoning rather than emit reasoning: ''", () => {
|
||||
// `extractReasoningContentFromMessage` can legitimately hand back ""
|
||||
// for an AI message that has no reasoning content. The export must
|
||||
// mirror the Markdown path's `!reasoning` `continue` and drop the row
|
||||
// instead of leaking `{reasoning: ""}`.
|
||||
const message = ai("", {
|
||||
id: "ai-empty-reasoning",
|
||||
additional_kwargs: { reasoning_content: "" },
|
||||
} as Partial<Message>);
|
||||
const raw = formatThreadAsJSON(makeThread(), [message], {
|
||||
includeReasoning: true,
|
||||
});
|
||||
const parsed = JSON.parse(raw) as { messages: unknown[] };
|
||||
expect(parsed.messages).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user