Merge branch 'main' into fix-3127

This commit is contained in:
Willem Jiang
2026-05-22 21:56:04 +08:00
committed by GitHub
57 changed files with 4981 additions and 195 deletions
+12
View File
@@ -184,6 +184,18 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart.
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state``lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work (logging level, channels, `langgraph_runtime` engines) and passes it explicitly to `langgraph_runtime(app, startup_config)`. Infrastructure fields are **restart-required**:
| Field | Why a restart is required |
|---|---|
| `database.*` | `init_engine_from_config()` runs once during `langgraph_runtime()` startup; the SQLAlchemy engine holds the connection pool. |
| `checkpointer.*` (including SQLite WAL/journal settings) | `make_checkpointer()` binds the persistent checkpointer once at startup. |
| `run_events.*` | `make_run_event_store()` selects memory- vs. SQL-backed implementation at startup. |
| `stream_bridge.*` | `make_stream_bridge()` constructs the bridge object once. |
| `sandbox.use` | `get_sandbox_provider()` caches the provider singleton (`_default_sandbox_provider`); a new class path takes effect only on next process start. |
| `log_level` | `apply_logging_level()` is called only in `app.py` startup; it mutates the root logger's level, and `get_app_config()` returning a fresh `AppConfig` does not retrigger it. |
| `channels.*` IM platform credentials | `start_channel_service()` is invoked once during startup; live channels are not rebuilt on config change. |
Configuration priority:
1. Explicit `config_path` argument
2. `DEER_FLOW_CONFIG_PATH` environment variable
+11 -5
View File
@@ -161,10 +161,16 @@ async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan handler."""
# Load config and check necessary environment variables at startup
# Load config and check necessary environment variables at startup.
# `startup_config` is a local snapshot used only for one-shot bootstrap
# work (logging level, langgraph_runtime engines, channels). Request-time
# config resolution always routes through `get_app_config()` in
# `app/gateway/deps.py::get_config()` so `config.yaml` edits become
# visible without a process restart. We deliberately do NOT cache this
# snapshot on `app.state` to keep that contract enforceable.
try:
app.state.config = get_app_config()
apply_logging_level(app.state.config.log_level)
startup_config = get_app_config()
apply_logging_level(startup_config.log_level)
logger.info("Configuration loaded successfully")
except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}"
@@ -174,7 +180,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
async with langgraph_runtime(app):
async with langgraph_runtime(app, startup_config):
logger.info("LangGraph runtime initialised")
# Check admin bootstrap state and migrate orphan threads after admin exists.
@@ -185,7 +191,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
try:
from app.channels.service import start_channel_service
channel_service = await start_channel_service(app.state.config)
channel_service = await start_channel_service(startup_config)
logger.info("Channel service started: %s", channel_service.get_status())
except Exception:
logger.exception("No IM channels configured or channel service failed to start")
+69 -17
View File
@@ -3,11 +3,21 @@
**Getters** (used by routers): raise 503 when a required dependency is
missing, except ``get_store`` which returns ``None``.
``AppConfig`` is intentionally *not* cached on ``app.state``. Routers and the
run path resolve it through :func:`deerflow.config.app_config.get_app_config`,
which performs mtime-based hot reload, so edits to ``config.yaml`` take
effect on the next request without a process restart. The engines created in
:func:`langgraph_runtime` (stream bridge, persistence, checkpointer, store,
run-event store) accept a ``startup_config`` snapshot — they are
restart-required by design and stay bound to that snapshot to keep the live
process consistent with itself.
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
"""
from __future__ import annotations
import logging
from collections.abc import AsyncGenerator, Callable
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, TypeVar, cast
@@ -15,12 +25,14 @@ from typing import TYPE_CHECKING, TypeVar, cast
from fastapi import FastAPI, HTTPException, Request
from langgraph.types import Checkpointer
from deerflow.config.app_config import AppConfig
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.persistence.feedback import FeedbackRepository
from deerflow.runtime import RunContext, RunManager, StreamBridge
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.runs.store.base import RunStore
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
@@ -30,21 +42,55 @@ if TYPE_CHECKING:
T = TypeVar("T")
def get_config(request: Request) -> AppConfig:
"""Return the app-scoped ``AppConfig`` stored on ``app.state``."""
config = getattr(request.app.state, "config", None)
if config is None:
raise HTTPException(status_code=503, detail="Configuration not available")
return config
def get_config() -> AppConfig:
"""Return the freshest ``AppConfig`` for the current request.
Routes through :func:`deerflow.config.app_config.get_app_config`, which
honours runtime ``ContextVar`` overrides and reloads ``config.yaml`` from
disk when its mtime changes. ``AppConfig`` is not cached on ``app.state``
at all — the only startup-time snapshot lives as a local
``startup_config`` variable inside ``lifespan()`` and is passed
explicitly into :func:`langgraph_runtime` for the engines that are
restart-required by design. Routing every request through
:func:`get_app_config` closes the bytedance/deer-flow issue #3107 BUG-001
split-brain where the worker / lead-agent thread saw a stale startup
snapshot.
Any failure to materialise the config (missing file, permission denied,
YAML parse error, validation error) is reported as 503 — semantically
"the gateway cannot serve requests without a usable configuration" — and
logged with the original exception so operators have something to debug.
"""
try:
return get_app_config()
except Exception as exc: # noqa: BLE001 - request boundary: log and degrade gracefully
logger.exception("Failed to load AppConfig at request time")
raise HTTPException(status_code=503, detail="Configuration not available") from exc
@asynccontextmanager
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGenerator[None, None]:
"""Bootstrap and tear down all LangGraph runtime singletons.
``startup_config`` is the ``AppConfig`` snapshot taken once during
``lifespan()`` for one-shot infrastructure bootstrap. The engines and
stores constructed here (stream bridge, persistence engine, checkpointer,
store, run-event store) are restart-required by design — they hold live
connections, file handles, or singleton providers — so they bind to this
snapshot and survive across `config.yaml` edits. Request-time consumers
must still go through :func:`get_config` for any field that should be
hot-reloadable. See ``backend/CLAUDE.md`` "Config Hot-Reload Boundary".
The matching ``run_events_config`` is frozen onto ``app.state`` so
:func:`get_run_context` pairs a freshly-loaded ``AppConfig`` with the
*startup-time* run-events configuration the underlying ``event_store``
was built from — otherwise the runtime could end up combining a live
new ``run_events_config`` with an event store still bound to the
previous backend.
Usage in ``app.py``::
async with langgraph_runtime(app):
async with langgraph_runtime(app, startup_config):
yield
"""
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
@@ -53,9 +99,7 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
from deerflow.runtime.events.store import make_run_event_store
async with AsyncExitStack() as stack:
config = getattr(app.state, "config", None)
if config is None:
raise RuntimeError("langgraph_runtime() requires app.state.config to be initialized")
config = startup_config
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge(config))
@@ -84,8 +128,12 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
app.state.thread_store = make_thread_store(sf, app.state.store)
# Run event store (has its own factory with config-driven backend selection)
# Run event store. The store and the matching ``run_events_config`` are
# both frozen at startup so ``get_run_context`` does not combine a
# freshly-reloaded ``AppConfig.run_events`` with a store still bound to
# the previous backend.
run_events_config = getattr(config, "run_events", None)
app.state.run_events_config = run_events_config
app.state.run_event_store = make_run_event_store(run_events_config)
# RunManager with store backing for persistence
@@ -139,16 +187,20 @@ def get_thread_store(request: Request) -> ThreadMetaStore:
def get_run_context(request: Request) -> RunContext:
"""Build a :class:`RunContext` from ``app.state`` singletons.
Returns a *base* context with infrastructure dependencies.
Returns a *base* context with infrastructure dependencies. The
``app_config`` field is resolved live so per-run fields (e.g.
``models[*].max_tokens``) follow ``config.yaml`` edits; the
``event_store`` / ``run_events_config`` pair stays frozen to the snapshot
captured in :func:`langgraph_runtime` so callers never see a store bound
to one backend paired with a config pointing at another.
"""
config = get_config(request)
return RunContext(
checkpointer=get_checkpointer(request),
store=get_store(request),
event_store=get_run_event_store(request),
run_events_config=getattr(config, "run_events", None),
run_events_config=getattr(request.app.state, "run_events_config", None),
thread_store=get_thread_store(request),
app_config=config,
app_config=get_config(),
)
+25 -2
View File
@@ -66,6 +66,14 @@ class RunResponse(BaseModel):
multitask_strategy: str = "reject"
created_at: str = ""
updated_at: str = ""
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
llm_call_count: int = 0
lead_agent_tokens: int = 0
subagent_tokens: int = 0
middleware_tokens: int = 0
message_count: int = 0
class ThreadTokenUsageModelBreakdown(BaseModel):
@@ -111,6 +119,14 @@ def _record_to_response(record: RunRecord) -> RunResponse:
multitask_strategy=record.multitask_strategy,
created_at=record.created_at,
updated_at=record.updated_at,
total_input_tokens=record.total_input_tokens,
total_output_tokens=record.total_output_tokens,
total_tokens=record.total_tokens,
llm_call_count=record.llm_call_count,
lead_agent_tokens=record.lead_agent_tokens,
subagent_tokens=record.subagent_tokens,
middleware_tokens=record.middleware_tokens,
message_count=record.message_count,
)
@@ -402,8 +418,15 @@ async def list_run_events(
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
@require_permission("threads", "read", owner_check=True)
async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse:
async def thread_token_usage(
thread_id: str,
request: Request,
include_active: bool = Query(default=False, description="Include running run progress snapshots"),
) -> ThreadTokenUsageResponse:
"""Thread-level token usage aggregation."""
run_store = get_run_store(request)
agg = await run_store.aggregate_tokens_by_thread(thread_id)
if include_active:
agg = await run_store.aggregate_tokens_by_thread(thread_id, include_active=True)
else:
agg = await run_store.aggregate_tokens_by_thread(thread_id)
return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
+27 -12
View File
@@ -15,7 +15,8 @@ from collections.abc import Mapping
from typing import Any
from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage
from langchain_core.messages import BaseMessage
from langchain_core.messages.utils import convert_to_messages
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
from app.gateway.utils import sanitize_log_param
@@ -76,21 +77,35 @@ def normalize_stream_modes(raw: list[str] | str | None) -> list[str]:
def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
"""Convert LangGraph Platform input format to LangChain state dict."""
"""Convert LangGraph Platform input format to LangChain state dict.
Delegates dict→message coercion to ``langchain_core.messages.utils.convert_to_messages``
so that ``additional_kwargs`` (e.g. uploaded-file metadata — gh #3132), ``id``,
``name``, and non-human roles (ai/system/tool) survive unchanged. An earlier
hand-rolled version only forwarded ``content`` and collapsed every role to
``HumanMessage``, which silently stripped frontend-supplied attachments.
Malformed message dicts (missing ``role``/``type``/``content``, unsupported
role, etc.) raise ``HTTPException(400)`` with the offending index, instead
of bubbling up as a 500. The gateway is a system boundary, so per-entry
validation errors are the right shape for clients to retry against.
"""
if raw_input is None:
return {}
messages = raw_input.get("messages")
if messages and isinstance(messages, list):
converted = []
for msg in messages:
if isinstance(msg, dict):
role = msg.get("role", msg.get("type", "user"))
content = msg.get("content", "")
if role in ("user", "human"):
converted.append(HumanMessage(content=content))
else:
# TODO: handle other message types (system, ai, tool)
converted.append(HumanMessage(content=content))
converted: list[Any] = []
for index, msg in enumerate(messages):
if isinstance(msg, BaseMessage):
converted.append(msg)
elif isinstance(msg, dict):
try:
converted.extend(convert_to_messages([msg]))
except (ValueError, TypeError, NotImplementedError) as exc:
raise HTTPException(
status_code=400,
detail=f"Invalid message at input.messages[{index}]: {exc}",
) from exc
else:
converted.append(msg)
return {**raw_input, "messages": converted}
-7
View File
@@ -241,13 +241,6 @@ GET /api/mcp/config
"GITHUB_TOKEN": "***"
},
"description": "GitHub operations"
},
"filesystem": {
"enabled": false,
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-filesystem"],
"description": "File system access"
}
}
}
+14 -2
View File
@@ -14,6 +14,19 @@ DeerFlow supports configurable MCP servers and skills to extend its capabilities
3. Configure each servers command, arguments, and environment variables as needed.
4. Restart the application to load and register MCP tools.
## Filesystem MCP Servers
DeerFlow already provides built-in file tools for thread-scoped workspace access.
Do not add an MCP filesystem server for the same DeerFlow workspace. The
overlapping file tools use different path semantics, which can make LLM tool
selection and file access behavior unstable.
DeerFlow does not currently adapt the MCP Roots mode for filesystem servers. In
particular, it does not publish per-thread MCP roots or map DeerFlow sandbox
paths such as `/mnt/user-data/...` to paths accepted by
`@modelcontextprotocol/server-filesystem`. Use DeerFlow's built-in file tools
for DeerFlow workspace files.
## OAuth Support (HTTP/SSE MCP Servers)
For `http` and `sse` MCP servers, DeerFlow supports OAuth token acquisition and automatic token refresh.
@@ -88,7 +101,6 @@ MCP servers expose tools that are automatically discovered and integrated into D
MCP servers can provide access to:
- **File systems**
- **Databases** (e.g., PostgreSQL)
- **External APIs** (e.g., GitHub, Brave Search)
- **Browser automation** (e.g., Puppeteer)
@@ -97,4 +109,4 @@ MCP servers can provide access to:
## Learn More
For detailed documentation about the Model Context Protocol, visit:
https://modelcontextprotocol.io
https://modelcontextprotocol.io
@@ -29,6 +29,7 @@ from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
@@ -338,6 +339,15 @@ def _build_middlewares(
if custom_middlewares:
middlewares.extend(custom_middlewares)
# SafetyFinishReasonMiddleware — suppress tool execution when the provider
# safety-terminated the response. Registered after custom middlewares so
# that LangChain's reverse-order after_model dispatch runs Safety first;
# cleared tool_calls then flow through Loop/Subagent accounting without
# firing extra alarms. See safety_finish_reason_middleware.py docstring.
safety_config = resolved_app_config.safety_finish_reason
if safety_config.enabled:
middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config))
# ClarificationMiddleware should always be last
middlewares.append(ClarificationMiddleware())
return middlewares
@@ -15,6 +15,7 @@ to the end of the message list as before_model + add_messages reducer would do.
import json
import logging
from collections import defaultdict, deque
from collections.abc import Awaitable, Callable
from typing import override
@@ -109,10 +110,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
This normalizes model-bound causal order before provider serialization while
preserving already-valid transcripts unchanged.
"""
tool_messages_by_id: dict[str, ToolMessage] = {}
tool_messages_by_id: dict[str, deque[ToolMessage]] = defaultdict(deque)
for msg in messages:
if isinstance(msg, ToolMessage):
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
tool_messages_by_id[msg.tool_call_id].append(msg)
tool_call_ids: set[str] = set()
for msg in messages:
@@ -124,7 +125,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
tool_call_ids.add(tc_id)
patched: list = []
consumed_tool_msg_ids: set[str] = set()
patch_count = 0
for msg in messages:
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
@@ -136,13 +136,13 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
for tc in self._message_tool_calls(msg):
tc_id = tc.get("id")
if not tc_id or tc_id in consumed_tool_msg_ids:
if not tc_id:
continue
existing_tool_msg = tool_messages_by_id.get(tc_id)
tool_msg_queue = tool_messages_by_id.get(tc_id)
existing_tool_msg = tool_msg_queue.popleft() if tool_msg_queue else None
if existing_tool_msg is not None:
patched.append(existing_tool_msg)
consumed_tool_msg_ids.add(tc_id)
else:
patched.append(
ToolMessage(
@@ -152,7 +152,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
status="error",
)
)
consumed_tool_msg_ids.add(tc_id)
patch_count += 1
if patched == messages:
@@ -0,0 +1,317 @@
"""Suppress tool execution when the provider safety-terminated the response.
Background — see issue bytedance/deer-flow#3028.
Some providers (OpenAI ``finish_reason='content_filter'``, Anthropic
``stop_reason='refusal'``, Gemini ``finish_reason='SAFETY'`` ...) can stop
generation mid-stream while still returning partially-formed ``tool_calls``.
LangChain's tool router treats any AIMessage with a non-empty ``tool_calls``
field as "go execute these", so half-truncated arguments — e.g. a markdown
``write_file`` that stops in the middle of a sentence — get dispatched as if
they were complete. The agent then sees the truncated file, tries to fix it,
gets filtered again, and loops.
This middleware sits at ``after_model`` and gates that behaviour: when a
configured ``SafetyTerminationDetector`` fires *and* the AIMessage carries
tool calls, we strip the tool calls (both structured and raw provider
payloads), append a user-facing explanation, and stash observability fields
in ``additional_kwargs.safety_termination`` so logs, traces, and SSE
consumers can see what happened.
Hook choice: ``after_model`` (not ``wrap_model_call``) because the response
is a *normal* return — not an exception — and we want to participate in the
same after-model chain as ``LoopDetectionMiddleware``, with which we share
the same tool-call-suppression mechanic but a different trigger.
Placement: register *after* ``LoopDetectionMiddleware`` in the middleware
list. LangChain factory wires ``after_model`` edges in reverse list order
(``langchain/agents/factory.py:add_edge("model", middleware_w_after_model[-1])``,
then walks ``range(len-1, 0, -1)``), so the *last* registered middleware is
the *first* to observe the model output. Registering Safety after Loop
means Safety sees the raw response first, clears tool calls if it fires,
and Loop then accounts against the cleaned message.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime
from deerflow.agents.middlewares.safety_termination_detectors import (
SafetyTermination,
SafetyTerminationDetector,
default_detectors,
)
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
if TYPE_CHECKING:
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
logger = logging.getLogger(__name__)
_USER_FACING_MESSAGE = (
"The model provider stopped this response with a safety-related signal "
"({reason_field}={reason_value!r}, detector={detector!r}). Any tool "
"calls produced in this turn were suppressed because their arguments "
"may be truncated and unsafe to execute. Please rephrase the request "
"or ask for a narrower output."
)
class SafetyFinishReasonMiddleware(AgentMiddleware[AgentState]):
"""Strip tool_calls from AIMessages flagged by a SafetyTerminationDetector."""
def __init__(self, detectors: list[SafetyTerminationDetector] | None = None) -> None:
super().__init__()
# Copy so caller mutations after construction don't leak into us.
self._detectors: list[SafetyTerminationDetector] = list(detectors) if detectors else default_detectors()
@classmethod
def from_config(cls, config: SafetyFinishReasonConfig) -> SafetyFinishReasonMiddleware:
"""Construct from validated Pydantic config, honouring the
reflection-loaded detector list when provided.
An explicit empty list is intentionally rejected — it would silently
disable detection while leaving the middleware in the chain, which
is the worst of both worlds. Use ``enabled: false`` instead.
"""
if config.detectors is None:
return cls()
if not config.detectors:
raise ValueError("safety_finish_reason.detectors must be omitted (use built-ins) or contain at least one entry; use enabled=false to disable the middleware entirely.")
from deerflow.reflection import resolve_variable
detectors: list[SafetyTerminationDetector] = []
for entry in config.detectors:
detector_cls = resolve_variable(entry.use)
kwargs = dict(entry.config) if entry.config else {}
detector = detector_cls(**kwargs)
if not isinstance(detector, SafetyTerminationDetector):
raise TypeError(f"{entry.use} did not produce a SafetyTerminationDetector (got {type(detector).__name__}); ensure it has a `name` attribute and a `detect(message)` method")
detectors.append(detector)
return cls(detectors=detectors)
# ----- detection -------------------------------------------------------
def _detect(self, message: AIMessage) -> SafetyTermination | None:
for detector in self._detectors:
try:
hit = detector.detect(message)
except Exception: # noqa: BLE001 - never let a buggy detector break the agent run
logger.exception("SafetyTerminationDetector %r raised; treating as no-match", getattr(detector, "name", type(detector).__name__))
continue
if hit is not None:
return hit
return None
# ----- message rewriting ----------------------------------------------
@staticmethod
def _append_user_message(content: object, text: str) -> str | list:
"""Append a plain-text explanation to AIMessage content.
Mirrors ``LoopDetectionMiddleware._append_text`` so list-content
responses (Anthropic thinking blocks, vLLM reasoning splits) keep
their structure instead of being string-coerced into a TypeError.
"""
if content is None or content == "":
return text
if isinstance(content, list):
return [*content, {"type": "text", "text": f"\n\n{text}"}]
if isinstance(content, str):
return content + f"\n\n{text}"
return str(content) + f"\n\n{text}"
def _build_suppressed_message(
self,
message: AIMessage,
termination: SafetyTermination,
) -> AIMessage:
suppressed_names = [tc.get("name") or "unknown" for tc in (message.tool_calls or [])]
explanation = _USER_FACING_MESSAGE.format(
reason_field=termination.reason_field,
reason_value=termination.reason_value,
detector=termination.detector,
)
new_content = self._append_user_message(message.content, explanation)
# clone_ai_message_with_tool_calls handles structured tool_calls,
# raw additional_kwargs.tool_calls, and function_call in one shot.
# It only rewrites finish_reason when the old value was "tool_calls",
# which is not our case — content_filter / refusal / SAFETY stay put
# so downstream SSE / converters keep seeing the real provider reason.
cleared = clone_ai_message_with_tool_calls(message, [], content=new_content)
# Re-clone additional_kwargs so we don't accidentally mutate the
# dict returned by clone_ai_message_with_tool_calls (which already
# made a shallow copy, but downstream model_copy still references
# it). Then stamp the observability record.
kwargs = dict(getattr(cleared, "additional_kwargs", None) or {})
kwargs["safety_termination"] = {
"detector": termination.detector,
"reason_field": termination.reason_field,
"reason_value": termination.reason_value,
"suppressed_tool_call_count": len(suppressed_names),
"suppressed_tool_call_names": suppressed_names,
"extras": dict(termination.extras) if termination.extras else {},
}
return cleared.model_copy(update={"additional_kwargs": kwargs})
# ----- observability ---------------------------------------------------
def _emit_event(
self,
termination: SafetyTermination,
suppressed_names: list[str],
runtime: Runtime,
) -> None:
"""Notify SSE consumers (e.g. the web UI) that a tool turn was
suppressed so they can reconcile any "tool starting..." placeholders
already streamed to the user. Failures are logged at debug and
ignored — this is a best-effort signal."""
try:
from langgraph.config import get_stream_writer
writer = get_stream_writer()
except Exception: # noqa: BLE001
logger.debug("get_stream_writer unavailable; skipping safety_termination event", exc_info=True)
return
thread_id = None
if runtime is not None and getattr(runtime, "context", None):
thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None
try:
writer(
{
"type": "safety_termination",
"detector": termination.detector,
"reason_field": termination.reason_field,
"reason_value": termination.reason_value,
"suppressed_tool_call_count": len(suppressed_names),
"suppressed_tool_call_names": suppressed_names,
"thread_id": thread_id,
}
)
except Exception: # noqa: BLE001
logger.debug("Failed to emit safety_termination stream event", exc_info=True)
def _record_audit_event(
self,
termination: SafetyTermination,
message,
tool_calls: list[dict],
runtime: Runtime,
) -> None:
"""Write a ``middleware:safety_termination`` record to RunEventStore
for post-run auditability.
The custom stream event in ``_emit_event`` is consumed by live SSE
clients and disappears after the run; this event is persisted so an
operator can answer "which runs were safety-suppressed today?" from
a single SQL query without joining the message body. Worker exposes
the run-scoped ``RunJournal`` via ``runtime.context["__run_journal"]``;
absent in unit-test / subagent / no-event-store paths, in which case
we silently skip.
Tool **arguments** are deliberately **not** recorded — those are the
very content the provider filtered; persisting them would defeat the
purpose of the safety filter. Names / count / ids are sufficient for
audit and debugging (issue #3028 review).
"""
journal = None
if runtime is not None and getattr(runtime, "context", None):
context = runtime.context
if isinstance(context, dict):
journal = context.get("__run_journal")
if journal is None:
return
suppressed_names = [tc.get("name") or "unknown" for tc in tool_calls]
suppressed_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
changes = {
"detector": termination.detector,
"reason_field": termination.reason_field,
"reason_value": termination.reason_value,
"suppressed_tool_call_count": len(tool_calls),
"suppressed_tool_call_names": suppressed_names,
"suppressed_tool_call_ids": suppressed_ids,
"message_id": getattr(message, "id", None),
"extras": dict(termination.extras) if termination.extras else {},
}
try:
journal.record_middleware(
tag="safety_termination",
name=type(self).__name__,
hook="after_model",
action="suppress_tool_calls",
changes=changes,
)
except Exception: # noqa: BLE001
# Audit-event persistence must never break agent execution.
logger.debug("Failed to record middleware:safety_termination event", exc_info=True)
# ----- main apply ------------------------------------------------------
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
messages = state.get("messages", [])
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage):
return None
# Issue scope: only intervene when there's something to suppress.
# ``content_filter`` without tool_calls is allowed through unchanged
# so the partial text response (if any) reaches the user naturally.
tool_calls = last.tool_calls
if not tool_calls:
return None
termination = self._detect(last)
if termination is None:
return None
patched = self._build_suppressed_message(last, termination)
thread_id = None
if runtime is not None and getattr(runtime, "context", None):
thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None
logger.warning(
"Provider safety termination detected — suppressed %d tool call(s)",
len(tool_calls),
extra={
"thread_id": thread_id,
"detector": termination.detector,
"reason_field": termination.reason_field,
"reason_value": termination.reason_value,
"suppressed_tool_call_names": [tc.get("name") for tc in tool_calls],
},
)
self._emit_event(termination, [tc.get("name") or "unknown" for tc in tool_calls], runtime)
self._record_audit_event(termination, last, list(tool_calls), runtime)
return {"messages": [patched]}
# ----- hooks -----------------------------------------------------------
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
@@ -0,0 +1,237 @@
"""Detectors for provider-side safety termination signals.
Different LLM providers signal "I stopped this response for safety reasons"
through different fields with different values. This module defines a small
strategy interface and three built-in detectors that cover the major
providers DeerFlow supports today. New providers (Wenxin, Hunyuan, Bedrock
adapters, in-house gateways, ...) can be added by implementing
``SafetyTerminationDetector`` and wiring it through
``config.yaml: safety_finish_reason.detectors``.
The middleware that consumes these detectors lives in
``safety_finish_reason_middleware.py``.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Protocol, runtime_checkable
from langchain_core.messages import AIMessage
@dataclass(frozen=True)
class SafetyTermination:
"""A detected safety-related termination signal.
Attributes:
detector: Name of the detector that produced this result. Used for
observability so operators can see which provider rule fired.
reason_field: The message metadata field that carried the signal
(e.g. ``finish_reason``, ``stop_reason``).
reason_value: The actual value of that field
(e.g. ``content_filter``, ``refusal``, ``SAFETY``).
extras: Provider-specific metadata that may help downstream
consumers (e.g. Azure OpenAI content_filter_results, Gemini
safety_ratings). Detectors are free to populate or skip this.
"""
detector: str
reason_field: str
reason_value: str
extras: dict[str, Any] = field(default_factory=dict)
@runtime_checkable
class SafetyTerminationDetector(Protocol):
"""Strategy interface for provider safety termination detection."""
name: str
def detect(self, message: AIMessage) -> SafetyTermination | None:
"""Return a SafetyTermination if *message* indicates provider safety
termination, otherwise return ``None``.
Implementations must be side-effect free and tolerant of missing or
oddly-typed metadata — detectors run on every model response.
"""
...
def _get_metadata_value(message: AIMessage, field_name: str) -> str | None:
"""Read a string-typed value from either ``response_metadata`` or
``additional_kwargs``.
LangChain provider adapters are inconsistent about where they stash
provider stop signals. Most modern adapters use ``response_metadata``,
but some legacy / passthrough paths still surface them via
``additional_kwargs``. We check both, in that order, and only accept
string values — Pydantic enums or dicts are ignored so we never raise
on malformed inputs.
"""
for container_name in ("response_metadata", "additional_kwargs"):
container = getattr(message, container_name, None) or {}
if not isinstance(container, dict):
continue
value = container.get(field_name)
if isinstance(value, str) and value:
return value
return None
class OpenAICompatibleContentFilterDetector:
"""OpenAI-compatible content_filter signal.
Covers OpenAI, Azure OpenAI, Moonshot/Kimi, DeepSeek, Mistral, vLLM,
Qwen (OpenAI-compatible mode), and any other adapter that follows the
OpenAI ``finish_reason`` convention.
Some Chinese providers ship custom OpenAI-compatible gateways that use
alternative tokens like ``sensitive`` or ``violation``. Extend the set
via the ``finish_reasons`` kwarg in config.
"""
name = "openai_compatible_content_filter"
def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None:
configured = finish_reasons if finish_reasons is not None else ("content_filter",)
self._finish_reasons: frozenset[str] = frozenset(r.lower() for r in configured)
def detect(self, message: AIMessage) -> SafetyTermination | None:
value = _get_metadata_value(message, "finish_reason")
if value is None or value.lower() not in self._finish_reasons:
return None
extras: dict[str, Any] = {}
# Azure OpenAI ships a structured content_filter_results block; carry it
# through so operators can see *what* was filtered without re-tracing.
response_metadata = getattr(message, "response_metadata", None) or {}
if isinstance(response_metadata, dict):
filter_results = response_metadata.get("content_filter_results")
if filter_results:
extras["content_filter_results"] = filter_results
return SafetyTermination(
detector=self.name,
reason_field="finish_reason",
reason_value=value,
extras=extras,
)
class AnthropicRefusalDetector:
"""Anthropic ``stop_reason == "refusal"`` signal.
Anthropic models surface safety refusals via a dedicated ``stop_reason``
rather than ``finish_reason``. See:
https://platform.claude.com/docs/en/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals
"""
name = "anthropic_refusal"
def __init__(self, stop_reasons: list[str] | tuple[str, ...] | None = None) -> None:
configured = stop_reasons if stop_reasons is not None else ("refusal",)
self._stop_reasons: frozenset[str] = frozenset(r.lower() for r in configured)
def detect(self, message: AIMessage) -> SafetyTermination | None:
value = _get_metadata_value(message, "stop_reason")
if value is None or value.lower() not in self._stop_reasons:
return None
return SafetyTermination(
detector=self.name,
reason_field="stop_reason",
reason_value=value,
)
class GeminiSafetyDetector:
"""Gemini / Vertex AI safety-related finish reasons.
Gemini uses the same ``finish_reason`` field as OpenAI but with an
enumerated upper-case taxonomy. The default set covers every Gemini
finish_reason that means "the model stopped because the content/image
tripped a safety, blocklist, recitation, or PII filter" — i.e. cases
where any tool_calls returned alongside are likely truncated/
unreliable. Full enum:
https://docs.cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform_v1.types.Candidate.FinishReason
Intentionally **excluded** from the default set:
- ``STOP`` — normal termination.
- ``MAX_TOKENS`` — output length truncation, not safety
(same root failure mode as
content_filter, but issue #3028
scopes it out; expose separately if
desired).
- ``LANGUAGE`` / ``NO_IMAGE`` — capability mismatches, unrelated to
safety; tool_calls would be absent
anyway.
- ``MALFORMED_FUNCTION_CALL`` /
``UNEXPECTED_TOOL_CALL`` — tool-call protocol errors. The
tool_calls are *also* unreliable
here, but the failure category is
distinct from safety filtering;
handle in a dedicated detector to
keep observability records honest.
- ``OTHER`` / ``IMAGE_OTHER`` /
``FINISH_REASON_UNSPECIFIED`` — too broad to enable by default;
opt in via ``finish_reasons=`` if
your provider abuses these.
"""
name = "gemini_safety"
_DEFAULT_FINISH_REASONS = (
# Text safety
"SAFETY",
"BLOCKLIST",
"PROHIBITED_CONTENT",
"SPII",
"RECITATION",
# Image safety (multimodal generation)
"IMAGE_SAFETY",
"IMAGE_PROHIBITED_CONTENT",
"IMAGE_RECITATION",
)
def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None:
configured = finish_reasons if finish_reasons is not None else self._DEFAULT_FINISH_REASONS
self._finish_reasons: frozenset[str] = frozenset(r.upper() for r in configured)
def detect(self, message: AIMessage) -> SafetyTermination | None:
value = _get_metadata_value(message, "finish_reason")
if value is None or value.upper() not in self._finish_reasons:
return None
extras: dict[str, Any] = {}
response_metadata = getattr(message, "response_metadata", None) or {}
if isinstance(response_metadata, dict):
# Gemini surfaces per-category scoring under safety_ratings.
ratings = response_metadata.get("safety_ratings")
if ratings:
extras["safety_ratings"] = ratings
return SafetyTermination(
detector=self.name,
reason_field="finish_reason",
reason_value=value,
extras=extras,
)
def default_detectors() -> list[SafetyTerminationDetector]:
"""Built-in detector set used when no custom detectors are configured."""
return [
OpenAICompatibleContentFilterDetector(),
AnthropicRefusalDetector(),
GeminiSafetyDetector(),
]
__all__ = [
"AnthropicRefusalDetector",
"GeminiSafetyDetector",
"OpenAICompatibleContentFilterDetector",
"SafetyTermination",
"SafetyTerminationDetector",
"default_detectors",
]
@@ -164,4 +164,14 @@ def build_subagent_runtime_middlewares(
middlewares.append(ViewImageMiddleware())
# Same provider safety-termination guard the lead agent uses — subagents
# are equally exposed to truncated tool_calls returned with
# finish_reason=content_filter (and friends), and the bad call would then
# propagate back to the lead agent via the task tool result.
safety_config = app_config.safety_finish_reason
if safety_config.enabled:
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config))
return middlewares
@@ -20,6 +20,7 @@ from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_
from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
from deerflow.config.runtime_paths import existing_project_file
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
from deerflow.config.skills_config import SkillsConfig
@@ -102,6 +103,7 @@ class AppConfig(BaseModel):
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
model_config = ConfigDict(extra="allow")
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
@@ -0,0 +1,47 @@
"""Configuration for SafetyFinishReasonMiddleware.
Mirrors the shape of GuardrailsConfig: detectors are loaded by class path
through ``deerflow.reflection.resolve_variable`` (same loader the
``guardrails.provider`` config uses) so users can drop in custom provider
detectors without modifying core code.
"""
from __future__ import annotations
from pydantic import BaseModel, Field
class SafetyDetectorConfig(BaseModel):
"""One detector entry under ``safety_finish_reason.detectors``."""
use: str = Field(
description=("Class path of a SafetyTerminationDetector implementation (e.g. 'deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector')."),
)
config: dict = Field(
default_factory=dict,
description="Constructor kwargs passed to the detector class.",
)
class SafetyFinishReasonConfig(BaseModel):
"""Configuration for the SafetyFinishReasonMiddleware.
The middleware intercepts AIMessages where the provider signaled a
safety-related termination (e.g. OpenAI ``finish_reason='content_filter'``)
while still returning tool calls, and suppresses those tool calls so the
half-truncated arguments never execute.
"""
enabled: bool = Field(
default=True,
description="Master switch for the SafetyFinishReasonMiddleware.",
)
detectors: list[SafetyDetectorConfig] | None = Field(
default=None,
description=(
"Custom detector list. Leave unset (None) to use the built-in "
"set covering OpenAI-compatible content_filter, Anthropic "
"refusal, and Gemini SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/"
"RECITATION. Provide a non-null list to fully override."
),
)
@@ -134,9 +134,25 @@ def reset_mcp_tools_cache() -> None:
"""Reset the MCP tools cache.
This is useful for testing or when you want to reload MCP tools.
Also closes all persistent MCP sessions so they are recreated on
the next tool load.
"""
global _mcp_tools_cache, _cache_initialized, _config_mtime
_mcp_tools_cache = None
_cache_initialized = False
_config_mtime = None
# Close persistent sessions they will be recreated by the next
# get_mcp_tools() call with the (possibly updated) connection config.
try:
from deerflow.mcp.session_pool import get_session_pool
pool = get_session_pool()
pool.close_all_sync()
except Exception:
logger.debug("Could not close MCP session pool on cache reset", exc_info=True)
from deerflow.mcp.session_pool import reset_session_pool
reset_session_pool()
logger.info("MCP tools cache reset")
@@ -0,0 +1,198 @@
"""Persistent MCP session pool for stateful tool calls.
When MCP tools are loaded via langchain-mcp-adapters with ``session=None``,
each tool call creates a new MCP session. For stateful servers like Playwright,
this means browser state (opened pages, filled forms) is lost between calls.
This module provides a session pool that maintains persistent MCP sessions,
scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id —
so that consecutive tool calls share the same session and server-side state.
Sessions are evicted in LRU order when the pool reaches capacity.
"""
from __future__ import annotations
import asyncio
import logging
import threading
from collections import OrderedDict
from typing import Any
from mcp import ClientSession
logger = logging.getLogger(__name__)
class MCPSessionPool:
"""Manages persistent MCP sessions scoped by ``(server_name, scope_key)``."""
MAX_SESSIONS = 256
SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session via run_coroutine_threadsafe
def __init__(self) -> None:
self._entries: OrderedDict[
tuple[str, str],
tuple[ClientSession, asyncio.AbstractEventLoop],
] = OrderedDict()
self._context_managers: dict[tuple[str, str], Any] = {}
# threading.Lock is not bound to any event loop, so it is safe to
# acquire from both async paths and sync/worker-thread paths.
self._lock = threading.Lock()
async def get_session(
self,
server_name: str,
scope_key: str,
connection: dict[str, Any],
) -> ClientSession:
"""Get or create a persistent MCP session.
If an existing session was created in a different event loop (e.g.
the sync-wrapper path), it is closed and replaced with a fresh one
in the current loop.
Args:
server_name: MCP server name.
scope_key: Isolation key (typically thread_id).
connection: Connection configuration for ``create_session``.
Returns:
An initialized ``ClientSession``.
"""
key = (server_name, scope_key)
current_loop = asyncio.get_running_loop()
# Phase 1: inspect/mutate the registry under the thread lock (no awaits).
cms_to_close: list[tuple[tuple[str, str], Any]] = []
with self._lock:
if key in self._entries:
session, loop = self._entries[key]
if loop is current_loop:
self._entries.move_to_end(key)
return session
# Session belongs to a different event loop evict it.
cm = self._context_managers.pop(key, None)
self._entries.pop(key)
if cm is not None:
cms_to_close.append((key, cm))
# Evict LRU entries when at capacity.
while len(self._entries) >= self.MAX_SESSIONS:
oldest_key = next(iter(self._entries))
cm = self._context_managers.pop(oldest_key, None)
self._entries.pop(oldest_key)
if cm is not None:
cms_to_close.append((oldest_key, cm))
# Phase 2: async cleanup outside the lock so we never await while holding it.
for close_key, cm in cms_to_close:
try:
await cm.__aexit__(None, None, None)
except Exception:
logger.warning("Error closing MCP session %s", close_key, exc_info=True)
from langchain_mcp_adapters.sessions import create_session
cm = create_session(connection)
session = await cm.__aenter__()
await session.initialize()
# Phase 3: register the new session under the lock.
with self._lock:
self._entries[key] = (session, current_loop)
self._context_managers[key] = cm
logger.info("Created persistent MCP session for %s/%s", server_name, scope_key)
return session
# ------------------------------------------------------------------
# Cleanup helpers
# ------------------------------------------------------------------
async def _close_cm(self, key: tuple[str, str], cm: Any) -> None:
"""Close a single context manager (must be called WITHOUT the lock)."""
try:
await cm.__aexit__(None, None, None)
except Exception:
logger.warning("Error closing MCP session %s", key, exc_info=True)
async def close_scope(self, scope_key: str) -> None:
"""Close all sessions for a given scope (e.g. thread_id)."""
with self._lock:
keys = [k for k in self._entries if k[1] == scope_key]
cms = [(k, self._context_managers.pop(k, None)) for k in keys]
for k in keys:
self._entries.pop(k, None)
for key, cm in cms:
if cm is not None:
await self._close_cm(key, cm)
async def close_server(self, server_name: str) -> None:
"""Close all sessions for a given server."""
with self._lock:
keys = [k for k in self._entries if k[0] == server_name]
cms = [(k, self._context_managers.pop(k, None)) for k in keys]
for k in keys:
self._entries.pop(k, None)
for key, cm in cms:
if cm is not None:
await self._close_cm(key, cm)
async def close_all(self) -> None:
"""Close every managed session."""
with self._lock:
cms = list(self._context_managers.items())
self._context_managers.clear()
self._entries.clear()
for key, cm in cms:
await self._close_cm(key, cm)
def close_all_sync(self) -> None:
"""Close all sessions using their owning event loops (synchronous).
Each session is closed on the loop it was created in, avoiding
cross-loop resource leaks. Safe to call from any thread without an
active event loop.
"""
with self._lock:
entries = list(self._entries.items())
cms = dict(self._context_managers)
self._entries.clear()
self._context_managers.clear()
for key, (_, loop) in entries:
cm = cms.get(key)
if cm is None or loop.is_closed():
continue
try:
if loop.is_running():
# Schedule on the owning loop from this (different) thread.
future = asyncio.run_coroutine_threadsafe(cm.__aexit__(None, None, None), loop)
future.result(timeout=self.SESSION_CLOSE_TIMEOUT)
else:
loop.run_until_complete(cm.__aexit__(None, None, None))
except Exception:
logger.debug("Error closing MCP session %s during sync close", key, exc_info=True)
# ------------------------------------------------------------------
# Module-level singleton
# ------------------------------------------------------------------
_pool: MCPSessionPool | None = None
_pool_lock = threading.Lock()
def get_session_pool() -> MCPSessionPool:
"""Return the global session-pool singleton."""
global _pool
if _pool is None:
with _pool_lock:
if _pool is None:
_pool = MCPSessionPool()
return _pool
def reset_session_pool() -> None:
"""Reset the singleton (for tests)."""
global _pool
_pool = None
+190 -8
View File
@@ -1,21 +1,181 @@
"""Load MCP tools using langchain-mcp-adapters."""
"""Load MCP tools using langchain-mcp-adapters with persistent sessions."""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.tools import BaseTool
from langchain_core.tools import BaseTool, StructuredTool
from langgraph.config import get_config
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.mcp.client import build_servers_config
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
from deerflow.mcp.session_pool import get_session_pool
from deerflow.reflection import resolve_variable
from deerflow.tools.sync import make_sync_tool_wrapper
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
def _extract_thread_id(runtime: Runtime | None) -> str:
"""Extract thread_id from the injected tool runtime or LangGraph config."""
if runtime is not None:
tid = runtime.context.get("thread_id") if runtime.context else None
if tid is not None:
return str(tid)
config = runtime.config or {}
tid = config.get("configurable", {}).get("thread_id")
if tid is not None:
return str(tid)
try:
tid = get_config().get("configurable", {}).get("thread_id")
return str(tid) if tid is not None else "default"
except RuntimeError:
return "default"
def _convert_call_tool_result(call_tool_result: Any) -> Any:
"""Convert an MCP CallToolResult to the LangChain ``content_and_artifact`` format.
Implements the same conversion logic as the adapter without relying on
the private ``langchain_mcp_adapters.tools._convert_call_tool_result`` symbol.
"""
from langchain_core.messages import ToolMessage
from langchain_core.messages.content import create_file_block, create_image_block, create_text_block
from langchain_core.tools import ToolException
from mcp.types import EmbeddedResource, ImageContent, ResourceLink, TextContent, TextResourceContents
# Pass ToolMessage through directly (interceptor short-circuit).
if isinstance(call_tool_result, ToolMessage):
return call_tool_result, None
# Pass LangGraph Command through directly when langgraph is installed.
try:
from langgraph.types import Command
if isinstance(call_tool_result, Command):
return call_tool_result, None
except ImportError:
# langgraph is optional; if unavailable, continue with standard MCP content conversion.
pass
# Convert MCP content blocks to LangChain content blocks.
lc_content = []
for item in call_tool_result.content:
if isinstance(item, TextContent):
lc_content.append(create_text_block(text=item.text))
elif isinstance(item, ImageContent):
lc_content.append(create_image_block(base64=item.data, mime_type=item.mimeType))
elif isinstance(item, ResourceLink):
mime = item.mimeType or None
if mime and mime.startswith("image/"):
lc_content.append(create_image_block(url=str(item.uri), mime_type=mime))
else:
lc_content.append(create_file_block(url=str(item.uri), mime_type=mime))
elif isinstance(item, EmbeddedResource):
from mcp.types import BlobResourceContents
res = item.resource
if isinstance(res, TextResourceContents):
lc_content.append(create_text_block(text=res.text))
elif isinstance(res, BlobResourceContents):
mime = res.mimeType or None
if mime and mime.startswith("image/"):
lc_content.append(create_image_block(base64=res.blob, mime_type=mime))
else:
lc_content.append(create_file_block(base64=res.blob, mime_type=mime))
else:
lc_content.append(create_text_block(text=str(res)))
else:
lc_content.append(create_text_block(text=str(item)))
if call_tool_result.isError:
error_parts = [item["text"] for item in lc_content if isinstance(item, dict) and item.get("type") == "text"]
raise ToolException("\n".join(error_parts) if error_parts else str(lc_content))
artifact = None
if call_tool_result.structuredContent is not None:
artifact = {"structured_content": call_tool_result.structuredContent}
return lc_content, artifact
def _make_session_pool_tool(
tool: BaseTool,
server_name: str,
connection: dict[str, Any],
tool_interceptors: list[Any] | None = None,
) -> BaseTool:
"""Wrap an MCP tool so it reuses a persistent session from the pool.
Replaces the per-call session creation with pool-managed sessions scoped
by ``(server_name, thread_id)``. This ensures stateful MCP servers (e.g.
Playwright) keep their state across tool calls within the same thread.
The configured ``tool_interceptors`` (OAuth, custom) are preserved and
applied on every call before invoking the pooled session.
"""
# Strip the server-name prefix to recover the original MCP tool name.
original_name = tool.name
prefix = f"{server_name}_"
if original_name.startswith(prefix):
original_name = original_name[len(prefix) :]
pool = get_session_pool()
async def call_with_persistent_session(
runtime: Runtime | None = None,
**arguments: Any,
) -> Any:
thread_id = _extract_thread_id(runtime)
session = await pool.get_session(server_name, thread_id, connection)
if tool_interceptors:
from langchain_mcp_adapters.interceptors import MCPToolCallRequest
async def base_handler(request: MCPToolCallRequest) -> Any:
return await session.call_tool(request.name, request.args)
handler = base_handler
for interceptor in reversed(tool_interceptors):
outer = handler
async def wrapped(req: Any, _i: Any = interceptor, _h: Any = outer) -> Any:
return await _i(req, _h)
handler = wrapped
request = MCPToolCallRequest(
name=original_name,
args=arguments,
server_name=server_name,
runtime=runtime,
)
call_tool_result = await handler(request)
else:
call_tool_result = await session.call_tool(original_name, arguments)
return _convert_call_tool_result(call_tool_result)
return StructuredTool(
name=tool.name,
description=tool.description,
args_schema=tool.args_schema,
coroutine=call_with_persistent_session,
response_format="content_and_artifact",
metadata=tool.metadata,
)
async def get_mcp_tools() -> list[BaseTool]:
"""Get all tools from enabled MCP servers.
Tools are wrapped with persistent-session logic so that consecutive
calls within the same thread reuse the same MCP session.
Returns:
List of LangChain tools from all enabled MCP servers.
"""
@@ -50,7 +210,7 @@ async def get_mcp_tools() -> list[BaseTool]:
existing_headers["Authorization"] = auth_header
servers_config[server_name]["headers"] = existing_headers
tool_interceptors = []
tool_interceptors: list[Any] = []
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
if oauth_interceptor is not None:
tool_interceptors.append(oauth_interceptor)
@@ -74,20 +234,42 @@ async def get_mcp_tools() -> list[BaseTool]:
elif interceptor is not None:
logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping")
except Exception as e:
logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True)
logger.warning(
f"Failed to load MCP interceptor {interceptor_path}: {e}",
exc_info=True,
)
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True)
client = MultiServerMCPClient(
servers_config,
tool_interceptors=tool_interceptors,
tool_name_prefix=True,
)
# Get all tools from all servers
# Get all tools from all servers (discovers tool definitions via
# temporary sessions the persistent-session wrapping is applied below).
tools = await client.get_tools()
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
# Patch tools to support sync invocation, as deerflow client streams synchronously
# Wrap each tool with persistent-session logic.
wrapped_tools: list[BaseTool] = []
for tool in tools:
tool_server: str | None = None
for name in servers_config:
if tool.name.startswith(f"{name}_"):
tool_server = name
break
if tool_server is not None:
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
else:
wrapped_tools.append(tool)
# Patch tools to support sync invocation, as deerflow client streams synchronously
for tool in wrapped_tools:
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
return tools
return wrapped_tools
except Exception as e:
logger.error(f"Failed to load MCP tools: {e}", exc_info=True)
@@ -227,9 +227,48 @@ class RunRepository(RunStore):
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
async def update_run_progress(
self,
run_id: str,
*,
total_input_tokens: int | None = None,
total_output_tokens: int | None = None,
total_tokens: int | None = None,
llm_call_count: int | None = None,
lead_agent_tokens: int | None = None,
subagent_tokens: int | None = None,
middleware_tokens: int | None = None,
message_count: int | None = None,
last_ai_message: str | None = None,
first_human_message: str | None = None,
) -> None:
"""Update token usage + convenience fields while a run is still active."""
values: dict[str, Any] = {"updated_at": datetime.now(UTC)}
optional_counters = {
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_tokens": total_tokens,
"llm_call_count": llm_call_count,
"lead_agent_tokens": lead_agent_tokens,
"subagent_tokens": subagent_tokens,
"middleware_tokens": middleware_tokens,
"message_count": message_count,
}
for key, value in optional_counters.items():
if value is not None:
values[key] = value
if last_ai_message is not None:
values["last_ai_message"] = last_ai_message[:2000]
if first_human_message is not None:
values["first_human_message"] = first_human_message[:2000]
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id, RunRow.status == "running").values(**values))
await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage via a single SQL GROUP BY query."""
_completed = RunRow.status.in_(("success", "error"))
statuses = ("success", "error", "running") if include_active else ("success", "error")
_completed = RunRow.status.in_(statuses)
_thread = RunRow.thread_id == thread_id
model_name = func.coalesce(RunRow.model_name, "unknown")
@@ -20,7 +20,7 @@ from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Mapping
from collections.abc import Awaitable, Callable, Mapping
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID
@@ -46,6 +46,8 @@ class RunJournal(BaseCallbackHandler):
*,
track_token_usage: bool = True,
flush_threshold: int = 20,
progress_reporter: Callable[[dict], Awaitable[None]] | None = None,
progress_flush_interval: float = 5.0,
):
super().__init__()
self.run_id = run_id
@@ -53,10 +55,16 @@ class RunJournal(BaseCallbackHandler):
self._store = event_store
self._track_tokens = track_token_usage
self._flush_threshold = flush_threshold
self._progress_reporter = progress_reporter
self._progress_flush_interval = progress_flush_interval
# Write buffer
self._buffer: list[dict] = []
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
self._pending_progress_task: asyncio.Task[None] | None = None
self._pending_progress_delayed = False
self._progress_dirty = False
self._last_progress_flush = 0.0
# Token accumulators
self._total_input_tokens = 0
@@ -294,6 +302,8 @@ class RunJournal(BaseCallbackHandler):
else:
self._lead_agent_tokens += total_tk
self._schedule_progress_flush()
if messages:
self._counted_message_llm_run_ids.add(str(run_id))
@@ -445,6 +455,8 @@ class RunJournal(BaseCallbackHandler):
else:
self._lead_agent_tokens += total_tk
self._schedule_progress_flush()
def set_first_human_message(self, content: str) -> None:
"""Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None
@@ -474,6 +486,14 @@ class RunJournal(BaseCallbackHandler):
"""Force flush remaining buffer. Called in worker's finally block."""
if self._pending_flush_tasks:
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
while self._pending_progress_task is not None and not self._pending_progress_task.done():
if self._pending_progress_delayed:
self._pending_progress_task.cancel()
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
self._progress_dirty = False
self._pending_progress_delayed = False
break
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
while self._buffer:
batch = self._buffer[: self._flush_threshold]
@@ -484,6 +504,57 @@ class RunJournal(BaseCallbackHandler):
self._buffer = batch + self._buffer
raise
def _schedule_progress_flush(self) -> None:
"""Best-effort throttled progress snapshot for active run visibility."""
if self._progress_reporter is None:
return
now = time.monotonic()
elapsed = now - self._last_progress_flush
if elapsed < self._progress_flush_interval:
self._progress_dirty = True
self._schedule_delayed_progress_flush(self._progress_flush_interval - elapsed)
return
if self._pending_progress_task is not None and not self._pending_progress_task.done():
self._progress_dirty = True
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
self._progress_dirty = False
self._pending_progress_task = loop.create_task(self._flush_progress_async(snapshot=self.get_completion_data()))
def _schedule_delayed_progress_flush(self, delay: float) -> None:
if self._pending_progress_task is not None and not self._pending_progress_task.done():
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
delay = max(0.0, delay)
self._pending_progress_delayed = delay > 0
self._pending_progress_task = loop.create_task(self._flush_progress_async(delay=delay))
async def _flush_progress_async(self, *, snapshot: dict | None = None, delay: float = 0.0) -> None:
if self._progress_reporter is None:
return
if delay > 0:
self._pending_progress_delayed = True
await asyncio.sleep(delay)
self._pending_progress_delayed = False
dirty_before_write = self._progress_dirty
self._progress_dirty = False
snapshot_to_write = snapshot or self.get_completion_data()
try:
await self._progress_reporter(snapshot_to_write)
self._last_progress_flush = time.monotonic()
except Exception:
logger.warning("Failed to persist progress snapshot for run %s", self.run_id, exc_info=True)
if dirty_before_write or self._progress_dirty:
self._progress_dirty = False
self._pending_progress_task = None
self._schedule_delayed_progress_flush(self._progress_flush_interval)
def get_completion_data(self) -> dict:
"""Return accumulated token and message data for run completion."""
return {
@@ -38,6 +38,16 @@ class RunRecord:
error: str | None = None
model_name: str | None = None
store_only: bool = False
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
llm_call_count: int = 0
lead_agent_tokens: int = 0
subagent_tokens: int = 0
middleware_tokens: int = 0
message_count: int = 0
last_ai_message: str | None = None
first_human_message: str | None = None
class RunManager:
@@ -102,16 +112,53 @@ class RunManager:
error=row.get("error"),
model_name=row.get("model_name"),
store_only=True,
total_input_tokens=row.get("total_input_tokens") or 0,
total_output_tokens=row.get("total_output_tokens") or 0,
total_tokens=row.get("total_tokens") or 0,
llm_call_count=row.get("llm_call_count") or 0,
lead_agent_tokens=row.get("lead_agent_tokens") or 0,
subagent_tokens=row.get("subagent_tokens") or 0,
middleware_tokens=row.get("middleware_tokens") or 0,
message_count=row.get("message_count") or 0,
last_ai_message=row.get("last_ai_message"),
first_human_message=row.get("first_human_message"),
)
async def update_run_completion(self, run_id: str, **kwargs) -> None:
"""Persist token usage and completion data to the backing store."""
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
for key, value in kwargs.items():
if key == "status":
continue
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if self._store is not None:
try:
await self._store.update_run_completion(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
async def update_run_progress(self, run_id: str, **kwargs) -> None:
"""Persist a running token/message snapshot without changing status."""
should_persist = True
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
should_persist = record.status == RunStatus.running
if record is not None and should_persist:
for key, value in kwargs.items():
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if should_persist and self._store is not None:
try:
await self._store.update_run_progress(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run progress for %s", run_id, exc_info=True)
async def create(
self,
thread_id: str,
@@ -95,12 +95,30 @@ class RunStore(abc.ABC):
) -> None:
pass
async def update_run_progress(
self,
run_id: str,
*,
total_input_tokens: int | None = None,
total_output_tokens: int | None = None,
total_tokens: int | None = None,
llm_call_count: int | None = None,
lead_agent_tokens: int | None = None,
subagent_tokens: int | None = None,
middleware_tokens: int | None = None,
message_count: int | None = None,
last_ai_message: str | None = None,
first_human_message: str | None = None,
) -> None:
"""Persist a best-effort running snapshot without changing run status."""
return None
@abc.abstractmethod
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
pass
@abc.abstractmethod
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage for completed runs in a thread.
Returns a dict with keys: total_tokens, total_input_tokens,
@@ -82,14 +82,22 @@ class MemoryRunStore(RunStore):
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def update_run_progress(self, run_id, **kwargs):
if run_id in self._runs and self._runs[run_id].get("status") == "running":
for key, value in kwargs.items():
if value is not None:
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def list_pending(self, *, before=None):
now = before or datetime.now(UTC).isoformat()
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
results.sort(key=lambda r: r["created_at"])
return results
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
statuses = ("success", "error", "running") if include_active else ("success", "error")
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]
by_model: dict[str, dict] = {}
for r in completed:
model = r.get("model_name") or "unknown"
@@ -153,8 +153,6 @@ async def run_agent(
journal = None
journal = None
# Track whether "events" was requested but skipped
if "events" in requested_modes:
logger.info(
@@ -177,6 +175,7 @@ async def run_agent(
thread_id=thread_id,
event_store=event_store,
track_token_usage=getattr(run_events_config, "track_token_usage", True),
progress_reporter=lambda snapshot: run_manager.update_run_progress(run_id, **snapshot),
)
# 1. Mark running
@@ -219,6 +218,12 @@ async def run_agent(
# manually here because we drive the graph through ``agent.astream(config=...)``
# without passing the official ``context=`` parameter.
runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"), ctx.app_config)
# Expose the run-scoped journal under a sentinel key so middleware can
# write audit events (e.g. SafetyFinishReasonMiddleware recording
# suppressed tool calls). Double-underscore prefix marks it as a
# runtime-internal channel; user code must not depend on the key name.
if journal is not None:
runtime_ctx["__run_journal"] = journal
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=cast(Any, runtime_ctx), store=store)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
@@ -42,6 +42,7 @@ _DEFAULT_GLOB_MAX_RESULTS = 200
_MAX_GLOB_MAX_RESULTS = 1000
_DEFAULT_GREP_MAX_RESULTS = 100
_MAX_GREP_MAX_RESULTS = 500
_DEFAULT_WRITE_FILE_ERROR_MAX_CHARS = 2000
_LOCAL_BASH_CWD_COMMANDS = {"cd", "pushd"}
_LOCAL_BASH_COMMAND_WRAPPERS = {"command", "builtin"}
_LOCAL_BASH_COMMAND_PREFIX_KEYWORDS = {"!", "{", "case", "do", "elif", "else", "for", "if", "select", "then", "time", "until", "while"}
@@ -435,6 +436,42 @@ def _sanitize_error(error: Exception, runtime: Runtime | None = None) -> str:
return msg
def _truncate_write_file_error_detail(detail: str, max_chars: int) -> str:
"""Middle-truncate write_file error details, preserving the head and tail."""
if max_chars == 0:
return detail
if len(detail) <= max_chars:
return detail
total = len(detail)
marker_max_len = len(f"\n... [write_file error truncated: {total} chars skipped] ...\n")
kept = max(0, max_chars - marker_max_len)
if kept == 0:
return detail[:max_chars]
head_len = kept // 2
tail_len = kept - head_len
skipped = total - kept
marker = f"\n... [write_file error truncated: {skipped} chars skipped] ...\n"
return f"{detail[:head_len]}{marker}{detail[-tail_len:] if tail_len > 0 else ''}"
def _format_write_file_error(
requested_path: str,
error: Exception,
runtime: Runtime | None = None,
*,
max_chars: int = _DEFAULT_WRITE_FILE_ERROR_MAX_CHARS,
) -> str:
"""Return a bounded, sanitized error string for write_file failures."""
header = f"Error: Failed to write file '{requested_path}'"
detail = _sanitize_error(error, runtime)
if max_chars == 0:
return f"{header}: {detail}"
detail_budget = max_chars - len(header) - 2
if detail_budget <= 0:
return _truncate_write_file_error_detail(f"{header}: {detail}", max_chars)
return f"{header}: {_truncate_write_file_error_detail(detail, detail_budget)}"
def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
"""Replace virtual /mnt/user-data paths with actual thread data paths.
@@ -1651,9 +1688,9 @@ def write_file_tool(
append: Whether to append content to the end of the file instead of overwriting it. Defaults to false.
"""
try:
requested_path = path
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data)
@@ -1664,15 +1701,21 @@ def write_file_tool(
sandbox.write_file(path, content, append)
return "OK"
except SandboxError as e:
return f"Error: {e}"
return _format_write_file_error(requested_path, e, runtime)
except PermissionError:
return f"Error: Permission denied writing to file: {requested_path}"
return _truncate_write_file_error_detail(
f"Error: Permission denied writing to file: {requested_path}",
_DEFAULT_WRITE_FILE_ERROR_MAX_CHARS,
)
except IsADirectoryError:
return f"Error: Path is a directory, not a file: {requested_path}"
return _truncate_write_file_error_detail(
f"Error: Path is a directory, not a file: {requested_path}",
_DEFAULT_WRITE_FILE_ERROR_MAX_CHARS,
)
except OSError as e:
return f"Error: Failed to write file '{requested_path}': {_sanitize_error(e, runtime)}"
return _format_write_file_error(requested_path, e, runtime)
except Exception as e:
return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime)}"
return _format_write_file_error(requested_path, e, runtime)
async def _write_file_tool_async(
@@ -7,6 +7,7 @@ from dataclasses import replace
from typing import TYPE_CHECKING, Annotated, Any, cast
from langchain.tools import InjectedToolCallId, tool
from langchain_core.callbacks import BaseCallbackManager
from langgraph.config import get_stream_writer
from deerflow.config import get_app_config
@@ -99,15 +100,31 @@ def _schedule_deferred_subagent_cleanup(task_id: str, trace_id: str, max_polls:
def _find_usage_recorder(runtime: Any) -> Any | None:
"""Find a callback handler with ``record_external_llm_usage_records`` in the runtime config."""
"""Find a callback handler with ``record_external_llm_usage_records`` in the runtime config.
LangChain may pass ``config["callbacks"]`` in three different shapes:
- ``None`` (no callbacks registered): no recorder.
- A plain ``list[BaseCallbackHandler]``: iterate it directly.
- A ``BaseCallbackManager`` instance (e.g. ``AsyncCallbackManager`` on async
tool runs): managers are not iterable, so we unwrap ``.handlers`` first.
Any other shape (e.g. a single handler object accidentally passed without a
list wrapper) cannot be iterated safely; treat it as "no recorder" rather
than raise.
"""
if runtime is None:
return None
config = getattr(runtime, "config", None)
if not isinstance(config, dict):
return None
callbacks = config.get("callbacks", [])
callbacks = config.get("callbacks")
if isinstance(callbacks, BaseCallbackManager):
callbacks = callbacks.handlers
if not callbacks:
return None
if not isinstance(callbacks, list):
return None
for cb in callbacks:
if hasattr(cb, "record_external_llm_usage_records"):
return cb
@@ -0,0 +1,206 @@
"""End-to-end demo: SafetyFinishReasonMiddleware on the real DeerFlow lead-agent.
What it proves
--------------
- The real ``make_lead_agent`` / ``DeerFlowClient`` pipeline is built (full
18-middleware chain, sandbox, tools, etc.).
- A model that returns ``finish_reason='content_filter'`` + ``tool_calls``
triggers SafetyFinishReasonMiddleware.
- LangChain's tool router never invokes ``write_file`` — the truncated
arguments do **not** reach the sandbox.
- A ``safety_termination`` custom event is emitted on the stream and the
final AIMessage carries the observability stamp.
Run from backend/ directory:
PYTHONPATH=. uv run python scripts/e2e_safety_termination_demo.py
"""
from __future__ import annotations
import sys
from typing import Any
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, ChatResult
# ---------------------------------------------------------------------------
# Fake provider that mimics Moonshot's content_filter behaviour
# ---------------------------------------------------------------------------
class _ContentFilteredFakeModel(BaseChatModel):
"""First call returns finish_reason=content_filter + truncated write_file
tool_call. Subsequent calls return a normal stop response so the agent
can terminate (the middleware should make a second call unnecessary by
clearing tool_calls, but we keep this safety net in case loop-detection
or anything else triggers another model invocation)."""
call_count: int = 0
@property
def _llm_type(self) -> str:
return "fake-content-filtered"
def bind_tools(self, tools, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self.call_count += 1
if self.call_count == 1:
msg = AIMessage(
content="# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与",
tool_calls=[
{
"id": "call_truncated_write",
"name": "write_file",
"args": {
"path": "/mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md",
"content": "# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与",
},
}
],
response_metadata={
"finish_reason": "content_filter",
"model_name": "kimi-k2.6",
"model_provider": "openai",
},
)
else:
msg = AIMessage(
content="(secondary call, should not be needed)",
response_metadata={"finish_reason": "stop", "model_name": "kimi-k2.6"},
)
return ChatResult(generations=[ChatGeneration(message=msg)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
# ---------------------------------------------------------------------------
# Driver
# ---------------------------------------------------------------------------
def main() -> int:
# Inject the fake model BEFORE constructing the client. Both the
# client module and the lead-agent module bind ``create_chat_model``
# at import time via ``from deerflow.models import create_chat_model``,
# so we patch both attribute slots — the source-of-truth patch on
# ``factory.create_chat_model`` doesn't propagate back into already-
# imported names.
import deerflow.agents.lead_agent.agent as lead_agent_module
import deerflow.client as client_module
fake = _ContentFilteredFakeModel()
originals = {
"lead": lead_agent_module.create_chat_model,
"client": client_module.create_chat_model,
}
def fake_create_chat_model(*args, **kwargs):
return fake
lead_agent_module.create_chat_model = fake_create_chat_model
client_module.create_chat_model = fake_create_chat_model
from deerflow.client import DeerFlowClient
try:
client = DeerFlowClient()
print("\n=== Streaming a turn through the real lead-agent ===")
events: list[dict[str, Any]] = []
for event in client.stream(
"帮我整理一下最近一周政经新闻,写到 /mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md",
thread_id="e2e-safety-1",
):
events.append({"type": event.type, "data": event.data})
# ---- Assertions ----
safety_event = next(
(e for e in events if e["type"] == "custom" and isinstance(e["data"], dict) and e["data"].get("type") == "safety_termination"),
None,
)
final_values = next(
(e for e in reversed(events) if e["type"] == "values"),
None,
)
tool_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "tool"]
ai_tool_call_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "ai" and e["data"].get("tool_calls")]
print(f"\n[stats] total stream events: {len(events)}")
print(f"[stats] model call count: {fake.call_count}")
print(f"[stats] tool messages on stream: {len(tool_messages)}")
print(f"[stats] AI messages carrying tool_calls: {len(ai_tool_call_messages)}")
print("\n[event] safety_termination custom event:")
if safety_event is None:
print(" *** NOT FOUND ***")
return 1
for k, v in safety_event["data"].items():
print(f" {k}: {v}")
print("\n[state] final AIMessage from last values snapshot:")
if final_values is None:
print(" *** no values snapshot ***")
return 1
# `values` event carries `_serialize_message` dicts, not Message objects.
final_messages = final_values["data"].get("messages") or []
last_ai = next((m for m in reversed(final_messages) if isinstance(m, dict) and m.get("type") == "ai"), None)
if last_ai is None:
print(" *** no AIMessage in final state ***")
print(f" message types seen: {[m.get('type') if isinstance(m, dict) else type(m).__name__ for m in final_messages]}")
return 1
tool_calls = last_ai.get("tool_calls") or []
additional_kwargs = last_ai.get("additional_kwargs") or {}
response_metadata = last_ai.get("response_metadata") or {}
content = last_ai.get("content")
print(f" tool_calls (must be empty): {tool_calls}")
print(f" additional_kwargs.safety_termination: {additional_kwargs.get('safety_termination')}")
content_preview = (content if isinstance(content, str) else str(content))[:200]
print(f" content[:200]: {content_preview!r}")
print(f" response_metadata.finish_reason: {response_metadata.get('finish_reason')}")
# NOTE: `client._serialize_message` does not include `response_metadata`
# in the values-event payload (client-layer behaviour, unrelated to the
# middleware). The middleware *does* preserve finish_reason on the
# AIMessage object — see test_safety_finish_reason_middleware.py::
# TestMessageRewrite::test_preserves_response_metadata_finish_reason.
# Here we assert on the observability stamp, which carries the same
# evidence and is in the serialized payload.
stamp = additional_kwargs.get("safety_termination") or {}
failures = []
if tool_calls:
failures.append("final AIMessage still has tool_calls — middleware did NOT clear them")
if not stamp:
failures.append("final AIMessage missing safety_termination observability stamp")
if tool_messages:
failures.append(f"tool node was invoked: {len(tool_messages)} ToolMessage(s) on stream")
if stamp.get("reason_value") != "content_filter":
failures.append(f"safety_termination.reason_value was {stamp.get('reason_value')!r}, expected 'content_filter'")
if safety_event is None:
failures.append("safety_termination custom event was not emitted on the stream")
if failures:
print("\n=== FAIL ===")
for f in failures:
print(f" - {f}")
return 1
print("\n=== PASS ===")
print(" - tool_calls cleared on final AIMessage")
print(" - tool node never invoked (no ToolMessage on stream)")
print(" - safety_termination custom event emitted")
print(" - observability stamp written to additional_kwargs")
print(" - response_metadata.finish_reason preserved for downstream SSE")
return 0
finally:
lead_agent_module.create_chat_model = originals["lead"]
client_module.create_chat_model = originals["client"]
if __name__ == "__main__":
sys.exit(main())
@@ -218,6 +218,70 @@ class TestBuildPatchedMessagesPatching:
assert mw._build_patched_messages(msgs) is None
def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self):
mw = DanglingToolCallMiddleware()
msgs = [
HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:11"),
_tc("web_search", "web_search:12"),
_tc("web_search", "web_search:13"),
]
),
_tool_msg("web_search:11", "web_search"),
_tool_msg("web_search:12", "web_search"),
_tool_msg("web_search:13", "web_search"),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:9"),
_tc("web_search", "web_search:10"),
_tc("web_search", "web_search:11"),
]
),
_tool_msg("web_search:9", "web_search"),
_tool_msg("web_search:10", "web_search"),
_tool_msg("web_search:11", "web_search"),
]
assert mw._build_patched_messages(msgs) is None
def test_reused_tool_call_id_patches_second_dangling_occurrence(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
_tool_msg("web_search:11", "web_search"),
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "web_search:11"
assert patched[1].status == "success"
assert isinstance(patched[3], ToolMessage)
assert patched[3].tool_call_id == "web_search:11"
assert patched[3].status == "error"
def test_reused_tool_call_id_consumes_later_result_for_first_dangling_occurrence(self):
mw = DanglingToolCallMiddleware()
result = _tool_msg("web_search:11", "web_search")
msgs = [
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
result,
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert patched[1] is result
assert patched[1].status == "success"
assert isinstance(patched[3], ToolMessage)
assert patched[3].tool_call_id == "web_search:11"
assert patched[3].status == "error"
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
mw = DanglingToolCallMiddleware()
msgs = [
@@ -0,0 +1,189 @@
"""Regression tests for gateway config freshness on the request hot path.
Bytedance/deer-flow issue #3107 BUG-001: the worker and lead-agent path
captured ``app.state.config`` at gateway startup. ``config.yaml`` edits during
runtime were therefore ignored — ``get_app_config()``'s mtime-based reload
existed but was bypassed because the snapshot object was passed through
explicitly.
These tests pin the desired behaviour: a request-time ``get_config`` call must
observe the most recent on-disk ``config.yaml`` (mtime reload), and the
runtime ``ContextVar`` override must keep working for per-request injection.
"""
from __future__ import annotations
import os
from pathlib import Path
import pytest
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from app.gateway import deps as gateway_deps
from app.gateway.deps import get_config
from deerflow.config.app_config import (
AppConfig,
pop_current_app_config,
push_current_app_config,
reset_app_config,
set_app_config,
)
from deerflow.config.sandbox_config import SandboxConfig
@pytest.fixture(autouse=True)
def _isolate_app_config_singleton():
"""Ensure each test starts with a clean module-level cache."""
reset_app_config()
yield
reset_app_config()
def _write_config_yaml(path: Path, *, log_level: str) -> None:
path.write_text(
f"""
sandbox:
use: deerflow.sandbox.local.provider:LocalSandboxProvider
log_level: {log_level}
""".strip()
+ "\n",
encoding="utf-8",
)
def _build_app() -> FastAPI:
app = FastAPI()
@app.get("/probe")
def probe(cfg: AppConfig = Depends(get_config)):
return {"log_level": cfg.log_level}
return app
def test_get_config_reflects_file_mtime_reload(tmp_path, monkeypatch):
"""Editing config.yaml at runtime must be visible to /probe without restart.
This is the literal repro for the issue: the gateway must not freeze the
config to whatever was on disk when the process started.
"""
config_file = tmp_path / "config.yaml"
_write_config_yaml(config_file, log_level="info")
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
app = _build_app()
client = TestClient(app)
assert client.get("/probe").json() == {"log_level": "info"}
# Edit the file and bump its mtime — simulating a maintainer changing
# max_tokens / model settings in production while the gateway is live.
_write_config_yaml(config_file, log_level="debug")
future_mtime = config_file.stat().st_mtime + 5
os.utime(config_file, (future_mtime, future_mtime))
assert client.get("/probe").json() == {"log_level": "debug"}
def test_get_config_respects_runtime_context_override(tmp_path, monkeypatch):
"""Per-request ``push_current_app_config`` injection must still win."""
config_file = tmp_path / "config.yaml"
_write_config_yaml(config_file, log_level="info")
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
override = AppConfig(sandbox=SandboxConfig(use="test"), log_level="trace")
push_current_app_config(override)
try:
app = _build_app()
client = TestClient(app)
assert client.get("/probe").json() == {"log_level": "trace"}
finally:
pop_current_app_config()
def test_get_config_respects_test_set_app_config():
"""``set_app_config`` (used by upload/skills router tests) keeps working."""
injected = AppConfig(sandbox=SandboxConfig(use="test"), log_level="warning")
set_app_config(injected)
app = _build_app()
client = TestClient(app)
assert client.get("/probe").json() == {"log_level": "warning"}
def test_run_context_app_config_reflects_yaml_edit(tmp_path, monkeypatch):
"""``RunContext.app_config`` must follow live `config.yaml` edits.
BUG-001 review feedback: the run-context that feeds worker / lead-agent
factories must observe the same mtime reload that `get_config()` does;
otherwise stale config slips back in through the run path even after the
request dependency is fixed.
"""
from unittest.mock import MagicMock
from app.gateway.deps import get_run_context
config_file = tmp_path / "config.yaml"
_write_config_yaml(config_file, log_level="info")
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
app = FastAPI()
# Sentinel values for the rest of the RunContext wiring — we only care
# about ``ctx.app_config`` for this assertion.
app.state.checkpointer = MagicMock()
app.state.store = MagicMock()
app.state.run_event_store = MagicMock()
app.state.run_events_config = {"frozen": "startup"}
app.state.thread_store = MagicMock()
@app.get("/run-ctx-log-level")
def probe(ctx=Depends(get_run_context)):
return {
"log_level": ctx.app_config.log_level,
"run_events_config": ctx.run_events_config,
}
client = TestClient(app)
first = client.get("/run-ctx-log-level").json()
assert first == {"log_level": "info", "run_events_config": {"frozen": "startup"}}
_write_config_yaml(config_file, log_level="debug")
future_mtime = config_file.stat().st_mtime + 5
os.utime(config_file, (future_mtime, future_mtime))
second = client.get("/run-ctx-log-level").json()
# app_config follows the edit; run_events_config stays frozen to the
# startup snapshot we wrote onto app.state above.
assert second == {"log_level": "debug", "run_events_config": {"frozen": "startup"}}
@pytest.mark.parametrize(
"exception",
[
FileNotFoundError("config.yaml not found"),
PermissionError("config.yaml not readable"),
ValueError("invalid config"),
RuntimeError("yaml parse error"),
],
)
def test_get_config_returns_503_on_any_load_failure(monkeypatch, exception):
"""Any failure to materialise the config must surface as 503, not 500.
Bytedance/deer-flow issue #3107 BUG-001 review: the original snapshot
contract returned 503 when ``app.state.config is None``. The first cut of
this fix only mapped ``FileNotFoundError`` to 503, which left
``PermissionError`` / ``yaml.YAMLError`` / ``ValidationError`` etc. bubbling
up as 500. Catch every load failure at the request boundary.
"""
def _broken_get_app_config():
raise exception
monkeypatch.setattr(gateway_deps, "get_app_config", _broken_get_app_config)
app = _build_app()
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/probe")
assert response.status_code == 503
assert response.json() == {"detail": "Configuration not available"}
-41
View File
@@ -1,41 +0,0 @@
from __future__ import annotations
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
def test_get_config_returns_app_state_config():
"""get_config should return the exact AppConfig stored on app.state."""
app = FastAPI()
config = AppConfig(sandbox=SandboxConfig(use="test"))
app.state.config = config
@app.get("/probe")
def probe(cfg: AppConfig = Depends(get_config)):
return {"same_identity": cfg is config, "log_level": cfg.log_level}
client = TestClient(app)
response = client.get("/probe")
assert response.status_code == 200
assert response.json() == {"same_identity": True, "log_level": "info"}
def test_get_config_reads_updated_app_state():
"""Swapping app.state.config should be visible to the dependency."""
app = FastAPI()
app.state.config = AppConfig(sandbox=SandboxConfig(use="test"), log_level="info")
@app.get("/log-level")
def log_level(cfg: AppConfig = Depends(get_config)):
return {"level": cfg.log_level}
client = TestClient(app)
assert client.get("/log-level").json() == {"level": "info"}
app.state.config = app.state.config.model_copy(update={"log_level": "debug"})
assert client.get("/log-level").json() == {"level": "debug"}
@@ -17,7 +17,7 @@ from fastapi import FastAPI
@asynccontextmanager
async def _noop_langgraph_runtime(_app):
async def _noop_langgraph_runtime(_app, _startup_config):
yield
+88
View File
@@ -81,6 +81,94 @@ def test_normalize_input_passthrough():
assert result == {"custom_key": "value"}
def test_normalize_input_preserves_additional_kwargs_and_id():
"""Regression: gh #3132 — frontend ships uploaded-file metadata in
additional_kwargs.files (and a client-side message id). The gateway must
not strip them before the graph runs, otherwise UploadsMiddleware reports
"(empty)" for new uploads and the frontend message loses its file chip.
"""
from langchain_core.messages import HumanMessage
from app.gateway.services import normalize_input
files = [{"filename": "a.csv", "size": 100, "path": "/mnt/user-data/uploads/a.csv", "status": "uploaded"}]
result = normalize_input(
{
"messages": [
{
"type": "human",
"id": "client-msg-1",
"name": "user-input",
"content": [{"type": "text", "text": "clean it"}],
"additional_kwargs": {"files": files, "custom": "keep-me"},
}
]
}
)
assert len(result["messages"]) == 1
msg = result["messages"][0]
assert isinstance(msg, HumanMessage)
assert msg.id == "client-msg-1"
assert msg.name == "user-input"
assert msg.content == [{"type": "text", "text": "clean it"}]
assert msg.additional_kwargs == {"files": files, "custom": "keep-me"}
def test_normalize_input_passes_through_basemessage_instances():
from langchain_core.messages import HumanMessage
from app.gateway.services import normalize_input
msg = HumanMessage(content="hello", id="m-1", additional_kwargs={"files": [{"filename": "x"}]})
result = normalize_input({"messages": [msg]})
assert result["messages"][0] is msg
def test_normalize_input_rejects_malformed_message_with_400():
"""Boundary validation: ``convert_to_messages`` raises ``ValueError`` when a
message dict is missing ``role``/``type``/``content``. ``normalize_input``
runs inside the gateway HTTP boundary, so a malformed payload should surface
as a 400 referencing the offending entry — not bubble up as a 500.
Raised after the Copilot review on PR #3136.
"""
import pytest
from fastapi import HTTPException
from app.gateway.services import normalize_input
with pytest.raises(HTTPException) as excinfo:
normalize_input({"messages": [{"role": "human", "content": "ok"}, {"oops": "no role here"}]})
assert excinfo.value.status_code == 400
assert "input.messages[1]" in excinfo.value.detail
def test_normalize_input_handles_non_human_roles():
"""The previous implementation collapsed every role to HumanMessage with a
`# TODO: handle other message types` comment. Resuming a thread with prior
AI/tool messages would silently rewrite them as human turns — corrupting
the conversation. Use langchain's standard conversion so ai/system/tool
roles round-trip correctly.
"""
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
from app.gateway.services import normalize_input
result = normalize_input(
{
"messages": [
{"role": "system", "content": "sys"},
{"role": "ai", "content": "hi", "id": "ai-1"},
{"role": "tool", "content": "result", "tool_call_id": "call-1"},
]
}
)
types = [type(m) for m in result["messages"]]
assert types == [SystemMessage, AIMessage, ToolMessage]
assert result["messages"][1].id == "ai-1"
assert result["messages"][2].tool_call_id == "call-1"
def test_build_run_config_basic():
from app.gateway.services import build_run_config
@@ -336,8 +336,11 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
)
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
# verify the custom middleware is injected correctly
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
# verify the custom middleware is injected correctly.
# Chain tail order after the custom middleware is:
# ..., custom, SafetyFinishReasonMiddleware, ClarificationMiddleware
# so the custom mock sits at index [-3].
assert len(middlewares) > 0 and isinstance(middlewares[-3], MagicMock)
def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch):
+409
View File
@@ -0,0 +1,409 @@
"""Tests for the MCP persistent-session pool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from deerflow.mcp.session_pool import MCPSessionPool, get_session_pool, reset_session_pool
@pytest.fixture(autouse=True)
def _reset_pool():
reset_session_pool()
yield
reset_session_pool()
# ---------------------------------------------------------------------------
# MCPSessionPool unit tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_session_creates_new():
"""First call for a key creates a new session."""
pool = MCPSessionPool()
mock_session = AsyncMock()
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
session = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
assert session is mock_session
mock_session.initialize.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_session_reuses_existing():
"""Second call for the same key returns the cached session."""
pool = MCPSessionPool()
mock_session = AsyncMock()
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
s2 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
assert s1 is s2
# Only one session should have been created.
assert mock_cm.__aenter__.await_count == 1
@pytest.mark.asyncio
async def test_different_scope_creates_different_session():
"""Different scope keys get different sessions."""
pool = MCPSessionPool()
sessions = [AsyncMock(), AsyncMock()]
idx = 0
class CmFactory:
def __init__(self):
self.enter_count = 0
async def __aenter__(self):
nonlocal idx
s = sessions[idx]
idx += 1
self.enter_count += 1
return s
async def __aexit__(self, *args):
return False
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=lambda *a, **kw: CmFactory()):
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
s2 = await pool.get_session("server", "thread-2", {"transport": "stdio", "command": "x", "args": []})
assert s1 is not s2
assert s1 is sessions[0]
assert s2 is sessions[1]
@pytest.mark.asyncio
async def test_lru_eviction():
"""Oldest entries are evicted when the pool is full."""
pool = MCPSessionPool()
pool.MAX_SESSIONS = 2
class CmFactory:
def __init__(self):
self.closed = False
async def __aenter__(self):
return AsyncMock()
async def __aexit__(self, *args):
self.closed = True
return False
cms: list[CmFactory] = []
def make_cm(*a, **kw):
cm = CmFactory()
cms.append(cm)
return cm
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
# Pool is full (2). Adding t3 should evict t1.
await pool.get_session("s", "t3", {"transport": "stdio", "command": "x", "args": []})
assert cms[0].closed is True
assert cms[1].closed is False
assert cms[2].closed is False
@pytest.mark.asyncio
async def test_close_scope():
"""close_scope shuts down sessions for a specific scope key."""
pool = MCPSessionPool()
class CmFactory:
def __init__(self):
self.closed = False
async def __aenter__(self):
return AsyncMock()
async def __aexit__(self, *args):
self.closed = True
return False
cms: list[CmFactory] = []
def make_cm(*a, **kw):
cm = CmFactory()
cms.append(cm)
return cm
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
await pool.close_scope("t1")
assert cms[0].closed is True
assert cms[1].closed is False
# t2 session still exists.
assert ("s", "t2") in pool._entries
@pytest.mark.asyncio
async def test_close_all():
"""close_all shuts down every session."""
pool = MCPSessionPool()
class CmFactory:
def __init__(self):
self.closed = False
async def __aenter__(self):
return AsyncMock()
async def __aexit__(self, *args):
self.closed = True
return False
cms: list[CmFactory] = []
def make_cm(*a, **kw):
cm = CmFactory()
cms.append(cm)
return cm
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
await pool.get_session("s1", "t1", {"transport": "stdio", "command": "x", "args": []})
await pool.get_session("s2", "t2", {"transport": "stdio", "command": "x", "args": []})
await pool.close_all()
assert all(cm.closed for cm in cms)
assert len(pool._entries) == 0
# ---------------------------------------------------------------------------
# Singleton helpers
# ---------------------------------------------------------------------------
def test_get_session_pool_singleton():
"""get_session_pool returns the same instance."""
p1 = get_session_pool()
p2 = get_session_pool()
assert p1 is p2
def test_reset_session_pool():
"""reset_session_pool clears the singleton."""
p1 = get_session_pool()
reset_session_pool()
p2 = get_session_pool()
assert p1 is not p2
# ---------------------------------------------------------------------------
# Integration: _make_session_pool_tool uses the pool
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_session_pool_tool_wrapping():
"""The wrapper tool delegates to a pool-managed session."""
# Build a dummy StructuredTool (as returned by langchain-mcp-adapters).
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
url: str = Field(..., description="url")
original_tool = StructuredTool(
name="playwright_navigate",
description="Navigate browser",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
connection = {"transport": "stdio", "command": "pw", "args": []}
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
# Simulate a tool call with a runtime context containing thread_id.
mock_runtime = MagicMock()
mock_runtime.context = {"thread_id": "thread-42"}
mock_runtime.config = {}
await wrapped.coroutine(runtime=mock_runtime, url="https://example.com")
mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"})
@pytest.mark.asyncio
async def test_session_pool_tool_extracts_thread_id():
"""Thread ID is extracted from runtime.config when not in context."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
x: int = Field(..., description="x")
original_tool = StructuredTool(
name="server_tool",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
mock_runtime = MagicMock()
mock_runtime.context = {}
mock_runtime.config = {"configurable": {"thread_id": "from-config"}}
await wrapped.coroutine(runtime=mock_runtime, x=1)
# Verify the session was created with the correct scope key.
pool = get_session_pool()
assert ("server", "from-config") in pool._entries
@pytest.mark.asyncio
async def test_session_pool_tool_default_scope():
"""When no thread_id is available, 'default' is used as scope key."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
x: int = Field(..., description="x")
original_tool = StructuredTool(
name="server_tool",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
# No thread_id in runtime at all.
await wrapped.coroutine(runtime=None, x=1)
pool = get_session_pool()
assert ("server", "default") in pool._entries
@pytest.mark.asyncio
async def test_session_pool_tool_get_config_fallback():
"""When runtime is None, get_config() provides thread_id as fallback."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
x: int = Field(..., description="x")
original_tool = StructuredTool(
name="server_tool",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
fake_config = {"configurable": {"thread_id": "from-langgraph-config"}}
with (
patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm),
patch("deerflow.mcp.tools.get_config", return_value=fake_config),
):
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
# runtime=None — get_config() fallback should provide thread_id
await wrapped.coroutine(runtime=None, x=1)
pool = get_session_pool()
assert ("server", "from-langgraph-config") in pool._entries
def test_session_pool_tool_sync_wrapper_path_is_safe():
"""Sync wrapper (tool.func) invocation doesn't crash on cross-loop access."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
from deerflow.tools.sync import make_sync_tool_wrapper
class Args(BaseModel):
url: str = Field(..., description="url")
original_tool = StructuredTool(
name="playwright_navigate",
description="Navigate browser",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
connection = {"transport": "stdio", "command": "pw", "args": []}
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
# Attach the sync wrapper exactly as get_mcp_tools() does.
wrapped.func = make_sync_tool_wrapper(wrapped.coroutine, wrapped.name)
# Call via the sync path (asyncio.run in a worker thread).
# runtime is not supplied so _extract_thread_id falls back to "default".
wrapped.func(url="https://example.com")
mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"})
+104
View File
@@ -714,6 +714,110 @@ class TestExternalUsageRecords:
assert j._subagent_tokens == 0
class TestProgressSnapshots:
@pytest.mark.anyio
async def test_on_llm_end_reports_progress_snapshot(self):
snapshots: list[dict] = []
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=0,
)
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("Answer", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
await j.flush()
assert snapshots
assert snapshots[-1]["total_tokens"] == 15
assert snapshots[-1]["llm_call_count"] == 1
assert snapshots[-1]["message_count"] == 1
assert snapshots[-1]["last_ai_message"] == "Answer"
@pytest.mark.anyio
async def test_throttled_progress_flush_emits_trailing_snapshot(self):
snapshots: list[dict] = []
trailing_seen = asyncio.Event()
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
if snapshot["total_tokens"] == 45:
trailing_seen.set()
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=0.01,
)
j.on_llm_end(
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
j.on_llm_end(
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.wait_for(trailing_seen.wait(), timeout=1.0)
await j.flush()
assert len(snapshots) >= 2
assert snapshots[-1]["total_tokens"] == 45
assert snapshots[-1]["llm_call_count"] == 2
assert snapshots[-1]["last_ai_message"] == "Second"
@pytest.mark.anyio
async def test_flush_cancels_delayed_progress_without_final_progress_write(self):
snapshots: list[dict] = []
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=10.0,
)
j.on_llm_end(
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.sleep(0)
assert snapshots[-1]["total_tokens"] == 15
j.on_llm_end(
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.wait_for(j.flush(), timeout=0.2)
assert snapshots[-1]["total_tokens"] == 15
assert snapshots[-1]["llm_call_count"] == 1
assert snapshots[-1]["last_ai_message"] == "First"
class TestChatModelStartHumanMessage:
"""Tests for on_chat_model_start extracting the first human message."""
+122
View File
@@ -10,6 +10,7 @@ from sqlalchemy.dialects import postgresql
from deerflow.persistence.run import RunRepository
from deerflow.runtime import RunManager, RunStatus
from deerflow.runtime.runs.store.base import RunStore
async def _make_repo(tmp_path):
@@ -26,6 +27,42 @@ async def _cleanup():
await close_engine()
class _CustomRunStoreWithoutProgress(RunStore):
async def put(self, *args, **kwargs):
return None
async def get(self, *args, **kwargs):
return None
async def list_by_thread(self, *args, **kwargs):
return []
async def update_status(self, *args, **kwargs):
return None
async def delete(self, *args, **kwargs):
return None
async def update_model_name(self, *args, **kwargs):
return None
async def update_run_completion(self, *args, **kwargs):
return None
async def list_pending(self, *args, **kwargs):
return []
async def aggregate_tokens_by_thread(self, *args, **kwargs):
return {}
@pytest.mark.anyio
async def test_update_run_progress_defaults_to_noop_for_custom_store():
store = _CustomRunStoreWithoutProgress()
await store.update_run_progress("r1", total_tokens=1)
class TestRunRepository:
@pytest.mark.anyio
async def test_put_and_get(self, tmp_path):
@@ -170,6 +207,69 @@ class TestRunRepository:
assert row["total_tokens"] == 100
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_keeps_status_running(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_progress(
"r1",
total_input_tokens=40,
total_output_tokens=10,
total_tokens=50,
llm_call_count=1,
message_count=2,
last_ai_message="partial answer",
)
row = await repo.get("r1")
assert row["status"] == "running"
assert row["total_tokens"] == 50
assert row["llm_call_count"] == 1
assert row["message_count"] == 2
assert row["last_ai_message"] == "partial answer"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_preserves_omitted_fields(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_progress(
"r1",
total_input_tokens=40,
total_output_tokens=10,
total_tokens=50,
llm_call_count=1,
lead_agent_tokens=30,
subagent_tokens=20,
message_count=2,
)
await repo.update_run_progress("r1", total_tokens=60, last_ai_message="updated")
row = await repo.get("r1")
assert row["total_input_tokens"] == 40
assert row["total_output_tokens"] == 10
assert row["total_tokens"] == 60
assert row["llm_call_count"] == 1
assert row["lead_agent_tokens"] == 30
assert row["subagent_tokens"] == 20
assert row["message_count"] == 2
assert row["last_ai_message"] == "updated"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_skips_terminal_runs(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_completion("r1", status="success", total_tokens=100, llm_call_count=1)
await repo.update_run_progress("r1", total_tokens=200, llm_call_count=2)
row = await repo.get("r1")
assert row["status"] == "success"
assert row["total_tokens"] == 100
assert row["llm_call_count"] == 1
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
repo = await _make_repo(tmp_path)
@@ -225,6 +325,28 @@ class TestRunRepository:
}
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_can_include_active_runs(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("success-run", thread_id="t1", status="running")
await repo.update_run_completion("success-run", status="success", total_tokens=100, lead_agent_tokens=100)
await repo.put("running-run", thread_id="t1", status="running")
await repo.update_run_progress("running-run", total_tokens=25, lead_agent_tokens=20, subagent_tokens=5)
without_active = await repo.aggregate_tokens_by_thread("t1")
with_active = await repo.aggregate_tokens_by_thread("t1", include_active=True)
assert without_active["total_tokens"] == 100
assert without_active["total_runs"] == 1
assert with_active["total_tokens"] == 125
assert with_active["total_runs"] == 2
assert with_active["by_caller"] == {
"lead_agent": 120,
"subagent": 5,
"middleware": 0,
}
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread_ordered_desc(self, tmp_path):
"""list_by_thread returns newest first."""
@@ -0,0 +1,225 @@
"""End-to-end graph integration test for SafetyFinishReasonMiddleware.
Unit tests prove ``_apply`` does the right thing on a synthetic state.
This test does one level up: builds a real ``langchain.agents.create_agent``
graph with the SafetyFinishReasonMiddleware in place, feeds it a fake model
that returns ``finish_reason='content_filter'`` + tool_calls, and asserts:
1. The tool node is **not** invoked (the dangerous truncated tool call
is suppressed).
2. The final AIMessage in graph state has ``tool_calls == []``.
3. The observability ``safety_termination`` record is attached.
4. The user-facing explanation is appended to the message content.
This is the closest we can get to the issue's failure mode without a live
Moonshot key, and it proves the middleware actually gates LangChain's
tool router — not just rewrites state in isolation.
"""
from __future__ import annotations
from typing import Any
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import tool
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
_TOOL_INVOCATIONS: list[dict[str, Any]] = []
@tool
def write_file(path: str, content: str) -> str:
"""Pretend to write *content* to *path*. Records the call for assertion."""
_TOOL_INVOCATIONS.append({"path": path, "content": content})
return f"wrote {len(content)} bytes to {path}"
class _ContentFilteredModel(BaseChatModel):
"""Fake chat model that mimics OpenAI/Moonshot's content_filter response.
First call returns finish_reason='content_filter' + a tool_call whose
arguments are visibly truncated. Second call (if reached) returns a
normal text completion so the agent can terminate cleanly.
"""
call_count: int = 0
@property
def _llm_type(self) -> str:
return "fake-content-filtered"
def bind_tools(self, tools, **kwargs):
# create_agent binds tools onto the model; we don't actually need
# to bind anything since responses are hard-coded, but the method
# must not raise.
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self.call_count += 1
if self.call_count == 1:
message = AIMessage(
content="Here is the report:\n# Weekly Politics\n- Meeting time: 2026-05-12—",
tool_calls=[
{
"id": "call_truncated_1",
"name": "write_file",
"args": {
"path": "/mnt/user-data/outputs/report.md",
"content": "# Weekly Politics\n- Meeting time: 2026-05-12—",
},
}
],
response_metadata={"finish_reason": "content_filter", "model_name": "fake-kimi"},
)
else:
message = AIMessage(content="ack", response_metadata={"finish_reason": "stop"})
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
class _InspectMiddleware(AgentMiddleware):
"""Captures the messages list at every model entry so we can assert
no synthetic tool result was injected back into the conversation."""
def __init__(self) -> None:
super().__init__()
self.observed: list[list[Any]] = []
def wrap_model_call(self, request: ModelRequest, handler) -> ModelResponse:
self.observed.append(list(request.messages))
return handler(request)
def test_content_filter_with_tool_calls_does_not_invoke_tool_node():
_TOOL_INVOCATIONS.clear()
inspector = _InspectMiddleware()
agent = create_agent(
model=_ContentFilteredModel(),
tools=[write_file],
# Inspector first so its after_model is registered; Safety last in
# the list so it executes first under LIFO (matches production wiring).
middleware=[inspector, SafetyFinishReasonMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage(content="write me a report")]})
# Critical assertion: the dangerous truncated tool call must NOT have
# been executed. This is the entire point of the middleware.
assert _TOOL_INVOCATIONS == [], f"write_file was invoked despite content_filter: {_TOOL_INVOCATIONS}"
# Final AIMessage has no tool calls left.
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
assert final_ai.tool_calls == []
# Observability stamp is present.
record = final_ai.additional_kwargs.get("safety_termination")
assert record is not None
assert record["detector"] == "openai_compatible_content_filter"
assert record["reason_field"] == "finish_reason"
assert record["reason_value"] == "content_filter"
assert record["suppressed_tool_call_count"] == 1
assert record["suppressed_tool_call_names"] == ["write_file"]
# User-facing explanation is appended.
assert "safety-related signal" in final_ai.content
# Original partial text preserved (we don't throw away what the user
# already saw in the stream — see middleware docstring).
assert "Weekly Politics" in final_ai.content
# finish_reason on response_metadata is preserved (so SSE / converters
# downstream still see the real provider reason).
assert final_ai.response_metadata.get("finish_reason") == "content_filter"
def test_content_filter_without_tool_calls_passes_through_unchanged():
"""No tool calls => issue scope says don't intervene; the partial
response should be delivered as-is so the user sees what they got."""
_TOOL_INVOCATIONS.clear()
class _NoToolModel(BaseChatModel):
@property
def _llm_type(self) -> str:
return "fake-no-tool"
def bind_tools(self, tools, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
msg = AIMessage(
content="Partial answer truncated by safety filter",
response_metadata={"finish_reason": "content_filter"},
)
return ChatResult(generations=[ChatGeneration(message=msg)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
agent = create_agent(
model=_NoToolModel(),
tools=[write_file],
middleware=[SafetyFinishReasonMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage(content="hi")]})
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
# Content untouched.
assert final_ai.content == "Partial answer truncated by safety filter"
# No safety_termination stamp because we didn't intervene.
assert "safety_termination" not in final_ai.additional_kwargs
# tool node never ran (there were no tool calls in the first place).
assert _TOOL_INVOCATIONS == []
def test_normal_tool_call_round_trip_is_not_affected():
"""Regression: a healthy finish_reason='tool_calls' response must still
execute the tool. The middleware must not over-fire."""
_TOOL_INVOCATIONS.clear()
class _HealthyToolModel(BaseChatModel):
call_count: int = 0
@property
def _llm_type(self) -> str:
return "fake-healthy"
def bind_tools(self, tools, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self.call_count += 1
if self.call_count == 1:
msg = AIMessage(
content="",
tool_calls=[
{
"id": "call_ok",
"name": "write_file",
"args": {"path": "/tmp/ok", "content": "complete content"},
}
],
response_metadata={"finish_reason": "tool_calls"},
)
else:
msg = AIMessage(content="done", response_metadata={"finish_reason": "stop"})
return ChatResult(generations=[ChatGeneration(message=msg)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
agent = create_agent(
model=_HealthyToolModel(),
tools=[write_file],
middleware=[SafetyFinishReasonMiddleware()],
)
agent.invoke({"messages": [HumanMessage(content="write")]})
assert _TOOL_INVOCATIONS == [{"path": "/tmp/ok", "content": "complete content"}]
@@ -0,0 +1,651 @@
"""Unit tests for SafetyFinishReasonMiddleware."""
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
from deerflow.agents.middlewares.safety_termination_detectors import (
SafetyTermination,
)
from deerflow.config.safety_finish_reason_config import (
SafetyDetectorConfig,
SafetyFinishReasonConfig,
)
def _runtime(thread_id="t-1"):
runtime = MagicMock()
runtime.context = {"thread_id": thread_id}
return runtime
def _ai(
*,
content="",
tool_calls=None,
response_metadata=None,
additional_kwargs=None,
):
return AIMessage(
content=content,
tool_calls=tool_calls or [],
response_metadata=response_metadata or {},
additional_kwargs=additional_kwargs or {},
)
def _write_call(idx=1, content_text="半截"):
return {
"id": f"call_write_{idx}",
"name": "write_file",
"args": {"path": "/mnt/user-data/outputs/x.md", "content": content_text},
}
class AlwaysHitDetector:
"""Test fixture: always reports the given termination."""
name = "always_hit"
def __init__(self, *, reason_field="finish_reason", reason_value="content_filter", extras=None):
self.reason_field = reason_field
self.reason_value = reason_value
self.extras = extras or {}
def detect(self, message):
return SafetyTermination(
detector=self.name,
reason_field=self.reason_field,
reason_value=self.reason_value,
extras=self.extras,
)
class NeverHitDetector:
name = "never_hit"
def detect(self, message):
return None
class RaisingDetector:
name = "raising"
def detect(self, message):
raise RuntimeError("boom")
# ---------------------------------------------------------------------------
# Core trigger behaviour
# ---------------------------------------------------------------------------
class TestTriggerCriteria:
def test_content_filter_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="partial",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
patched = result["messages"][0]
assert patched.tool_calls == []
def test_content_filter_without_tool_calls_passes_through(self):
"""issue scope: when there are no tool calls the partial text is a
legitimate final response and should not be rewritten."""
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="partial response",
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_normal_tool_calls_pass_through(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "tool_calls"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_normal_stop_with_tool_calls_pass_through(self):
# Some providers report finish_reason='stop' for tool-call messages.
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "stop"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_empty_message_list_passes_through(self):
mw = SafetyFinishReasonMiddleware()
assert mw._apply({"messages": []}, _runtime()) is None
def test_non_ai_last_message_passes_through(self):
mw = SafetyFinishReasonMiddleware()
state = {"messages": [HumanMessage(content="hi"), SystemMessage(content="sys")]}
assert mw._apply(state, _runtime()) is None
def test_anthropic_refusal_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"stop_reason": "refusal"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].tool_calls == []
def test_gemini_safety_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "SAFETY"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].tool_calls == []
# ---------------------------------------------------------------------------
# Message rewriting
# ---------------------------------------------------------------------------
class TestMessageRewrite:
def test_clears_structured_tool_calls(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call(1), _write_call(2)],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, _runtime())
patched = result["messages"][0]
assert patched.tool_calls == []
def test_clears_raw_additional_kwargs_tool_calls(self):
"""Critical defence-in-depth: DanglingToolCallMiddleware will recover
tool calls from additional_kwargs.tool_calls if we forget them, which
would re-emit a synthetic ToolMessage downstream and confuse the
model. We must wipe both."""
mw = SafetyFinishReasonMiddleware()
raw_tool_calls = [
{
"id": "call_write_1",
"type": "function",
"function": {"name": "write_file", "arguments": '{"path": "/x"}'},
}
]
state = {
"messages": [
_ai(
tool_calls=[_write_call(1)],
response_metadata={"finish_reason": "content_filter"},
additional_kwargs={
"tool_calls": raw_tool_calls,
"function_call": {"name": "write_file", "arguments": "{}"},
},
)
]
}
result = mw._apply(state, _runtime())
patched = result["messages"][0]
assert "tool_calls" not in patched.additional_kwargs
assert "function_call" not in patched.additional_kwargs
def test_preserves_other_additional_kwargs(self):
# vLLM puts reasoning under additional_kwargs.reasoning; Anthropic
# may carry other provider-specific keys. They must not be wiped.
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
additional_kwargs={
"reasoning": "thinking text",
"custom_provider_field": {"x": 1},
},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.additional_kwargs["reasoning"] == "thinking text"
assert patched.additional_kwargs["custom_provider_field"] == {"x": 1}
def test_writes_observability_field(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call(1), _write_call(2)],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
record = patched.additional_kwargs["safety_termination"]
assert record["detector"] == "openai_compatible_content_filter"
assert record["reason_field"] == "finish_reason"
assert record["reason_value"] == "content_filter"
assert record["suppressed_tool_call_count"] == 2
assert record["suppressed_tool_call_names"] == ["write_file", "write_file"]
def test_preserves_response_metadata_finish_reason(self):
"""Downstream SSE converters read response_metadata.finish_reason —
we want them to see the *real* provider reason, not 'stop'."""
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter", "model_name": "kimi-k2"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.response_metadata["finish_reason"] == "content_filter"
assert patched.response_metadata["model_name"] == "kimi-k2"
def test_appends_user_facing_explanation_to_str_content(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="some partial text",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, str)
assert patched.content.startswith("some partial text")
assert "safety-related signal" in patched.content
def test_handles_empty_content(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, str)
assert "safety-related signal" in patched.content
def test_handles_list_content_thinking_blocks(self):
"""Anthropic thinking / vLLM reasoning models emit content blocks.
Naively concatenating a string would raise TypeError."""
mw = SafetyFinishReasonMiddleware()
thinking_blocks = [
{"type": "thinking", "text": "let me consider..."},
{"type": "text", "text": "partial answer"},
]
state = {
"messages": [
_ai(
content=thinking_blocks,
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, list)
assert patched.content[:2] == thinking_blocks
assert patched.content[-1]["type"] == "text"
assert "safety-related signal" in patched.content[-1]["text"]
def test_idempotent_on_already_cleared_message(self):
# Re-running the middleware on a message we already cleared must not
# re-trigger (tool_calls is now empty → fast passthrough).
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
first = mw._apply(state, _runtime())
state2 = {"messages": [first["messages"][0]]}
second = mw._apply(state2, _runtime())
assert second is None
def test_preserves_message_id_for_add_messages_replacement(self):
"""LangGraph's add_messages reducer treats same-id messages as
replacements. model_copy keeps id by default."""
mw = SafetyFinishReasonMiddleware()
original = _ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
# AIMessage auto-generates id; capture it
original_id = original.id
state = {"messages": [original]}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.id == original_id
# ---------------------------------------------------------------------------
# Detector wiring
# ---------------------------------------------------------------------------
class TestDetectorWiring:
def test_iterates_detectors_in_order(self):
first = AlwaysHitDetector(reason_value="first")
second = AlwaysHitDetector(reason_value="second")
mw = SafetyFinishReasonMiddleware(detectors=[first, second])
state = {"messages": [_ai(tool_calls=[_write_call()])]}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.additional_kwargs["safety_termination"]["reason_value"] == "first"
def test_returns_none_when_no_detector_matches(self):
mw = SafetyFinishReasonMiddleware(detectors=[NeverHitDetector(), NeverHitDetector()])
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_buggy_detector_does_not_break_run(self):
mw = SafetyFinishReasonMiddleware(detectors=[RaisingDetector(), AlwaysHitDetector()])
state = {"messages": [_ai(tool_calls=[_write_call()])]}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].additional_kwargs["safety_termination"]["detector"] == "always_hit"
def test_constructor_copies_detectors(self):
"""Caller mutation after construction must not leak into us."""
detectors = [AlwaysHitDetector()]
mw = SafetyFinishReasonMiddleware(detectors=detectors)
detectors.clear()
state = {"messages": [_ai(tool_calls=[_write_call()])]}
assert mw._apply(state, _runtime()) is not None
# ---------------------------------------------------------------------------
# from_config
# ---------------------------------------------------------------------------
class TestFromConfig:
def test_default_config_uses_builtin_detectors(self):
mw = SafetyFinishReasonMiddleware.from_config(SafetyFinishReasonConfig())
assert len(mw._detectors) == 3
names = {d.name for d in mw._detectors}
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
def test_custom_detectors_loaded_via_reflection(self):
cfg = SafetyFinishReasonConfig(
detectors=[
SafetyDetectorConfig(
use="deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector",
config={"finish_reasons": ["custom_filter"]},
),
]
)
mw = SafetyFinishReasonMiddleware.from_config(cfg)
assert len(mw._detectors) == 1
# Confirm the kwargs propagated.
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "custom_filter"},
)
]
}
assert mw._apply(state, _runtime()) is not None
# Default token no longer matches.
state2 = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state2, _runtime()) is None
def test_empty_detector_list_rejected(self):
cfg = SafetyFinishReasonConfig(detectors=[])
with pytest.raises(ValueError, match="enabled=false"):
SafetyFinishReasonMiddleware.from_config(cfg)
def test_non_detector_class_rejected(self):
cfg = SafetyFinishReasonConfig(
detectors=[SafetyDetectorConfig(use="builtins:dict")],
)
with pytest.raises(TypeError):
SafetyFinishReasonMiddleware.from_config(cfg)
# ---------------------------------------------------------------------------
# Stream event
# ---------------------------------------------------------------------------
class TestAuditEvent:
"""Verify SafetyFinishReasonMiddleware records a `middleware:safety_termination`
audit event via RunJournal.record_middleware when the run-scoped journal is
exposed under runtime.context["__run_journal"].
Background: review on PR #3035 — SSE custom event handles live consumers,
but post-run audit needs a row in run_events that can be queried with one
SQL statement (no JOIN against message body).
"""
def _runtime_with_journal(self, journal):
runtime = MagicMock()
runtime.context = {"thread_id": "t-audit", "__run_journal": journal}
return runtime
def test_records_audit_event_when_journal_present(self):
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
tc = _write_call(1)
state = {
"messages": [
_ai(
content="partial",
tool_calls=[tc],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, self._runtime_with_journal(journal))
assert result is not None
journal.record_middleware.assert_called_once()
call = journal.record_middleware.call_args
# tag is positional or kwarg depending on call style; we use kwargs.
assert call.kwargs["tag"] == "safety_termination"
assert call.kwargs["name"] == "SafetyFinishReasonMiddleware"
assert call.kwargs["hook"] == "after_model"
assert call.kwargs["action"] == "suppress_tool_calls"
changes = call.kwargs["changes"]
assert changes["detector"] == "openai_compatible_content_filter"
assert changes["reason_field"] == "finish_reason"
assert changes["reason_value"] == "content_filter"
assert changes["suppressed_tool_call_count"] == 1
assert changes["suppressed_tool_call_names"] == ["write_file"]
assert changes["suppressed_tool_call_ids"] == ["call_write_1"]
assert "message_id" in changes
assert isinstance(changes["extras"], dict)
def test_audit_event_never_carries_tool_arguments(self):
"""PR #3035 review IMPORTANT: tool args are the filtered content itself
and must NOT be persisted to run_events under any circumstance."""
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
sensitive_tc = {
"id": "call_x",
"name": "write_file",
"args": {"path": "/x", "content": "FILTERED_CONTENT_DO_NOT_PERSIST"},
}
state = {
"messages": [
_ai(
tool_calls=[sensitive_tc],
response_metadata={"finish_reason": "content_filter"},
)
]
}
mw._apply(state, self._runtime_with_journal(journal))
flat = repr(journal.record_middleware.call_args)
assert "FILTERED_CONTENT_DO_NOT_PERSIST" not in flat, "tool arguments must not leak into audit event"
assert "args" not in journal.record_middleware.call_args.kwargs["changes"]
def test_no_journal_in_runtime_context_is_silently_skipped(self):
"""Subagent runtime / unit tests / no-event-store paths have no journal.
Middleware must still intervene and clear tool_calls — only the audit
event is skipped."""
mw = SafetyFinishReasonMiddleware()
runtime = MagicMock()
runtime.context = {"thread_id": "t-noj"} # no __run_journal
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Should not raise; should still clear tool_calls.
result = mw._apply(state, runtime)
assert result is not None
assert result["messages"][0].tool_calls == []
def test_journal_record_exception_does_not_break_run(self):
"""Buggy journal must never propagate an exception into the agent loop."""
journal = MagicMock()
journal.record_middleware.side_effect = RuntimeError("db down")
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Must not raise.
result = mw._apply(state, self._runtime_with_journal(journal))
assert result is not None
assert result["messages"][0].tool_calls == []
def test_no_record_when_passthrough(self):
"""When the middleware does NOT intervene, no audit event is written."""
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "tool_calls"}, # healthy
)
]
}
assert mw._apply(state, self._runtime_with_journal(journal)) is None
journal.record_middleware.assert_not_called()
class TestStreamEvent:
def test_emits_event_when_writer_available(self, monkeypatch):
captured: list = []
def fake_writer(payload):
captured.append(payload)
# Patch get_stream_writer at the symbol-resolution site.
import langgraph.config
monkeypatch.setattr(langgraph.config, "get_stream_writer", lambda: fake_writer)
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
mw._apply(state, _runtime("t-stream"))
assert len(captured) == 1
payload = captured[0]
assert payload["type"] == "safety_termination"
assert payload["detector"] == "openai_compatible_content_filter"
assert payload["reason_field"] == "finish_reason"
assert payload["reason_value"] == "content_filter"
assert payload["suppressed_tool_call_count"] == 1
assert payload["suppressed_tool_call_names"] == ["write_file"]
assert payload["thread_id"] == "t-stream"
def test_writer_unavailable_does_not_break(self, monkeypatch):
import langgraph.config
def boom():
raise LookupError("not in a stream context")
monkeypatch.setattr(langgraph.config, "get_stream_writer", boom)
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Should not raise.
result = mw._apply(state, _runtime())
assert result is not None
@@ -0,0 +1,176 @@
"""Unit tests for SafetyTerminationDetector built-ins."""
from langchain_core.messages import AIMessage
from deerflow.agents.middlewares.safety_termination_detectors import (
AnthropicRefusalDetector,
GeminiSafetyDetector,
OpenAICompatibleContentFilterDetector,
SafetyTermination,
SafetyTerminationDetector,
default_detectors,
)
def _ai(*, content="", tool_calls=None, response_metadata=None, additional_kwargs=None) -> AIMessage:
return AIMessage(
content=content,
tool_calls=tool_calls or [],
response_metadata=response_metadata or {},
additional_kwargs=additional_kwargs or {},
)
class TestOpenAICompatibleContentFilterDetector:
def test_default_matches_content_filter(self):
d = OpenAICompatibleContentFilterDetector()
hit = d.detect(_ai(response_metadata={"finish_reason": "content_filter"}))
assert hit is not None
assert hit.detector == "openai_compatible_content_filter"
assert hit.reason_field == "finish_reason"
assert hit.reason_value == "content_filter"
def test_case_insensitive_match(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "CONTENT_FILTER"})) is not None
def test_other_finish_reasons_pass_through(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "stop"})) is None
assert d.detect(_ai(response_metadata={"finish_reason": "tool_calls"})) is None
assert d.detect(_ai(response_metadata={"finish_reason": "length"})) is None
def test_missing_metadata_passes_through(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai()) is None
def test_non_string_finish_reason_passes_through(self):
# Some adapters may stash an enum or dict — must not raise.
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": 42})) is None
assert d.detect(_ai(response_metadata={"finish_reason": {"value": "content_filter"}})) is None
def test_falls_back_to_additional_kwargs(self):
# Legacy adapters surface finish_reason via additional_kwargs.
d = OpenAICompatibleContentFilterDetector()
hit = d.detect(_ai(additional_kwargs={"finish_reason": "content_filter"}))
assert hit is not None
def test_configurable_extra_values(self):
# Chinese providers sometimes use bespoke tokens.
d = OpenAICompatibleContentFilterDetector(finish_reasons=["content_filter", "sensitive", "violation"])
assert d.detect(_ai(response_metadata={"finish_reason": "sensitive"})) is not None
assert d.detect(_ai(response_metadata={"finish_reason": "violation"})) is not None
# Original token still matches.
assert d.detect(_ai(response_metadata={"finish_reason": "content_filter"})) is not None
def test_carries_azure_content_filter_results(self):
d = OpenAICompatibleContentFilterDetector()
filter_results = {"hate": {"filtered": True, "severity": "high"}}
hit = d.detect(
_ai(
response_metadata={
"finish_reason": "content_filter",
"content_filter_results": filter_results,
},
)
)
assert hit is not None
assert hit.extras["content_filter_results"] == filter_results
class TestAnthropicRefusalDetector:
def test_default_matches_refusal(self):
hit = AnthropicRefusalDetector().detect(_ai(response_metadata={"stop_reason": "refusal"}))
assert hit is not None
assert hit.reason_field == "stop_reason"
assert hit.reason_value == "refusal"
def test_other_stop_reasons_pass_through(self):
d = AnthropicRefusalDetector()
assert d.detect(_ai(response_metadata={"stop_reason": "end_turn"})) is None
assert d.detect(_ai(response_metadata={"stop_reason": "tool_use"})) is None
assert d.detect(_ai(response_metadata={"stop_reason": "max_tokens"})) is None
def test_anthropic_does_not_steal_finish_reason(self):
# An OpenAI message must not accidentally trip the Anthropic detector.
assert AnthropicRefusalDetector().detect(_ai(response_metadata={"finish_reason": "content_filter"})) is None
class TestGeminiSafetyDetector:
def test_default_set_covers_documented_reasons(self):
d = GeminiSafetyDetector()
for reason in (
# text safety
"SAFETY",
"BLOCKLIST",
"PROHIBITED_CONTENT",
"SPII",
"RECITATION",
# image safety
"IMAGE_SAFETY",
"IMAGE_PROHIBITED_CONTENT",
"IMAGE_RECITATION",
):
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is not None, reason
def test_normal_termination_passes_through(self):
d = GeminiSafetyDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "STOP"})) is None
# MAX_TOKENS / LANGUAGE / NO_IMAGE / OTHER / IMAGE_OTHER /
# MALFORMED_FUNCTION_CALL / UNEXPECTED_TOOL_CALL are intentionally
# excluded from the default set — they are either normal termination,
# capability mismatches, too broad (OTHER), or tool-call protocol
# errors. See GeminiSafetyDetector docstring.
for reason in (
"MAX_TOKENS",
"LANGUAGE",
"NO_IMAGE",
"OTHER",
"IMAGE_OTHER",
"MALFORMED_FUNCTION_CALL",
"UNEXPECTED_TOOL_CALL",
"FINISH_REASON_UNSPECIFIED",
):
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is None, reason
def test_carries_safety_ratings(self):
ratings = [{"category": "HARM_CATEGORY_HARASSMENT", "probability": "HIGH"}]
hit = GeminiSafetyDetector().detect(
_ai(
response_metadata={
"finish_reason": "SAFETY",
"safety_ratings": ratings,
},
)
)
assert hit is not None
assert hit.extras["safety_ratings"] == ratings
class TestDefaultDetectorSet:
def test_default_set_returns_three_detectors(self):
dets = default_detectors()
names = {d.name for d in dets}
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
def test_default_set_returns_fresh_list(self):
# Caller mutation must not affect later calls.
first = default_detectors()
first.clear()
second = default_detectors()
assert len(second) == 3
class TestProtocolConformance:
def test_builtins_satisfy_protocol(self):
for d in default_detectors():
assert isinstance(d, SafetyTerminationDetector)
def test_safety_termination_is_frozen(self):
t = SafetyTermination(detector="x", reason_field="finish_reason", reason_value="content_filter")
try:
t.detector = "y" # type: ignore[misc]
except Exception:
return
raise AssertionError("SafetyTermination should be frozen")
@@ -5,6 +5,7 @@ from unittest.mock import patch
import pytest
from deerflow.sandbox.exceptions import SandboxError
from deerflow.sandbox.tools import (
VIRTUAL_PATH_PREFIX,
_apply_cwd_prefix,
@@ -1140,6 +1141,170 @@ def test_str_replace_and_append_on_same_path_should_preserve_both_updates(monkey
assert sandbox.content == "ALPHA\ntail\n"
def test_write_file_tool_bounds_large_oserror_and_masks_local_paths(monkeypatch) -> None:
class FailingSandbox:
id = "sandbox-write-large-oserror"
def write_file(self, path: str, content: str, append: bool = False) -> None:
host_path = f"{_THREAD_DATA['workspace_path']}/nested/output.txt"
raise OSError(f"write failed at {host_path}\n{'A' * 12000}\nremote tail marker")
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
sandbox = FailingSandbox()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: True)
monkeypatch.setattr("deerflow.sandbox.tools.get_thread_data", lambda runtime: _THREAD_DATA)
monkeypatch.setattr("deerflow.sandbox.tools.validate_local_tool_path", lambda path, thread_data: None)
monkeypatch.setattr(
"deerflow.sandbox.tools._resolve_and_validate_user_data_path",
lambda path, thread_data: f"{_THREAD_DATA['workspace_path']}/output.txt",
)
result = write_file_tool.func(
runtime=runtime,
description="写入大文件失败",
path="/mnt/user-data/workspace/output.txt",
content="report body",
)
assert len(result) <= 2000
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
assert "/tmp/deer-flow/threads/t1/user-data/workspace" not in result
assert "/mnt/user-data/workspace/nested/output.txt" in result
assert "remote tail marker" in result
assert "[write_file error truncated:" in result
def test_write_file_tool_preserves_short_oserror_without_truncation(monkeypatch) -> None:
class FailingSandbox:
id = "sandbox-write-short-oserror"
def write_file(self, path: str, content: str, append: bool = False) -> None:
raise OSError("disk quota exceeded")
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
sandbox = FailingSandbox()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
result = write_file_tool.func(
runtime=runtime,
description="写入失败",
path="/mnt/user-data/workspace/output.txt",
content="tiny payload",
)
assert result == "Error: Failed to write file '/mnt/user-data/workspace/output.txt': OSError: disk quota exceeded"
assert "[write_file error truncated:" not in result
def test_write_file_tool_bounds_large_sandbox_error(monkeypatch) -> None:
class FailingSandbox:
id = "sandbox-write-large-sandbox-error"
def write_file(self, path: str, content: str, append: bool = False) -> None:
raise SandboxError(f"remote write rejected {'B' * 12000} final detail")
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
sandbox = FailingSandbox()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
result = write_file_tool.func(
runtime=runtime,
description="远端写入失败",
path="/mnt/user-data/workspace/output.txt",
content="tiny payload",
)
assert len(result) <= 2000
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
assert "SandboxError: remote write rejected" in result
assert "final detail" in result
assert "[write_file error truncated:" in result
@pytest.mark.parametrize(
("raised_error", "expected_fragment"),
[
pytest.param(
PermissionError("permission denied"),
"Error: Permission denied writing to file: /mnt/user-data/workspace/output.txt",
id="permission",
),
pytest.param(
IsADirectoryError("target is a directory"),
"Error: Path is a directory, not a file: /mnt/user-data/workspace/output.txt",
id="directory",
),
pytest.param(
Exception("remote sandbox timeout"),
"Exception: remote sandbox timeout",
id="generic",
),
],
)
def test_write_file_tool_formats_all_other_failure_branches(
monkeypatch,
raised_error: Exception,
expected_fragment: str,
) -> None:
class FailingSandbox:
id = "sandbox-write-other-failure"
def write_file(self, path: str, content: str, append: bool = False) -> None:
raise raised_error
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
sandbox = FailingSandbox()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
result = write_file_tool.func(
runtime=runtime,
description="验证错误分支格式化",
path="/mnt/user-data/workspace/output.txt",
content="tiny payload",
)
assert "/mnt/user-data/workspace/output.txt" in result
assert expected_fragment in result
assert "[write_file error truncated:" not in result
def test_write_file_tool_handles_sandbox_init_failure(monkeypatch) -> None:
"""Regression for #3133 review: SandboxError raised during sandbox
initialization (before the local `requested_path` assignment) must still
surface as a bounded tool error rather than an UnboundLocalError.
"""
def raise_sandbox_error(runtime):
raise SandboxError("sandbox missing")
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", raise_sandbox_error)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
result = write_file_tool.func(
runtime=runtime,
description="sandbox 初始化失败",
path="/mnt/user-data/workspace/output.txt",
content="tiny payload",
)
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
assert "SandboxError: sandbox missing" in result
assert "[write_file error truncated:" not in result
def test_file_operation_lock_memory_cleanup() -> None:
"""Verify that released locks are eventually cleaned up by WeakValueDictionary.
+3 -1
View File
@@ -7,6 +7,7 @@ from types import SimpleNamespace
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.gateway.deps import get_config
from app.gateway.routers import skills as skills_router
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.types import Skill
@@ -38,7 +39,8 @@ def _make_skill(name: str, *, enabled: bool) -> Skill:
def _make_test_app(config) -> FastAPI:
app = FastAPI()
app.state.config = config
app.state.config = config # kept for any startup-style reads
app.dependency_overrides[get_config] = lambda: config
app.include_router(skills_router.router)
return app
@@ -0,0 +1,91 @@
"""Regression tests for _find_usage_recorder callback shape handling.
Bytedance issue #3107 BUG-002: When LangChain passes ``config["callbacks"]`` as
an ``AsyncCallbackManager`` (instead of a plain list), the previous
``for cb in callbacks`` loop raised ``TypeError: 'AsyncCallbackManager' object
is not iterable``. ToolErrorHandlingMiddleware then converted the entire ``task``
tool call into an error ToolMessage, losing the subagent result.
"""
from types import SimpleNamespace
from langchain_core.callbacks import AsyncCallbackManager, CallbackManager
from deerflow.tools.builtins.task_tool import _find_usage_recorder
class _RecorderHandler:
def record_external_llm_usage_records(self, records):
self.records = records
class _OtherHandler:
pass
def _make_runtime(callbacks):
return SimpleNamespace(config={"callbacks": callbacks})
def test_find_usage_recorder_with_plain_list():
recorder = _RecorderHandler()
runtime = _make_runtime([_OtherHandler(), recorder])
assert _find_usage_recorder(runtime) is recorder
def test_find_usage_recorder_with_async_callback_manager():
"""LangChain wraps callbacks in AsyncCallbackManager for async tool runs.
The old implementation raised TypeError here. The recorder lives on
``manager.handlers``; we must look there too.
"""
recorder = _RecorderHandler()
manager = AsyncCallbackManager(handlers=[_OtherHandler(), recorder])
runtime = _make_runtime(manager)
assert _find_usage_recorder(runtime) is recorder
def test_find_usage_recorder_with_sync_callback_manager():
"""Sync flavor of the same wrapper used by some langchain code paths."""
recorder = _RecorderHandler()
manager = CallbackManager(handlers=[recorder])
runtime = _make_runtime(manager)
assert _find_usage_recorder(runtime) is recorder
def test_find_usage_recorder_returns_none_when_no_recorder():
manager = AsyncCallbackManager(handlers=[_OtherHandler()])
runtime = _make_runtime(manager)
assert _find_usage_recorder(runtime) is None
def test_find_usage_recorder_handles_empty_manager():
manager = AsyncCallbackManager(handlers=[])
runtime = _make_runtime(manager)
assert _find_usage_recorder(runtime) is None
def test_find_usage_recorder_returns_none_for_none_runtime():
assert _find_usage_recorder(None) is None
def test_find_usage_recorder_returns_none_when_callbacks_is_none():
runtime = _make_runtime(None)
assert _find_usage_recorder(runtime) is None
def test_find_usage_recorder_returns_none_for_single_handler_object():
"""A single handler instance (not wrapped in a list or manager) should not crash.
LangChain's contract is that ``config["callbacks"]`` is a list-or-manager,
but we treat any other shape defensively rather than letting a ``for`` loop
blow up at runtime.
"""
runtime = _make_runtime(_RecorderHandler())
assert _find_usage_recorder(runtime) is None
def test_find_usage_recorder_returns_none_when_config_not_dict():
"""Defensive: a runtime without a dict-shaped config should not raise."""
runtime = SimpleNamespace(config="not-a-dict")
assert _find_usage_recorder(runtime) is None
+27
View File
@@ -53,3 +53,30 @@ def test_thread_token_usage_returns_stable_shape():
},
}
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
def test_thread_token_usage_can_include_active_runs():
run_store = MagicMock()
run_store.aggregate_tokens_by_thread = AsyncMock(
return_value={
"total_tokens": 175,
"total_input_tokens": 120,
"total_output_tokens": 55,
"total_runs": 3,
"by_model": {"unknown": {"tokens": 175, "runs": 3}},
"by_caller": {
"lead_agent": 145,
"subagent": 25,
"middleware": 5,
},
},
)
app = _make_app(run_store)
with TestClient(app) as client:
response = client.get("/api/threads/thread-1/token-usage?include_active=true")
assert response.status_code == 200
assert response.json()["total_tokens"] == 175
assert response.json()["total_runs"] == 3
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1", include_active=True)
@@ -134,8 +134,14 @@ def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
assert captured["app_config"] is app_config
assert len(middlewares) == 6
assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware)
# 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling,
# SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware
# (enabled by default — see SafetyFinishReasonConfig).
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
assert len(middlewares) == 7
assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares)
assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware)
def test_wrap_tool_call_passthrough_on_success():
+2
View File
@@ -11,6 +11,7 @@ from _router_auth_helpers import call_unwrapped, make_authed_test_app
from fastapi import HTTPException, UploadFile
from fastapi.testclient import TestClient
from app.gateway.deps import get_config
from app.gateway.routers import uploads
@@ -687,6 +688,7 @@ def test_upload_limits_endpoint_requires_thread_access():
cfg.uploads = {}
app = make_authed_test_app(owner_check_passes=False)
app.state.config = cfg
app.dependency_overrides[get_config] = lambda: cfg
app.include_router(uploads.router)
with TestClient(app) as client: