mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
Merge branch 'main' into fix-3127
This commit is contained in:
@@ -184,6 +184,18 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
|
||||
|
||||
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart.
|
||||
|
||||
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work (logging level, channels, `langgraph_runtime` engines) and passes it explicitly to `langgraph_runtime(app, startup_config)`. Infrastructure fields are **restart-required**:
|
||||
|
||||
| Field | Why a restart is required |
|
||||
|---|---|
|
||||
| `database.*` | `init_engine_from_config()` runs once during `langgraph_runtime()` startup; the SQLAlchemy engine holds the connection pool. |
|
||||
| `checkpointer.*` (including SQLite WAL/journal settings) | `make_checkpointer()` binds the persistent checkpointer once at startup. |
|
||||
| `run_events.*` | `make_run_event_store()` selects memory- vs. SQL-backed implementation at startup. |
|
||||
| `stream_bridge.*` | `make_stream_bridge()` constructs the bridge object once. |
|
||||
| `sandbox.use` | `get_sandbox_provider()` caches the provider singleton (`_default_sandbox_provider`); a new class path takes effect only on next process start. |
|
||||
| `log_level` | `apply_logging_level()` is called only in `app.py` startup; it mutates the root logger's level, and `get_app_config()` returning a fresh `AppConfig` does not retrigger it. |
|
||||
| `channels.*` IM platform credentials | `start_channel_service()` is invoked once during startup; live channels are not rebuilt on config change. |
|
||||
|
||||
Configuration priority:
|
||||
1. Explicit `config_path` argument
|
||||
2. `DEER_FLOW_CONFIG_PATH` environment variable
|
||||
|
||||
@@ -161,10 +161,16 @@ async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan handler."""
|
||||
|
||||
# Load config and check necessary environment variables at startup
|
||||
# Load config and check necessary environment variables at startup.
|
||||
# `startup_config` is a local snapshot used only for one-shot bootstrap
|
||||
# work (logging level, langgraph_runtime engines, channels). Request-time
|
||||
# config resolution always routes through `get_app_config()` in
|
||||
# `app/gateway/deps.py::get_config()` so `config.yaml` edits become
|
||||
# visible without a process restart. We deliberately do NOT cache this
|
||||
# snapshot on `app.state` to keep that contract enforceable.
|
||||
try:
|
||||
app.state.config = get_app_config()
|
||||
apply_logging_level(app.state.config.log_level)
|
||||
startup_config = get_app_config()
|
||||
apply_logging_level(startup_config.log_level)
|
||||
logger.info("Configuration loaded successfully")
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
||||
@@ -174,7 +180,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
|
||||
|
||||
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
||||
async with langgraph_runtime(app):
|
||||
async with langgraph_runtime(app, startup_config):
|
||||
logger.info("LangGraph runtime initialised")
|
||||
|
||||
# Check admin bootstrap state and migrate orphan threads after admin exists.
|
||||
@@ -185,7 +191,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
try:
|
||||
from app.channels.service import start_channel_service
|
||||
|
||||
channel_service = await start_channel_service(app.state.config)
|
||||
channel_service = await start_channel_service(startup_config)
|
||||
logger.info("Channel service started: %s", channel_service.get_status())
|
||||
except Exception:
|
||||
logger.exception("No IM channels configured or channel service failed to start")
|
||||
|
||||
+69
-17
@@ -3,11 +3,21 @@
|
||||
**Getters** (used by routers): raise 503 when a required dependency is
|
||||
missing, except ``get_store`` which returns ``None``.
|
||||
|
||||
``AppConfig`` is intentionally *not* cached on ``app.state``. Routers and the
|
||||
run path resolve it through :func:`deerflow.config.app_config.get_app_config`,
|
||||
which performs mtime-based hot reload, so edits to ``config.yaml`` take
|
||||
effect on the next request without a process restart. The engines created in
|
||||
:func:`langgraph_runtime` (stream bridge, persistence, checkpointer, store,
|
||||
run-event store) accept a ``startup_config`` snapshot — they are
|
||||
restart-required by design and stay bound to that snapshot to keep the live
|
||||
process consistent with itself.
|
||||
|
||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import TYPE_CHECKING, TypeVar, cast
|
||||
@@ -15,12 +25,14 @@ from typing import TYPE_CHECKING, TypeVar, cast
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.app_config import AppConfig, get_app_config
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
from deerflow.runtime import RunContext, RunManager, StreamBridge
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
@@ -30,21 +42,55 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_config(request: Request) -> AppConfig:
|
||||
"""Return the app-scoped ``AppConfig`` stored on ``app.state``."""
|
||||
config = getattr(request.app.state, "config", None)
|
||||
if config is None:
|
||||
raise HTTPException(status_code=503, detail="Configuration not available")
|
||||
return config
|
||||
def get_config() -> AppConfig:
|
||||
"""Return the freshest ``AppConfig`` for the current request.
|
||||
|
||||
Routes through :func:`deerflow.config.app_config.get_app_config`, which
|
||||
honours runtime ``ContextVar`` overrides and reloads ``config.yaml`` from
|
||||
disk when its mtime changes. ``AppConfig`` is not cached on ``app.state``
|
||||
at all — the only startup-time snapshot lives as a local
|
||||
``startup_config`` variable inside ``lifespan()`` and is passed
|
||||
explicitly into :func:`langgraph_runtime` for the engines that are
|
||||
restart-required by design. Routing every request through
|
||||
:func:`get_app_config` closes the bytedance/deer-flow issue #3107 BUG-001
|
||||
split-brain where the worker / lead-agent thread saw a stale startup
|
||||
snapshot.
|
||||
|
||||
Any failure to materialise the config (missing file, permission denied,
|
||||
YAML parse error, validation error) is reported as 503 — semantically
|
||||
"the gateway cannot serve requests without a usable configuration" — and
|
||||
logged with the original exception so operators have something to debug.
|
||||
"""
|
||||
try:
|
||||
return get_app_config()
|
||||
except Exception as exc: # noqa: BLE001 - request boundary: log and degrade gracefully
|
||||
logger.exception("Failed to load AppConfig at request time")
|
||||
raise HTTPException(status_code=503, detail="Configuration not available") from exc
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGenerator[None, None]:
|
||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
||||
|
||||
``startup_config`` is the ``AppConfig`` snapshot taken once during
|
||||
``lifespan()`` for one-shot infrastructure bootstrap. The engines and
|
||||
stores constructed here (stream bridge, persistence engine, checkpointer,
|
||||
store, run-event store) are restart-required by design — they hold live
|
||||
connections, file handles, or singleton providers — so they bind to this
|
||||
snapshot and survive across `config.yaml` edits. Request-time consumers
|
||||
must still go through :func:`get_config` for any field that should be
|
||||
hot-reloadable. See ``backend/CLAUDE.md`` "Config Hot-Reload Boundary".
|
||||
|
||||
The matching ``run_events_config`` is frozen onto ``app.state`` so
|
||||
:func:`get_run_context` pairs a freshly-loaded ``AppConfig`` with the
|
||||
*startup-time* run-events configuration the underlying ``event_store``
|
||||
was built from — otherwise the runtime could end up combining a live
|
||||
new ``run_events_config`` with an event store still bound to the
|
||||
previous backend.
|
||||
|
||||
Usage in ``app.py``::
|
||||
|
||||
async with langgraph_runtime(app):
|
||||
async with langgraph_runtime(app, startup_config):
|
||||
yield
|
||||
"""
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
||||
@@ -53,9 +99,7 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
config = getattr(app.state, "config", None)
|
||||
if config is None:
|
||||
raise RuntimeError("langgraph_runtime() requires app.state.config to be initialized")
|
||||
config = startup_config
|
||||
|
||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge(config))
|
||||
|
||||
@@ -84,8 +128,12 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
app.state.thread_store = make_thread_store(sf, app.state.store)
|
||||
|
||||
# Run event store (has its own factory with config-driven backend selection)
|
||||
# Run event store. The store and the matching ``run_events_config`` are
|
||||
# both frozen at startup so ``get_run_context`` does not combine a
|
||||
# freshly-reloaded ``AppConfig.run_events`` with a store still bound to
|
||||
# the previous backend.
|
||||
run_events_config = getattr(config, "run_events", None)
|
||||
app.state.run_events_config = run_events_config
|
||||
app.state.run_event_store = make_run_event_store(run_events_config)
|
||||
|
||||
# RunManager with store backing for persistence
|
||||
@@ -139,16 +187,20 @@ def get_thread_store(request: Request) -> ThreadMetaStore:
|
||||
def get_run_context(request: Request) -> RunContext:
|
||||
"""Build a :class:`RunContext` from ``app.state`` singletons.
|
||||
|
||||
Returns a *base* context with infrastructure dependencies.
|
||||
Returns a *base* context with infrastructure dependencies. The
|
||||
``app_config`` field is resolved live so per-run fields (e.g.
|
||||
``models[*].max_tokens``) follow ``config.yaml`` edits; the
|
||||
``event_store`` / ``run_events_config`` pair stays frozen to the snapshot
|
||||
captured in :func:`langgraph_runtime` so callers never see a store bound
|
||||
to one backend paired with a config pointing at another.
|
||||
"""
|
||||
config = get_config(request)
|
||||
return RunContext(
|
||||
checkpointer=get_checkpointer(request),
|
||||
store=get_store(request),
|
||||
event_store=get_run_event_store(request),
|
||||
run_events_config=getattr(config, "run_events", None),
|
||||
run_events_config=getattr(request.app.state, "run_events_config", None),
|
||||
thread_store=get_thread_store(request),
|
||||
app_config=config,
|
||||
app_config=get_config(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -66,6 +66,14 @@ class RunResponse(BaseModel):
|
||||
multitask_strategy: str = "reject"
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
llm_call_count: int = 0
|
||||
lead_agent_tokens: int = 0
|
||||
subagent_tokens: int = 0
|
||||
middleware_tokens: int = 0
|
||||
message_count: int = 0
|
||||
|
||||
|
||||
class ThreadTokenUsageModelBreakdown(BaseModel):
|
||||
@@ -111,6 +119,14 @@ def _record_to_response(record: RunRecord) -> RunResponse:
|
||||
multitask_strategy=record.multitask_strategy,
|
||||
created_at=record.created_at,
|
||||
updated_at=record.updated_at,
|
||||
total_input_tokens=record.total_input_tokens,
|
||||
total_output_tokens=record.total_output_tokens,
|
||||
total_tokens=record.total_tokens,
|
||||
llm_call_count=record.llm_call_count,
|
||||
lead_agent_tokens=record.lead_agent_tokens,
|
||||
subagent_tokens=record.subagent_tokens,
|
||||
middleware_tokens=record.middleware_tokens,
|
||||
message_count=record.message_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -402,8 +418,15 @@ async def list_run_events(
|
||||
|
||||
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse:
|
||||
async def thread_token_usage(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
include_active: bool = Query(default=False, description="Include running run progress snapshots"),
|
||||
) -> ThreadTokenUsageResponse:
|
||||
"""Thread-level token usage aggregation."""
|
||||
run_store = get_run_store(request)
|
||||
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
||||
if include_active:
|
||||
agg = await run_store.aggregate_tokens_by_thread(thread_id, include_active=True)
|
||||
else:
|
||||
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
||||
return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
|
||||
|
||||
@@ -15,7 +15,8 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.utils import convert_to_messages
|
||||
|
||||
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
@@ -76,21 +77,35 @@ def normalize_stream_modes(raw: list[str] | str | None) -> list[str]:
|
||||
|
||||
|
||||
def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Convert LangGraph Platform input format to LangChain state dict."""
|
||||
"""Convert LangGraph Platform input format to LangChain state dict.
|
||||
|
||||
Delegates dict→message coercion to ``langchain_core.messages.utils.convert_to_messages``
|
||||
so that ``additional_kwargs`` (e.g. uploaded-file metadata — gh #3132), ``id``,
|
||||
``name``, and non-human roles (ai/system/tool) survive unchanged. An earlier
|
||||
hand-rolled version only forwarded ``content`` and collapsed every role to
|
||||
``HumanMessage``, which silently stripped frontend-supplied attachments.
|
||||
|
||||
Malformed message dicts (missing ``role``/``type``/``content``, unsupported
|
||||
role, etc.) raise ``HTTPException(400)`` with the offending index, instead
|
||||
of bubbling up as a 500. The gateway is a system boundary, so per-entry
|
||||
validation errors are the right shape for clients to retry against.
|
||||
"""
|
||||
if raw_input is None:
|
||||
return {}
|
||||
messages = raw_input.get("messages")
|
||||
if messages and isinstance(messages, list):
|
||||
converted = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role", msg.get("type", "user"))
|
||||
content = msg.get("content", "")
|
||||
if role in ("user", "human"):
|
||||
converted.append(HumanMessage(content=content))
|
||||
else:
|
||||
# TODO: handle other message types (system, ai, tool)
|
||||
converted.append(HumanMessage(content=content))
|
||||
converted: list[Any] = []
|
||||
for index, msg in enumerate(messages):
|
||||
if isinstance(msg, BaseMessage):
|
||||
converted.append(msg)
|
||||
elif isinstance(msg, dict):
|
||||
try:
|
||||
converted.extend(convert_to_messages([msg]))
|
||||
except (ValueError, TypeError, NotImplementedError) as exc:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid message at input.messages[{index}]: {exc}",
|
||||
) from exc
|
||||
else:
|
||||
converted.append(msg)
|
||||
return {**raw_input, "messages": converted}
|
||||
|
||||
@@ -241,13 +241,6 @@ GET /api/mcp/config
|
||||
"GITHUB_TOKEN": "***"
|
||||
},
|
||||
"description": "GitHub operations"
|
||||
},
|
||||
"filesystem": {
|
||||
"enabled": false,
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
"description": "File system access"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,19 @@ DeerFlow supports configurable MCP servers and skills to extend its capabilities
|
||||
3. Configure each server’s command, arguments, and environment variables as needed.
|
||||
4. Restart the application to load and register MCP tools.
|
||||
|
||||
## Filesystem MCP Servers
|
||||
|
||||
DeerFlow already provides built-in file tools for thread-scoped workspace access.
|
||||
Do not add an MCP filesystem server for the same DeerFlow workspace. The
|
||||
overlapping file tools use different path semantics, which can make LLM tool
|
||||
selection and file access behavior unstable.
|
||||
|
||||
DeerFlow does not currently adapt the MCP Roots mode for filesystem servers. In
|
||||
particular, it does not publish per-thread MCP roots or map DeerFlow sandbox
|
||||
paths such as `/mnt/user-data/...` to paths accepted by
|
||||
`@modelcontextprotocol/server-filesystem`. Use DeerFlow's built-in file tools
|
||||
for DeerFlow workspace files.
|
||||
|
||||
## OAuth Support (HTTP/SSE MCP Servers)
|
||||
|
||||
For `http` and `sse` MCP servers, DeerFlow supports OAuth token acquisition and automatic token refresh.
|
||||
@@ -88,7 +101,6 @@ MCP servers expose tools that are automatically discovered and integrated into D
|
||||
|
||||
MCP servers can provide access to:
|
||||
|
||||
- **File systems**
|
||||
- **Databases** (e.g., PostgreSQL)
|
||||
- **External APIs** (e.g., GitHub, Brave Search)
|
||||
- **Browser automation** (e.g., Puppeteer)
|
||||
@@ -97,4 +109,4 @@ MCP servers can provide access to:
|
||||
## Learn More
|
||||
|
||||
For detailed documentation about the Model Context Protocol, visit:
|
||||
https://modelcontextprotocol.io
|
||||
https://modelcontextprotocol.io
|
||||
|
||||
@@ -29,6 +29,7 @@ from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||
from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
@@ -338,6 +339,15 @@ def _build_middlewares(
|
||||
if custom_middlewares:
|
||||
middlewares.extend(custom_middlewares)
|
||||
|
||||
# SafetyFinishReasonMiddleware — suppress tool execution when the provider
|
||||
# safety-terminated the response. Registered after custom middlewares so
|
||||
# that LangChain's reverse-order after_model dispatch runs Safety first;
|
||||
# cleared tool_calls then flow through Loop/Subagent accounting without
|
||||
# firing extra alarms. See safety_finish_reason_middleware.py docstring.
|
||||
safety_config = resolved_app_config.safety_finish_reason
|
||||
if safety_config.enabled:
|
||||
middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config))
|
||||
|
||||
# ClarificationMiddleware should always be last
|
||||
middlewares.append(ClarificationMiddleware())
|
||||
return middlewares
|
||||
|
||||
+6
-7
@@ -15,6 +15,7 @@ to the end of the message list as before_model + add_messages reducer would do.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
|
||||
@@ -109,10 +110,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
This normalizes model-bound causal order before provider serialization while
|
||||
preserving already-valid transcripts unchanged.
|
||||
"""
|
||||
tool_messages_by_id: dict[str, ToolMessage] = {}
|
||||
tool_messages_by_id: dict[str, deque[ToolMessage]] = defaultdict(deque)
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage):
|
||||
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
|
||||
tool_messages_by_id[msg.tool_call_id].append(msg)
|
||||
|
||||
tool_call_ids: set[str] = set()
|
||||
for msg in messages:
|
||||
@@ -124,7 +125,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
tool_call_ids.add(tc_id)
|
||||
|
||||
patched: list = []
|
||||
consumed_tool_msg_ids: set[str] = set()
|
||||
patch_count = 0
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
|
||||
@@ -136,13 +136,13 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
|
||||
for tc in self._message_tool_calls(msg):
|
||||
tc_id = tc.get("id")
|
||||
if not tc_id or tc_id in consumed_tool_msg_ids:
|
||||
if not tc_id:
|
||||
continue
|
||||
|
||||
existing_tool_msg = tool_messages_by_id.get(tc_id)
|
||||
tool_msg_queue = tool_messages_by_id.get(tc_id)
|
||||
existing_tool_msg = tool_msg_queue.popleft() if tool_msg_queue else None
|
||||
if existing_tool_msg is not None:
|
||||
patched.append(existing_tool_msg)
|
||||
consumed_tool_msg_ids.add(tc_id)
|
||||
else:
|
||||
patched.append(
|
||||
ToolMessage(
|
||||
@@ -152,7 +152,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
status="error",
|
||||
)
|
||||
)
|
||||
consumed_tool_msg_ids.add(tc_id)
|
||||
patch_count += 1
|
||||
|
||||
if patched == messages:
|
||||
|
||||
+317
@@ -0,0 +1,317 @@
|
||||
"""Suppress tool execution when the provider safety-terminated the response.
|
||||
|
||||
Background — see issue bytedance/deer-flow#3028.
|
||||
|
||||
Some providers (OpenAI ``finish_reason='content_filter'``, Anthropic
|
||||
``stop_reason='refusal'``, Gemini ``finish_reason='SAFETY'`` ...) can stop
|
||||
generation mid-stream while still returning partially-formed ``tool_calls``.
|
||||
LangChain's tool router treats any AIMessage with a non-empty ``tool_calls``
|
||||
field as "go execute these", so half-truncated arguments — e.g. a markdown
|
||||
``write_file`` that stops in the middle of a sentence — get dispatched as if
|
||||
they were complete. The agent then sees the truncated file, tries to fix it,
|
||||
gets filtered again, and loops.
|
||||
|
||||
This middleware sits at ``after_model`` and gates that behaviour: when a
|
||||
configured ``SafetyTerminationDetector`` fires *and* the AIMessage carries
|
||||
tool calls, we strip the tool calls (both structured and raw provider
|
||||
payloads), append a user-facing explanation, and stash observability fields
|
||||
in ``additional_kwargs.safety_termination`` so logs, traces, and SSE
|
||||
consumers can see what happened.
|
||||
|
||||
Hook choice: ``after_model`` (not ``wrap_model_call``) because the response
|
||||
is a *normal* return — not an exception — and we want to participate in the
|
||||
same after-model chain as ``LoopDetectionMiddleware``, with which we share
|
||||
the same tool-call-suppression mechanic but a different trigger.
|
||||
|
||||
Placement: register *after* ``LoopDetectionMiddleware`` in the middleware
|
||||
list. LangChain factory wires ``after_model`` edges in reverse list order
|
||||
(``langchain/agents/factory.py:add_edge("model", middleware_w_after_model[-1])``,
|
||||
then walks ``range(len-1, 0, -1)``), so the *last* registered middleware is
|
||||
the *first* to observe the model output. Registering Safety after Loop
|
||||
means Safety sees the raw response first, clears tool calls if it fires,
|
||||
and Loop then accounts against the cleaned message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.middlewares.safety_termination_detectors import (
|
||||
SafetyTermination,
|
||||
SafetyTerminationDetector,
|
||||
default_detectors,
|
||||
)
|
||||
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_USER_FACING_MESSAGE = (
|
||||
"The model provider stopped this response with a safety-related signal "
|
||||
"({reason_field}={reason_value!r}, detector={detector!r}). Any tool "
|
||||
"calls produced in this turn were suppressed because their arguments "
|
||||
"may be truncated and unsafe to execute. Please rephrase the request "
|
||||
"or ask for a narrower output."
|
||||
)
|
||||
|
||||
|
||||
class SafetyFinishReasonMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Strip tool_calls from AIMessages flagged by a SafetyTerminationDetector."""
|
||||
|
||||
def __init__(self, detectors: list[SafetyTerminationDetector] | None = None) -> None:
|
||||
super().__init__()
|
||||
# Copy so caller mutations after construction don't leak into us.
|
||||
self._detectors: list[SafetyTerminationDetector] = list(detectors) if detectors else default_detectors()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: SafetyFinishReasonConfig) -> SafetyFinishReasonMiddleware:
|
||||
"""Construct from validated Pydantic config, honouring the
|
||||
reflection-loaded detector list when provided.
|
||||
|
||||
An explicit empty list is intentionally rejected — it would silently
|
||||
disable detection while leaving the middleware in the chain, which
|
||||
is the worst of both worlds. Use ``enabled: false`` instead.
|
||||
"""
|
||||
if config.detectors is None:
|
||||
return cls()
|
||||
|
||||
if not config.detectors:
|
||||
raise ValueError("safety_finish_reason.detectors must be omitted (use built-ins) or contain at least one entry; use enabled=false to disable the middleware entirely.")
|
||||
|
||||
from deerflow.reflection import resolve_variable
|
||||
|
||||
detectors: list[SafetyTerminationDetector] = []
|
||||
for entry in config.detectors:
|
||||
detector_cls = resolve_variable(entry.use)
|
||||
kwargs = dict(entry.config) if entry.config else {}
|
||||
detector = detector_cls(**kwargs)
|
||||
if not isinstance(detector, SafetyTerminationDetector):
|
||||
raise TypeError(f"{entry.use} did not produce a SafetyTerminationDetector (got {type(detector).__name__}); ensure it has a `name` attribute and a `detect(message)` method")
|
||||
detectors.append(detector)
|
||||
return cls(detectors=detectors)
|
||||
|
||||
# ----- detection -------------------------------------------------------
|
||||
|
||||
def _detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
for detector in self._detectors:
|
||||
try:
|
||||
hit = detector.detect(message)
|
||||
except Exception: # noqa: BLE001 - never let a buggy detector break the agent run
|
||||
logger.exception("SafetyTerminationDetector %r raised; treating as no-match", getattr(detector, "name", type(detector).__name__))
|
||||
continue
|
||||
if hit is not None:
|
||||
return hit
|
||||
return None
|
||||
|
||||
# ----- message rewriting ----------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _append_user_message(content: object, text: str) -> str | list:
|
||||
"""Append a plain-text explanation to AIMessage content.
|
||||
|
||||
Mirrors ``LoopDetectionMiddleware._append_text`` so list-content
|
||||
responses (Anthropic thinking blocks, vLLM reasoning splits) keep
|
||||
their structure instead of being string-coerced into a TypeError.
|
||||
"""
|
||||
if content is None or content == "":
|
||||
return text
|
||||
if isinstance(content, list):
|
||||
return [*content, {"type": "text", "text": f"\n\n{text}"}]
|
||||
if isinstance(content, str):
|
||||
return content + f"\n\n{text}"
|
||||
return str(content) + f"\n\n{text}"
|
||||
|
||||
def _build_suppressed_message(
|
||||
self,
|
||||
message: AIMessage,
|
||||
termination: SafetyTermination,
|
||||
) -> AIMessage:
|
||||
suppressed_names = [tc.get("name") or "unknown" for tc in (message.tool_calls or [])]
|
||||
explanation = _USER_FACING_MESSAGE.format(
|
||||
reason_field=termination.reason_field,
|
||||
reason_value=termination.reason_value,
|
||||
detector=termination.detector,
|
||||
)
|
||||
new_content = self._append_user_message(message.content, explanation)
|
||||
|
||||
# clone_ai_message_with_tool_calls handles structured tool_calls,
|
||||
# raw additional_kwargs.tool_calls, and function_call in one shot.
|
||||
# It only rewrites finish_reason when the old value was "tool_calls",
|
||||
# which is not our case — content_filter / refusal / SAFETY stay put
|
||||
# so downstream SSE / converters keep seeing the real provider reason.
|
||||
cleared = clone_ai_message_with_tool_calls(message, [], content=new_content)
|
||||
|
||||
# Re-clone additional_kwargs so we don't accidentally mutate the
|
||||
# dict returned by clone_ai_message_with_tool_calls (which already
|
||||
# made a shallow copy, but downstream model_copy still references
|
||||
# it). Then stamp the observability record.
|
||||
kwargs = dict(getattr(cleared, "additional_kwargs", None) or {})
|
||||
kwargs["safety_termination"] = {
|
||||
"detector": termination.detector,
|
||||
"reason_field": termination.reason_field,
|
||||
"reason_value": termination.reason_value,
|
||||
"suppressed_tool_call_count": len(suppressed_names),
|
||||
"suppressed_tool_call_names": suppressed_names,
|
||||
"extras": dict(termination.extras) if termination.extras else {},
|
||||
}
|
||||
return cleared.model_copy(update={"additional_kwargs": kwargs})
|
||||
|
||||
# ----- observability ---------------------------------------------------
|
||||
|
||||
def _emit_event(
|
||||
self,
|
||||
termination: SafetyTermination,
|
||||
suppressed_names: list[str],
|
||||
runtime: Runtime,
|
||||
) -> None:
|
||||
"""Notify SSE consumers (e.g. the web UI) that a tool turn was
|
||||
suppressed so they can reconcile any "tool starting..." placeholders
|
||||
already streamed to the user. Failures are logged at debug and
|
||||
ignored — this is a best-effort signal."""
|
||||
try:
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
except Exception: # noqa: BLE001
|
||||
logger.debug("get_stream_writer unavailable; skipping safety_termination event", exc_info=True)
|
||||
return
|
||||
|
||||
thread_id = None
|
||||
if runtime is not None and getattr(runtime, "context", None):
|
||||
thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None
|
||||
|
||||
try:
|
||||
writer(
|
||||
{
|
||||
"type": "safety_termination",
|
||||
"detector": termination.detector,
|
||||
"reason_field": termination.reason_field,
|
||||
"reason_value": termination.reason_value,
|
||||
"suppressed_tool_call_count": len(suppressed_names),
|
||||
"suppressed_tool_call_names": suppressed_names,
|
||||
"thread_id": thread_id,
|
||||
}
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.debug("Failed to emit safety_termination stream event", exc_info=True)
|
||||
|
||||
def _record_audit_event(
|
||||
self,
|
||||
termination: SafetyTermination,
|
||||
message,
|
||||
tool_calls: list[dict],
|
||||
runtime: Runtime,
|
||||
) -> None:
|
||||
"""Write a ``middleware:safety_termination`` record to RunEventStore
|
||||
for post-run auditability.
|
||||
|
||||
The custom stream event in ``_emit_event`` is consumed by live SSE
|
||||
clients and disappears after the run; this event is persisted so an
|
||||
operator can answer "which runs were safety-suppressed today?" from
|
||||
a single SQL query without joining the message body. Worker exposes
|
||||
the run-scoped ``RunJournal`` via ``runtime.context["__run_journal"]``;
|
||||
absent in unit-test / subagent / no-event-store paths, in which case
|
||||
we silently skip.
|
||||
|
||||
Tool **arguments** are deliberately **not** recorded — those are the
|
||||
very content the provider filtered; persisting them would defeat the
|
||||
purpose of the safety filter. Names / count / ids are sufficient for
|
||||
audit and debugging (issue #3028 review).
|
||||
"""
|
||||
journal = None
|
||||
if runtime is not None and getattr(runtime, "context", None):
|
||||
context = runtime.context
|
||||
if isinstance(context, dict):
|
||||
journal = context.get("__run_journal")
|
||||
if journal is None:
|
||||
return
|
||||
|
||||
suppressed_names = [tc.get("name") or "unknown" for tc in tool_calls]
|
||||
suppressed_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
|
||||
|
||||
changes = {
|
||||
"detector": termination.detector,
|
||||
"reason_field": termination.reason_field,
|
||||
"reason_value": termination.reason_value,
|
||||
"suppressed_tool_call_count": len(tool_calls),
|
||||
"suppressed_tool_call_names": suppressed_names,
|
||||
"suppressed_tool_call_ids": suppressed_ids,
|
||||
"message_id": getattr(message, "id", None),
|
||||
"extras": dict(termination.extras) if termination.extras else {},
|
||||
}
|
||||
|
||||
try:
|
||||
journal.record_middleware(
|
||||
tag="safety_termination",
|
||||
name=type(self).__name__,
|
||||
hook="after_model",
|
||||
action="suppress_tool_calls",
|
||||
changes=changes,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
# Audit-event persistence must never break agent execution.
|
||||
logger.debug("Failed to record middleware:safety_termination event", exc_info=True)
|
||||
|
||||
# ----- main apply ------------------------------------------------------
|
||||
|
||||
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
return None
|
||||
|
||||
# Issue scope: only intervene when there's something to suppress.
|
||||
# ``content_filter`` without tool_calls is allowed through unchanged
|
||||
# so the partial text response (if any) reaches the user naturally.
|
||||
tool_calls = last.tool_calls
|
||||
if not tool_calls:
|
||||
return None
|
||||
|
||||
termination = self._detect(last)
|
||||
if termination is None:
|
||||
return None
|
||||
|
||||
patched = self._build_suppressed_message(last, termination)
|
||||
|
||||
thread_id = None
|
||||
if runtime is not None and getattr(runtime, "context", None):
|
||||
thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None
|
||||
|
||||
logger.warning(
|
||||
"Provider safety termination detected — suppressed %d tool call(s)",
|
||||
len(tool_calls),
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"detector": termination.detector,
|
||||
"reason_field": termination.reason_field,
|
||||
"reason_value": termination.reason_value,
|
||||
"suppressed_tool_call_names": [tc.get("name") for tc in tool_calls],
|
||||
},
|
||||
)
|
||||
|
||||
self._emit_event(termination, [tc.get("name") or "unknown" for tc in tool_calls], runtime)
|
||||
self._record_audit_event(termination, last, list(tool_calls), runtime)
|
||||
|
||||
return {"messages": [patched]}
|
||||
|
||||
# ----- hooks -----------------------------------------------------------
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
@@ -0,0 +1,237 @@
|
||||
"""Detectors for provider-side safety termination signals.
|
||||
|
||||
Different LLM providers signal "I stopped this response for safety reasons"
|
||||
through different fields with different values. This module defines a small
|
||||
strategy interface and three built-in detectors that cover the major
|
||||
providers DeerFlow supports today. New providers (Wenxin, Hunyuan, Bedrock
|
||||
adapters, in-house gateways, ...) can be added by implementing
|
||||
``SafetyTerminationDetector`` and wiring it through
|
||||
``config.yaml: safety_finish_reason.detectors``.
|
||||
|
||||
The middleware that consumes these detectors lives in
|
||||
``safety_finish_reason_middleware.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SafetyTermination:
|
||||
"""A detected safety-related termination signal.
|
||||
|
||||
Attributes:
|
||||
detector: Name of the detector that produced this result. Used for
|
||||
observability so operators can see which provider rule fired.
|
||||
reason_field: The message metadata field that carried the signal
|
||||
(e.g. ``finish_reason``, ``stop_reason``).
|
||||
reason_value: The actual value of that field
|
||||
(e.g. ``content_filter``, ``refusal``, ``SAFETY``).
|
||||
extras: Provider-specific metadata that may help downstream
|
||||
consumers (e.g. Azure OpenAI content_filter_results, Gemini
|
||||
safety_ratings). Detectors are free to populate or skip this.
|
||||
"""
|
||||
|
||||
detector: str
|
||||
reason_field: str
|
||||
reason_value: str
|
||||
extras: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SafetyTerminationDetector(Protocol):
|
||||
"""Strategy interface for provider safety termination detection."""
|
||||
|
||||
name: str
|
||||
|
||||
def detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
"""Return a SafetyTermination if *message* indicates provider safety
|
||||
termination, otherwise return ``None``.
|
||||
|
||||
Implementations must be side-effect free and tolerant of missing or
|
||||
oddly-typed metadata — detectors run on every model response.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def _get_metadata_value(message: AIMessage, field_name: str) -> str | None:
|
||||
"""Read a string-typed value from either ``response_metadata`` or
|
||||
``additional_kwargs``.
|
||||
|
||||
LangChain provider adapters are inconsistent about where they stash
|
||||
provider stop signals. Most modern adapters use ``response_metadata``,
|
||||
but some legacy / passthrough paths still surface them via
|
||||
``additional_kwargs``. We check both, in that order, and only accept
|
||||
string values — Pydantic enums or dicts are ignored so we never raise
|
||||
on malformed inputs.
|
||||
"""
|
||||
for container_name in ("response_metadata", "additional_kwargs"):
|
||||
container = getattr(message, container_name, None) or {}
|
||||
if not isinstance(container, dict):
|
||||
continue
|
||||
value = container.get(field_name)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
class OpenAICompatibleContentFilterDetector:
|
||||
"""OpenAI-compatible content_filter signal.
|
||||
|
||||
Covers OpenAI, Azure OpenAI, Moonshot/Kimi, DeepSeek, Mistral, vLLM,
|
||||
Qwen (OpenAI-compatible mode), and any other adapter that follows the
|
||||
OpenAI ``finish_reason`` convention.
|
||||
|
||||
Some Chinese providers ship custom OpenAI-compatible gateways that use
|
||||
alternative tokens like ``sensitive`` or ``violation``. Extend the set
|
||||
via the ``finish_reasons`` kwarg in config.
|
||||
"""
|
||||
|
||||
name = "openai_compatible_content_filter"
|
||||
|
||||
def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None:
|
||||
configured = finish_reasons if finish_reasons is not None else ("content_filter",)
|
||||
self._finish_reasons: frozenset[str] = frozenset(r.lower() for r in configured)
|
||||
|
||||
def detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
value = _get_metadata_value(message, "finish_reason")
|
||||
if value is None or value.lower() not in self._finish_reasons:
|
||||
return None
|
||||
|
||||
extras: dict[str, Any] = {}
|
||||
# Azure OpenAI ships a structured content_filter_results block; carry it
|
||||
# through so operators can see *what* was filtered without re-tracing.
|
||||
response_metadata = getattr(message, "response_metadata", None) or {}
|
||||
if isinstance(response_metadata, dict):
|
||||
filter_results = response_metadata.get("content_filter_results")
|
||||
if filter_results:
|
||||
extras["content_filter_results"] = filter_results
|
||||
|
||||
return SafetyTermination(
|
||||
detector=self.name,
|
||||
reason_field="finish_reason",
|
||||
reason_value=value,
|
||||
extras=extras,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicRefusalDetector:
|
||||
"""Anthropic ``stop_reason == "refusal"`` signal.
|
||||
|
||||
Anthropic models surface safety refusals via a dedicated ``stop_reason``
|
||||
rather than ``finish_reason``. See:
|
||||
https://platform.claude.com/docs/en/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals
|
||||
"""
|
||||
|
||||
name = "anthropic_refusal"
|
||||
|
||||
def __init__(self, stop_reasons: list[str] | tuple[str, ...] | None = None) -> None:
|
||||
configured = stop_reasons if stop_reasons is not None else ("refusal",)
|
||||
self._stop_reasons: frozenset[str] = frozenset(r.lower() for r in configured)
|
||||
|
||||
def detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
value = _get_metadata_value(message, "stop_reason")
|
||||
if value is None or value.lower() not in self._stop_reasons:
|
||||
return None
|
||||
return SafetyTermination(
|
||||
detector=self.name,
|
||||
reason_field="stop_reason",
|
||||
reason_value=value,
|
||||
)
|
||||
|
||||
|
||||
class GeminiSafetyDetector:
|
||||
"""Gemini / Vertex AI safety-related finish reasons.
|
||||
|
||||
Gemini uses the same ``finish_reason`` field as OpenAI but with an
|
||||
enumerated upper-case taxonomy. The default set covers every Gemini
|
||||
finish_reason that means "the model stopped because the content/image
|
||||
tripped a safety, blocklist, recitation, or PII filter" — i.e. cases
|
||||
where any tool_calls returned alongside are likely truncated/
|
||||
unreliable. Full enum:
|
||||
https://docs.cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform_v1.types.Candidate.FinishReason
|
||||
|
||||
Intentionally **excluded** from the default set:
|
||||
- ``STOP`` — normal termination.
|
||||
- ``MAX_TOKENS`` — output length truncation, not safety
|
||||
(same root failure mode as
|
||||
content_filter, but issue #3028
|
||||
scopes it out; expose separately if
|
||||
desired).
|
||||
- ``LANGUAGE`` / ``NO_IMAGE`` — capability mismatches, unrelated to
|
||||
safety; tool_calls would be absent
|
||||
anyway.
|
||||
- ``MALFORMED_FUNCTION_CALL`` /
|
||||
``UNEXPECTED_TOOL_CALL`` — tool-call protocol errors. The
|
||||
tool_calls are *also* unreliable
|
||||
here, but the failure category is
|
||||
distinct from safety filtering;
|
||||
handle in a dedicated detector to
|
||||
keep observability records honest.
|
||||
- ``OTHER`` / ``IMAGE_OTHER`` /
|
||||
``FINISH_REASON_UNSPECIFIED`` — too broad to enable by default;
|
||||
opt in via ``finish_reasons=`` if
|
||||
your provider abuses these.
|
||||
"""
|
||||
|
||||
name = "gemini_safety"
|
||||
|
||||
_DEFAULT_FINISH_REASONS = (
|
||||
# Text safety
|
||||
"SAFETY",
|
||||
"BLOCKLIST",
|
||||
"PROHIBITED_CONTENT",
|
||||
"SPII",
|
||||
"RECITATION",
|
||||
# Image safety (multimodal generation)
|
||||
"IMAGE_SAFETY",
|
||||
"IMAGE_PROHIBITED_CONTENT",
|
||||
"IMAGE_RECITATION",
|
||||
)
|
||||
|
||||
def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None:
|
||||
configured = finish_reasons if finish_reasons is not None else self._DEFAULT_FINISH_REASONS
|
||||
self._finish_reasons: frozenset[str] = frozenset(r.upper() for r in configured)
|
||||
|
||||
def detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
value = _get_metadata_value(message, "finish_reason")
|
||||
if value is None or value.upper() not in self._finish_reasons:
|
||||
return None
|
||||
|
||||
extras: dict[str, Any] = {}
|
||||
response_metadata = getattr(message, "response_metadata", None) or {}
|
||||
if isinstance(response_metadata, dict):
|
||||
# Gemini surfaces per-category scoring under safety_ratings.
|
||||
ratings = response_metadata.get("safety_ratings")
|
||||
if ratings:
|
||||
extras["safety_ratings"] = ratings
|
||||
|
||||
return SafetyTermination(
|
||||
detector=self.name,
|
||||
reason_field="finish_reason",
|
||||
reason_value=value,
|
||||
extras=extras,
|
||||
)
|
||||
|
||||
|
||||
def default_detectors() -> list[SafetyTerminationDetector]:
|
||||
"""Built-in detector set used when no custom detectors are configured."""
|
||||
return [
|
||||
OpenAICompatibleContentFilterDetector(),
|
||||
AnthropicRefusalDetector(),
|
||||
GeminiSafetyDetector(),
|
||||
]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AnthropicRefusalDetector",
|
||||
"GeminiSafetyDetector",
|
||||
"OpenAICompatibleContentFilterDetector",
|
||||
"SafetyTermination",
|
||||
"SafetyTerminationDetector",
|
||||
"default_detectors",
|
||||
]
|
||||
+10
@@ -164,4 +164,14 @@ def build_subagent_runtime_middlewares(
|
||||
|
||||
middlewares.append(ViewImageMiddleware())
|
||||
|
||||
# Same provider safety-termination guard the lead agent uses — subagents
|
||||
# are equally exposed to truncated tool_calls returned with
|
||||
# finish_reason=content_filter (and friends), and the bad call would then
|
||||
# propagate back to the lead agent via the task tool result.
|
||||
safety_config = app_config.safety_finish_reason
|
||||
if safety_config.enabled:
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
|
||||
middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config))
|
||||
|
||||
return middlewares
|
||||
|
||||
@@ -20,6 +20,7 @@ from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.run_events_config import RunEventsConfig
|
||||
from deerflow.config.runtime_paths import existing_project_file
|
||||
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
|
||||
from deerflow.config.skills_config import SkillsConfig
|
||||
@@ -102,6 +103,7 @@ class AppConfig(BaseModel):
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
||||
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
||||
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
|
||||
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Configuration for SafetyFinishReasonMiddleware.
|
||||
|
||||
Mirrors the shape of GuardrailsConfig: detectors are loaded by class path
|
||||
through ``deerflow.reflection.resolve_variable`` (same loader the
|
||||
``guardrails.provider`` config uses) so users can drop in custom provider
|
||||
detectors without modifying core code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SafetyDetectorConfig(BaseModel):
|
||||
"""One detector entry under ``safety_finish_reason.detectors``."""
|
||||
|
||||
use: str = Field(
|
||||
description=("Class path of a SafetyTerminationDetector implementation (e.g. 'deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector')."),
|
||||
)
|
||||
config: dict = Field(
|
||||
default_factory=dict,
|
||||
description="Constructor kwargs passed to the detector class.",
|
||||
)
|
||||
|
||||
|
||||
class SafetyFinishReasonConfig(BaseModel):
|
||||
"""Configuration for the SafetyFinishReasonMiddleware.
|
||||
|
||||
The middleware intercepts AIMessages where the provider signaled a
|
||||
safety-related termination (e.g. OpenAI ``finish_reason='content_filter'``)
|
||||
while still returning tool calls, and suppresses those tool calls so the
|
||||
half-truncated arguments never execute.
|
||||
"""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Master switch for the SafetyFinishReasonMiddleware.",
|
||||
)
|
||||
detectors: list[SafetyDetectorConfig] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Custom detector list. Leave unset (None) to use the built-in "
|
||||
"set covering OpenAI-compatible content_filter, Anthropic "
|
||||
"refusal, and Gemini SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/"
|
||||
"RECITATION. Provide a non-null list to fully override."
|
||||
),
|
||||
)
|
||||
@@ -134,9 +134,25 @@ def reset_mcp_tools_cache() -> None:
|
||||
"""Reset the MCP tools cache.
|
||||
|
||||
This is useful for testing or when you want to reload MCP tools.
|
||||
Also closes all persistent MCP sessions so they are recreated on
|
||||
the next tool load.
|
||||
"""
|
||||
global _mcp_tools_cache, _cache_initialized, _config_mtime
|
||||
_mcp_tools_cache = None
|
||||
_cache_initialized = False
|
||||
_config_mtime = None
|
||||
|
||||
# Close persistent sessions – they will be recreated by the next
|
||||
# get_mcp_tools() call with the (possibly updated) connection config.
|
||||
try:
|
||||
from deerflow.mcp.session_pool import get_session_pool
|
||||
|
||||
pool = get_session_pool()
|
||||
pool.close_all_sync()
|
||||
except Exception:
|
||||
logger.debug("Could not close MCP session pool on cache reset", exc_info=True)
|
||||
|
||||
from deerflow.mcp.session_pool import reset_session_pool
|
||||
|
||||
reset_session_pool()
|
||||
logger.info("MCP tools cache reset")
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
"""Persistent MCP session pool for stateful tool calls.
|
||||
|
||||
When MCP tools are loaded via langchain-mcp-adapters with ``session=None``,
|
||||
each tool call creates a new MCP session. For stateful servers like Playwright,
|
||||
this means browser state (opened pages, filled forms) is lost between calls.
|
||||
|
||||
This module provides a session pool that maintains persistent MCP sessions,
|
||||
scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id —
|
||||
so that consecutive tool calls share the same session and server-side state.
|
||||
Sessions are evicted in LRU order when the pool reaches capacity.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from mcp import ClientSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPSessionPool:
|
||||
"""Manages persistent MCP sessions scoped by ``(server_name, scope_key)``."""
|
||||
|
||||
MAX_SESSIONS = 256
|
||||
SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session via run_coroutine_threadsafe
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: OrderedDict[
|
||||
tuple[str, str],
|
||||
tuple[ClientSession, asyncio.AbstractEventLoop],
|
||||
] = OrderedDict()
|
||||
self._context_managers: dict[tuple[str, str], Any] = {}
|
||||
# threading.Lock is not bound to any event loop, so it is safe to
|
||||
# acquire from both async paths and sync/worker-thread paths.
|
||||
self._lock = threading.Lock()
|
||||
|
||||
async def get_session(
|
||||
self,
|
||||
server_name: str,
|
||||
scope_key: str,
|
||||
connection: dict[str, Any],
|
||||
) -> ClientSession:
|
||||
"""Get or create a persistent MCP session.
|
||||
|
||||
If an existing session was created in a different event loop (e.g.
|
||||
the sync-wrapper path), it is closed and replaced with a fresh one
|
||||
in the current loop.
|
||||
|
||||
Args:
|
||||
server_name: MCP server name.
|
||||
scope_key: Isolation key (typically thread_id).
|
||||
connection: Connection configuration for ``create_session``.
|
||||
|
||||
Returns:
|
||||
An initialized ``ClientSession``.
|
||||
"""
|
||||
key = (server_name, scope_key)
|
||||
current_loop = asyncio.get_running_loop()
|
||||
|
||||
# Phase 1: inspect/mutate the registry under the thread lock (no awaits).
|
||||
cms_to_close: list[tuple[tuple[str, str], Any]] = []
|
||||
with self._lock:
|
||||
if key in self._entries:
|
||||
session, loop = self._entries[key]
|
||||
if loop is current_loop:
|
||||
self._entries.move_to_end(key)
|
||||
return session
|
||||
# Session belongs to a different event loop – evict it.
|
||||
cm = self._context_managers.pop(key, None)
|
||||
self._entries.pop(key)
|
||||
if cm is not None:
|
||||
cms_to_close.append((key, cm))
|
||||
|
||||
# Evict LRU entries when at capacity.
|
||||
while len(self._entries) >= self.MAX_SESSIONS:
|
||||
oldest_key = next(iter(self._entries))
|
||||
cm = self._context_managers.pop(oldest_key, None)
|
||||
self._entries.pop(oldest_key)
|
||||
if cm is not None:
|
||||
cms_to_close.append((oldest_key, cm))
|
||||
|
||||
# Phase 2: async cleanup outside the lock so we never await while holding it.
|
||||
for close_key, cm in cms_to_close:
|
||||
try:
|
||||
await cm.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error closing MCP session %s", close_key, exc_info=True)
|
||||
|
||||
from langchain_mcp_adapters.sessions import create_session
|
||||
|
||||
cm = create_session(connection)
|
||||
session = await cm.__aenter__()
|
||||
await session.initialize()
|
||||
|
||||
# Phase 3: register the new session under the lock.
|
||||
with self._lock:
|
||||
self._entries[key] = (session, current_loop)
|
||||
self._context_managers[key] = cm
|
||||
logger.info("Created persistent MCP session for %s/%s", server_name, scope_key)
|
||||
return session
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cleanup helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _close_cm(self, key: tuple[str, str], cm: Any) -> None:
|
||||
"""Close a single context manager (must be called WITHOUT the lock)."""
|
||||
try:
|
||||
await cm.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error closing MCP session %s", key, exc_info=True)
|
||||
|
||||
async def close_scope(self, scope_key: str) -> None:
|
||||
"""Close all sessions for a given scope (e.g. thread_id)."""
|
||||
with self._lock:
|
||||
keys = [k for k in self._entries if k[1] == scope_key]
|
||||
cms = [(k, self._context_managers.pop(k, None)) for k in keys]
|
||||
for k in keys:
|
||||
self._entries.pop(k, None)
|
||||
for key, cm in cms:
|
||||
if cm is not None:
|
||||
await self._close_cm(key, cm)
|
||||
|
||||
async def close_server(self, server_name: str) -> None:
|
||||
"""Close all sessions for a given server."""
|
||||
with self._lock:
|
||||
keys = [k for k in self._entries if k[0] == server_name]
|
||||
cms = [(k, self._context_managers.pop(k, None)) for k in keys]
|
||||
for k in keys:
|
||||
self._entries.pop(k, None)
|
||||
for key, cm in cms:
|
||||
if cm is not None:
|
||||
await self._close_cm(key, cm)
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close every managed session."""
|
||||
with self._lock:
|
||||
cms = list(self._context_managers.items())
|
||||
self._context_managers.clear()
|
||||
self._entries.clear()
|
||||
for key, cm in cms:
|
||||
await self._close_cm(key, cm)
|
||||
|
||||
def close_all_sync(self) -> None:
|
||||
"""Close all sessions using their owning event loops (synchronous).
|
||||
|
||||
Each session is closed on the loop it was created in, avoiding
|
||||
cross-loop resource leaks. Safe to call from any thread without an
|
||||
active event loop.
|
||||
"""
|
||||
with self._lock:
|
||||
entries = list(self._entries.items())
|
||||
cms = dict(self._context_managers)
|
||||
self._entries.clear()
|
||||
self._context_managers.clear()
|
||||
|
||||
for key, (_, loop) in entries:
|
||||
cm = cms.get(key)
|
||||
if cm is None or loop.is_closed():
|
||||
continue
|
||||
try:
|
||||
if loop.is_running():
|
||||
# Schedule on the owning loop from this (different) thread.
|
||||
future = asyncio.run_coroutine_threadsafe(cm.__aexit__(None, None, None), loop)
|
||||
future.result(timeout=self.SESSION_CLOSE_TIMEOUT)
|
||||
else:
|
||||
loop.run_until_complete(cm.__aexit__(None, None, None))
|
||||
except Exception:
|
||||
logger.debug("Error closing MCP session %s during sync close", key, exc_info=True)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Module-level singleton
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_pool: MCPSessionPool | None = None
|
||||
_pool_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_session_pool() -> MCPSessionPool:
|
||||
"""Return the global session-pool singleton."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
with _pool_lock:
|
||||
if _pool is None:
|
||||
_pool = MCPSessionPool()
|
||||
return _pool
|
||||
|
||||
|
||||
def reset_session_pool() -> None:
|
||||
"""Reset the singleton (for tests)."""
|
||||
global _pool
|
||||
_pool = None
|
||||
@@ -1,21 +1,181 @@
|
||||
"""Load MCP tools using langchain-mcp-adapters."""
|
||||
"""Load MCP tools using langchain-mcp-adapters with persistent sessions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
from langgraph.config import get_config
|
||||
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.mcp.client import build_servers_config
|
||||
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
|
||||
from deerflow.mcp.session_pool import get_session_pool
|
||||
from deerflow.reflection import resolve_variable
|
||||
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_thread_id(runtime: Runtime | None) -> str:
|
||||
"""Extract thread_id from the injected tool runtime or LangGraph config."""
|
||||
if runtime is not None:
|
||||
tid = runtime.context.get("thread_id") if runtime.context else None
|
||||
if tid is not None:
|
||||
return str(tid)
|
||||
config = runtime.config or {}
|
||||
tid = config.get("configurable", {}).get("thread_id")
|
||||
if tid is not None:
|
||||
return str(tid)
|
||||
|
||||
try:
|
||||
tid = get_config().get("configurable", {}).get("thread_id")
|
||||
return str(tid) if tid is not None else "default"
|
||||
except RuntimeError:
|
||||
return "default"
|
||||
|
||||
|
||||
def _convert_call_tool_result(call_tool_result: Any) -> Any:
|
||||
"""Convert an MCP CallToolResult to the LangChain ``content_and_artifact`` format.
|
||||
|
||||
Implements the same conversion logic as the adapter without relying on
|
||||
the private ``langchain_mcp_adapters.tools._convert_call_tool_result`` symbol.
|
||||
"""
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.messages.content import create_file_block, create_image_block, create_text_block
|
||||
from langchain_core.tools import ToolException
|
||||
from mcp.types import EmbeddedResource, ImageContent, ResourceLink, TextContent, TextResourceContents
|
||||
|
||||
# Pass ToolMessage through directly (interceptor short-circuit).
|
||||
if isinstance(call_tool_result, ToolMessage):
|
||||
return call_tool_result, None
|
||||
|
||||
# Pass LangGraph Command through directly when langgraph is installed.
|
||||
try:
|
||||
from langgraph.types import Command
|
||||
|
||||
if isinstance(call_tool_result, Command):
|
||||
return call_tool_result, None
|
||||
except ImportError:
|
||||
# langgraph is optional; if unavailable, continue with standard MCP content conversion.
|
||||
pass
|
||||
|
||||
# Convert MCP content blocks to LangChain content blocks.
|
||||
lc_content = []
|
||||
for item in call_tool_result.content:
|
||||
if isinstance(item, TextContent):
|
||||
lc_content.append(create_text_block(text=item.text))
|
||||
elif isinstance(item, ImageContent):
|
||||
lc_content.append(create_image_block(base64=item.data, mime_type=item.mimeType))
|
||||
elif isinstance(item, ResourceLink):
|
||||
mime = item.mimeType or None
|
||||
if mime and mime.startswith("image/"):
|
||||
lc_content.append(create_image_block(url=str(item.uri), mime_type=mime))
|
||||
else:
|
||||
lc_content.append(create_file_block(url=str(item.uri), mime_type=mime))
|
||||
elif isinstance(item, EmbeddedResource):
|
||||
from mcp.types import BlobResourceContents
|
||||
|
||||
res = item.resource
|
||||
if isinstance(res, TextResourceContents):
|
||||
lc_content.append(create_text_block(text=res.text))
|
||||
elif isinstance(res, BlobResourceContents):
|
||||
mime = res.mimeType or None
|
||||
if mime and mime.startswith("image/"):
|
||||
lc_content.append(create_image_block(base64=res.blob, mime_type=mime))
|
||||
else:
|
||||
lc_content.append(create_file_block(base64=res.blob, mime_type=mime))
|
||||
else:
|
||||
lc_content.append(create_text_block(text=str(res)))
|
||||
else:
|
||||
lc_content.append(create_text_block(text=str(item)))
|
||||
|
||||
if call_tool_result.isError:
|
||||
error_parts = [item["text"] for item in lc_content if isinstance(item, dict) and item.get("type") == "text"]
|
||||
raise ToolException("\n".join(error_parts) if error_parts else str(lc_content))
|
||||
|
||||
artifact = None
|
||||
if call_tool_result.structuredContent is not None:
|
||||
artifact = {"structured_content": call_tool_result.structuredContent}
|
||||
|
||||
return lc_content, artifact
|
||||
|
||||
|
||||
def _make_session_pool_tool(
|
||||
tool: BaseTool,
|
||||
server_name: str,
|
||||
connection: dict[str, Any],
|
||||
tool_interceptors: list[Any] | None = None,
|
||||
) -> BaseTool:
|
||||
"""Wrap an MCP tool so it reuses a persistent session from the pool.
|
||||
|
||||
Replaces the per-call session creation with pool-managed sessions scoped
|
||||
by ``(server_name, thread_id)``. This ensures stateful MCP servers (e.g.
|
||||
Playwright) keep their state across tool calls within the same thread.
|
||||
|
||||
The configured ``tool_interceptors`` (OAuth, custom) are preserved and
|
||||
applied on every call before invoking the pooled session.
|
||||
"""
|
||||
# Strip the server-name prefix to recover the original MCP tool name.
|
||||
original_name = tool.name
|
||||
prefix = f"{server_name}_"
|
||||
if original_name.startswith(prefix):
|
||||
original_name = original_name[len(prefix) :]
|
||||
|
||||
pool = get_session_pool()
|
||||
|
||||
async def call_with_persistent_session(
|
||||
runtime: Runtime | None = None,
|
||||
**arguments: Any,
|
||||
) -> Any:
|
||||
thread_id = _extract_thread_id(runtime)
|
||||
session = await pool.get_session(server_name, thread_id, connection)
|
||||
|
||||
if tool_interceptors:
|
||||
from langchain_mcp_adapters.interceptors import MCPToolCallRequest
|
||||
|
||||
async def base_handler(request: MCPToolCallRequest) -> Any:
|
||||
return await session.call_tool(request.name, request.args)
|
||||
|
||||
handler = base_handler
|
||||
for interceptor in reversed(tool_interceptors):
|
||||
outer = handler
|
||||
|
||||
async def wrapped(req: Any, _i: Any = interceptor, _h: Any = outer) -> Any:
|
||||
return await _i(req, _h)
|
||||
|
||||
handler = wrapped
|
||||
|
||||
request = MCPToolCallRequest(
|
||||
name=original_name,
|
||||
args=arguments,
|
||||
server_name=server_name,
|
||||
runtime=runtime,
|
||||
)
|
||||
call_tool_result = await handler(request)
|
||||
else:
|
||||
call_tool_result = await session.call_tool(original_name, arguments)
|
||||
|
||||
return _convert_call_tool_result(call_tool_result)
|
||||
|
||||
return StructuredTool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
args_schema=tool.args_schema,
|
||||
coroutine=call_with_persistent_session,
|
||||
response_format="content_and_artifact",
|
||||
metadata=tool.metadata,
|
||||
)
|
||||
|
||||
|
||||
async def get_mcp_tools() -> list[BaseTool]:
|
||||
"""Get all tools from enabled MCP servers.
|
||||
|
||||
Tools are wrapped with persistent-session logic so that consecutive
|
||||
calls within the same thread reuse the same MCP session.
|
||||
|
||||
Returns:
|
||||
List of LangChain tools from all enabled MCP servers.
|
||||
"""
|
||||
@@ -50,7 +210,7 @@ async def get_mcp_tools() -> list[BaseTool]:
|
||||
existing_headers["Authorization"] = auth_header
|
||||
servers_config[server_name]["headers"] = existing_headers
|
||||
|
||||
tool_interceptors = []
|
||||
tool_interceptors: list[Any] = []
|
||||
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
|
||||
if oauth_interceptor is not None:
|
||||
tool_interceptors.append(oauth_interceptor)
|
||||
@@ -74,20 +234,42 @@ async def get_mcp_tools() -> list[BaseTool]:
|
||||
elif interceptor is not None:
|
||||
logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True)
|
||||
logger.warning(
|
||||
f"Failed to load MCP interceptor {interceptor_path}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True)
|
||||
client = MultiServerMCPClient(
|
||||
servers_config,
|
||||
tool_interceptors=tool_interceptors,
|
||||
tool_name_prefix=True,
|
||||
)
|
||||
|
||||
# Get all tools from all servers
|
||||
# Get all tools from all servers (discovers tool definitions via
|
||||
# temporary sessions – the persistent-session wrapping is applied below).
|
||||
tools = await client.get_tools()
|
||||
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
|
||||
|
||||
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
||||
# Wrap each tool with persistent-session logic.
|
||||
wrapped_tools: list[BaseTool] = []
|
||||
for tool in tools:
|
||||
tool_server: str | None = None
|
||||
for name in servers_config:
|
||||
if tool.name.startswith(f"{name}_"):
|
||||
tool_server = name
|
||||
break
|
||||
|
||||
if tool_server is not None:
|
||||
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
|
||||
else:
|
||||
wrapped_tools.append(tool)
|
||||
|
||||
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
||||
for tool in wrapped_tools:
|
||||
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
|
||||
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
|
||||
|
||||
return tools
|
||||
return wrapped_tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load MCP tools: {e}", exc_info=True)
|
||||
|
||||
@@ -227,9 +227,48 @@ class RunRepository(RunStore):
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
await session.commit()
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
async def update_run_progress(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
total_input_tokens: int | None = None,
|
||||
total_output_tokens: int | None = None,
|
||||
total_tokens: int | None = None,
|
||||
llm_call_count: int | None = None,
|
||||
lead_agent_tokens: int | None = None,
|
||||
subagent_tokens: int | None = None,
|
||||
middleware_tokens: int | None = None,
|
||||
message_count: int | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
first_human_message: str | None = None,
|
||||
) -> None:
|
||||
"""Update token usage + convenience fields while a run is still active."""
|
||||
values: dict[str, Any] = {"updated_at": datetime.now(UTC)}
|
||||
optional_counters = {
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"llm_call_count": llm_call_count,
|
||||
"lead_agent_tokens": lead_agent_tokens,
|
||||
"subagent_tokens": subagent_tokens,
|
||||
"middleware_tokens": middleware_tokens,
|
||||
"message_count": message_count,
|
||||
}
|
||||
for key, value in optional_counters.items():
|
||||
if value is not None:
|
||||
values[key] = value
|
||||
if last_ai_message is not None:
|
||||
values["last_ai_message"] = last_ai_message[:2000]
|
||||
if first_human_message is not None:
|
||||
values["first_human_message"] = first_human_message[:2000]
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id, RunRow.status == "running").values(**values))
|
||||
await session.commit()
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
|
||||
"""Aggregate token usage via a single SQL GROUP BY query."""
|
||||
_completed = RunRow.status.in_(("success", "error"))
|
||||
statuses = ("success", "error", "running") if include_active else ("success", "error")
|
||||
_completed = RunRow.status.in_(statuses)
|
||||
_thread = RunRow.thread_id == thread_id
|
||||
model_name = func.coalesce(RunRow.model_name, "unknown")
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID
|
||||
@@ -46,6 +46,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
*,
|
||||
track_token_usage: bool = True,
|
||||
flush_threshold: int = 20,
|
||||
progress_reporter: Callable[[dict], Awaitable[None]] | None = None,
|
||||
progress_flush_interval: float = 5.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.run_id = run_id
|
||||
@@ -53,10 +55,16 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._store = event_store
|
||||
self._track_tokens = track_token_usage
|
||||
self._flush_threshold = flush_threshold
|
||||
self._progress_reporter = progress_reporter
|
||||
self._progress_flush_interval = progress_flush_interval
|
||||
|
||||
# Write buffer
|
||||
self._buffer: list[dict] = []
|
||||
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
|
||||
self._pending_progress_task: asyncio.Task[None] | None = None
|
||||
self._pending_progress_delayed = False
|
||||
self._progress_dirty = False
|
||||
self._last_progress_flush = 0.0
|
||||
|
||||
# Token accumulators
|
||||
self._total_input_tokens = 0
|
||||
@@ -294,6 +302,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
self._schedule_progress_flush()
|
||||
|
||||
if messages:
|
||||
self._counted_message_llm_run_ids.add(str(run_id))
|
||||
|
||||
@@ -445,6 +455,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
self._schedule_progress_flush()
|
||||
|
||||
def set_first_human_message(self, content: str) -> None:
|
||||
"""Record the first human message for convenience fields."""
|
||||
self._first_human_msg = content[:2000] if content else None
|
||||
@@ -474,6 +486,14 @@ class RunJournal(BaseCallbackHandler):
|
||||
"""Force flush remaining buffer. Called in worker's finally block."""
|
||||
if self._pending_flush_tasks:
|
||||
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
|
||||
while self._pending_progress_task is not None and not self._pending_progress_task.done():
|
||||
if self._pending_progress_delayed:
|
||||
self._pending_progress_task.cancel()
|
||||
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
|
||||
self._progress_dirty = False
|
||||
self._pending_progress_delayed = False
|
||||
break
|
||||
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
|
||||
|
||||
while self._buffer:
|
||||
batch = self._buffer[: self._flush_threshold]
|
||||
@@ -484,6 +504,57 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._buffer = batch + self._buffer
|
||||
raise
|
||||
|
||||
def _schedule_progress_flush(self) -> None:
|
||||
"""Best-effort throttled progress snapshot for active run visibility."""
|
||||
if self._progress_reporter is None:
|
||||
return
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_progress_flush
|
||||
if elapsed < self._progress_flush_interval:
|
||||
self._progress_dirty = True
|
||||
self._schedule_delayed_progress_flush(self._progress_flush_interval - elapsed)
|
||||
return
|
||||
if self._pending_progress_task is not None and not self._pending_progress_task.done():
|
||||
self._progress_dirty = True
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
self._progress_dirty = False
|
||||
self._pending_progress_task = loop.create_task(self._flush_progress_async(snapshot=self.get_completion_data()))
|
||||
|
||||
def _schedule_delayed_progress_flush(self, delay: float) -> None:
|
||||
if self._pending_progress_task is not None and not self._pending_progress_task.done():
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
delay = max(0.0, delay)
|
||||
self._pending_progress_delayed = delay > 0
|
||||
self._pending_progress_task = loop.create_task(self._flush_progress_async(delay=delay))
|
||||
|
||||
async def _flush_progress_async(self, *, snapshot: dict | None = None, delay: float = 0.0) -> None:
|
||||
if self._progress_reporter is None:
|
||||
return
|
||||
if delay > 0:
|
||||
self._pending_progress_delayed = True
|
||||
await asyncio.sleep(delay)
|
||||
self._pending_progress_delayed = False
|
||||
dirty_before_write = self._progress_dirty
|
||||
self._progress_dirty = False
|
||||
snapshot_to_write = snapshot or self.get_completion_data()
|
||||
try:
|
||||
await self._progress_reporter(snapshot_to_write)
|
||||
self._last_progress_flush = time.monotonic()
|
||||
except Exception:
|
||||
logger.warning("Failed to persist progress snapshot for run %s", self.run_id, exc_info=True)
|
||||
if dirty_before_write or self._progress_dirty:
|
||||
self._progress_dirty = False
|
||||
self._pending_progress_task = None
|
||||
self._schedule_delayed_progress_flush(self._progress_flush_interval)
|
||||
|
||||
def get_completion_data(self) -> dict:
|
||||
"""Return accumulated token and message data for run completion."""
|
||||
return {
|
||||
|
||||
@@ -38,6 +38,16 @@ class RunRecord:
|
||||
error: str | None = None
|
||||
model_name: str | None = None
|
||||
store_only: bool = False
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
llm_call_count: int = 0
|
||||
lead_agent_tokens: int = 0
|
||||
subagent_tokens: int = 0
|
||||
middleware_tokens: int = 0
|
||||
message_count: int = 0
|
||||
last_ai_message: str | None = None
|
||||
first_human_message: str | None = None
|
||||
|
||||
|
||||
class RunManager:
|
||||
@@ -102,16 +112,53 @@ class RunManager:
|
||||
error=row.get("error"),
|
||||
model_name=row.get("model_name"),
|
||||
store_only=True,
|
||||
total_input_tokens=row.get("total_input_tokens") or 0,
|
||||
total_output_tokens=row.get("total_output_tokens") or 0,
|
||||
total_tokens=row.get("total_tokens") or 0,
|
||||
llm_call_count=row.get("llm_call_count") or 0,
|
||||
lead_agent_tokens=row.get("lead_agent_tokens") or 0,
|
||||
subagent_tokens=row.get("subagent_tokens") or 0,
|
||||
middleware_tokens=row.get("middleware_tokens") or 0,
|
||||
message_count=row.get("message_count") or 0,
|
||||
last_ai_message=row.get("last_ai_message"),
|
||||
first_human_message=row.get("first_human_message"),
|
||||
)
|
||||
|
||||
async def update_run_completion(self, run_id: str, **kwargs) -> None:
|
||||
"""Persist token usage and completion data to the backing store."""
|
||||
async with self._lock:
|
||||
record = self._runs.get(run_id)
|
||||
if record is not None:
|
||||
for key, value in kwargs.items():
|
||||
if key == "status":
|
||||
continue
|
||||
if hasattr(record, key) and value is not None:
|
||||
setattr(record, key, value)
|
||||
record.updated_at = _now_iso()
|
||||
if self._store is not None:
|
||||
try:
|
||||
await self._store.update_run_completion(run_id, **kwargs)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
|
||||
|
||||
async def update_run_progress(self, run_id: str, **kwargs) -> None:
|
||||
"""Persist a running token/message snapshot without changing status."""
|
||||
should_persist = True
|
||||
async with self._lock:
|
||||
record = self._runs.get(run_id)
|
||||
if record is not None:
|
||||
should_persist = record.status == RunStatus.running
|
||||
if record is not None and should_persist:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(record, key) and value is not None:
|
||||
setattr(record, key, value)
|
||||
record.updated_at = _now_iso()
|
||||
if should_persist and self._store is not None:
|
||||
try:
|
||||
await self._store.update_run_progress(run_id, **kwargs)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run progress for %s", run_id, exc_info=True)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
|
||||
@@ -95,12 +95,30 @@ class RunStore(abc.ABC):
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def update_run_progress(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
total_input_tokens: int | None = None,
|
||||
total_output_tokens: int | None = None,
|
||||
total_tokens: int | None = None,
|
||||
llm_call_count: int | None = None,
|
||||
lead_agent_tokens: int | None = None,
|
||||
subagent_tokens: int | None = None,
|
||||
middleware_tokens: int | None = None,
|
||||
message_count: int | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
first_human_message: str | None = None,
|
||||
) -> None:
|
||||
"""Persist a best-effort running snapshot without changing run status."""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
|
||||
"""Aggregate token usage for completed runs in a thread.
|
||||
|
||||
Returns a dict with keys: total_tokens, total_input_tokens,
|
||||
|
||||
@@ -82,14 +82,22 @@ class MemoryRunStore(RunStore):
|
||||
self._runs[run_id][key] = value
|
||||
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
||||
|
||||
async def update_run_progress(self, run_id, **kwargs):
|
||||
if run_id in self._runs and self._runs[run_id].get("status") == "running":
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
self._runs[run_id][key] = value
|
||||
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
||||
|
||||
async def list_pending(self, *, before=None):
|
||||
now = before or datetime.now(UTC).isoformat()
|
||||
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
|
||||
results.sort(key=lambda r: r["created_at"])
|
||||
return results
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
|
||||
statuses = ("success", "error", "running") if include_active else ("success", "error")
|
||||
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]
|
||||
by_model: dict[str, dict] = {}
|
||||
for r in completed:
|
||||
model = r.get("model_name") or "unknown"
|
||||
|
||||
@@ -153,8 +153,6 @@ async def run_agent(
|
||||
|
||||
journal = None
|
||||
|
||||
journal = None
|
||||
|
||||
# Track whether "events" was requested but skipped
|
||||
if "events" in requested_modes:
|
||||
logger.info(
|
||||
@@ -177,6 +175,7 @@ async def run_agent(
|
||||
thread_id=thread_id,
|
||||
event_store=event_store,
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
progress_reporter=lambda snapshot: run_manager.update_run_progress(run_id, **snapshot),
|
||||
)
|
||||
|
||||
# 1. Mark running
|
||||
@@ -219,6 +218,12 @@ async def run_agent(
|
||||
# manually here because we drive the graph through ``agent.astream(config=...)``
|
||||
# without passing the official ``context=`` parameter.
|
||||
runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"), ctx.app_config)
|
||||
# Expose the run-scoped journal under a sentinel key so middleware can
|
||||
# write audit events (e.g. SafetyFinishReasonMiddleware recording
|
||||
# suppressed tool calls). Double-underscore prefix marks it as a
|
||||
# runtime-internal channel; user code must not depend on the key name.
|
||||
if journal is not None:
|
||||
runtime_ctx["__run_journal"] = journal
|
||||
_install_runtime_context(config, runtime_ctx)
|
||||
runtime = Runtime(context=cast(Any, runtime_ctx), store=store)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
@@ -42,6 +42,7 @@ _DEFAULT_GLOB_MAX_RESULTS = 200
|
||||
_MAX_GLOB_MAX_RESULTS = 1000
|
||||
_DEFAULT_GREP_MAX_RESULTS = 100
|
||||
_MAX_GREP_MAX_RESULTS = 500
|
||||
_DEFAULT_WRITE_FILE_ERROR_MAX_CHARS = 2000
|
||||
_LOCAL_BASH_CWD_COMMANDS = {"cd", "pushd"}
|
||||
_LOCAL_BASH_COMMAND_WRAPPERS = {"command", "builtin"}
|
||||
_LOCAL_BASH_COMMAND_PREFIX_KEYWORDS = {"!", "{", "case", "do", "elif", "else", "for", "if", "select", "then", "time", "until", "while"}
|
||||
@@ -435,6 +436,42 @@ def _sanitize_error(error: Exception, runtime: Runtime | None = None) -> str:
|
||||
return msg
|
||||
|
||||
|
||||
def _truncate_write_file_error_detail(detail: str, max_chars: int) -> str:
|
||||
"""Middle-truncate write_file error details, preserving the head and tail."""
|
||||
if max_chars == 0:
|
||||
return detail
|
||||
if len(detail) <= max_chars:
|
||||
return detail
|
||||
total = len(detail)
|
||||
marker_max_len = len(f"\n... [write_file error truncated: {total} chars skipped] ...\n")
|
||||
kept = max(0, max_chars - marker_max_len)
|
||||
if kept == 0:
|
||||
return detail[:max_chars]
|
||||
head_len = kept // 2
|
||||
tail_len = kept - head_len
|
||||
skipped = total - kept
|
||||
marker = f"\n... [write_file error truncated: {skipped} chars skipped] ...\n"
|
||||
return f"{detail[:head_len]}{marker}{detail[-tail_len:] if tail_len > 0 else ''}"
|
||||
|
||||
|
||||
def _format_write_file_error(
|
||||
requested_path: str,
|
||||
error: Exception,
|
||||
runtime: Runtime | None = None,
|
||||
*,
|
||||
max_chars: int = _DEFAULT_WRITE_FILE_ERROR_MAX_CHARS,
|
||||
) -> str:
|
||||
"""Return a bounded, sanitized error string for write_file failures."""
|
||||
header = f"Error: Failed to write file '{requested_path}'"
|
||||
detail = _sanitize_error(error, runtime)
|
||||
if max_chars == 0:
|
||||
return f"{header}: {detail}"
|
||||
detail_budget = max_chars - len(header) - 2
|
||||
if detail_budget <= 0:
|
||||
return _truncate_write_file_error_detail(f"{header}: {detail}", max_chars)
|
||||
return f"{header}: {_truncate_write_file_error_detail(detail, detail_budget)}"
|
||||
|
||||
|
||||
def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
|
||||
"""Replace virtual /mnt/user-data paths with actual thread data paths.
|
||||
|
||||
@@ -1651,9 +1688,9 @@ def write_file_tool(
|
||||
append: Whether to append content to the end of the file instead of overwriting it. Defaults to false.
|
||||
"""
|
||||
try:
|
||||
requested_path = path
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
validate_local_tool_path(path, thread_data)
|
||||
@@ -1664,15 +1701,21 @@ def write_file_tool(
|
||||
sandbox.write_file(path, content, append)
|
||||
return "OK"
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
return _format_write_file_error(requested_path, e, runtime)
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied writing to file: {requested_path}"
|
||||
return _truncate_write_file_error_detail(
|
||||
f"Error: Permission denied writing to file: {requested_path}",
|
||||
_DEFAULT_WRITE_FILE_ERROR_MAX_CHARS,
|
||||
)
|
||||
except IsADirectoryError:
|
||||
return f"Error: Path is a directory, not a file: {requested_path}"
|
||||
return _truncate_write_file_error_detail(
|
||||
f"Error: Path is a directory, not a file: {requested_path}",
|
||||
_DEFAULT_WRITE_FILE_ERROR_MAX_CHARS,
|
||||
)
|
||||
except OSError as e:
|
||||
return f"Error: Failed to write file '{requested_path}': {_sanitize_error(e, runtime)}"
|
||||
return _format_write_file_error(requested_path, e, runtime)
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime)}"
|
||||
return _format_write_file_error(requested_path, e, runtime)
|
||||
|
||||
|
||||
async def _write_file_tool_async(
|
||||
|
||||
@@ -7,6 +7,7 @@ from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Annotated, Any, cast
|
||||
|
||||
from langchain.tools import InjectedToolCallId, tool
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
@@ -99,15 +100,31 @@ def _schedule_deferred_subagent_cleanup(task_id: str, trace_id: str, max_polls:
|
||||
|
||||
|
||||
def _find_usage_recorder(runtime: Any) -> Any | None:
|
||||
"""Find a callback handler with ``record_external_llm_usage_records`` in the runtime config."""
|
||||
"""Find a callback handler with ``record_external_llm_usage_records`` in the runtime config.
|
||||
|
||||
LangChain may pass ``config["callbacks"]`` in three different shapes:
|
||||
|
||||
- ``None`` (no callbacks registered): no recorder.
|
||||
- A plain ``list[BaseCallbackHandler]``: iterate it directly.
|
||||
- A ``BaseCallbackManager`` instance (e.g. ``AsyncCallbackManager`` on async
|
||||
tool runs): managers are not iterable, so we unwrap ``.handlers`` first.
|
||||
|
||||
Any other shape (e.g. a single handler object accidentally passed without a
|
||||
list wrapper) cannot be iterated safely; treat it as "no recorder" rather
|
||||
than raise.
|
||||
"""
|
||||
if runtime is None:
|
||||
return None
|
||||
config = getattr(runtime, "config", None)
|
||||
if not isinstance(config, dict):
|
||||
return None
|
||||
callbacks = config.get("callbacks", [])
|
||||
callbacks = config.get("callbacks")
|
||||
if isinstance(callbacks, BaseCallbackManager):
|
||||
callbacks = callbacks.handlers
|
||||
if not callbacks:
|
||||
return None
|
||||
if not isinstance(callbacks, list):
|
||||
return None
|
||||
for cb in callbacks:
|
||||
if hasattr(cb, "record_external_llm_usage_records"):
|
||||
return cb
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
"""End-to-end demo: SafetyFinishReasonMiddleware on the real DeerFlow lead-agent.
|
||||
|
||||
What it proves
|
||||
--------------
|
||||
- The real ``make_lead_agent`` / ``DeerFlowClient`` pipeline is built (full
|
||||
18-middleware chain, sandbox, tools, etc.).
|
||||
- A model that returns ``finish_reason='content_filter'`` + ``tool_calls``
|
||||
triggers SafetyFinishReasonMiddleware.
|
||||
- LangChain's tool router never invokes ``write_file`` — the truncated
|
||||
arguments do **not** reach the sandbox.
|
||||
- A ``safety_termination`` custom event is emitted on the stream and the
|
||||
final AIMessage carries the observability stamp.
|
||||
|
||||
Run from backend/ directory:
|
||||
PYTHONPATH=. uv run python scripts/e2e_safety_termination_demo.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake provider that mimics Moonshot's content_filter behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ContentFilteredFakeModel(BaseChatModel):
|
||||
"""First call returns finish_reason=content_filter + truncated write_file
|
||||
tool_call. Subsequent calls return a normal stop response so the agent
|
||||
can terminate (the middleware should make a second call unnecessary by
|
||||
clearing tool_calls, but we keep this safety net in case loop-detection
|
||||
or anything else triggers another model invocation)."""
|
||||
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-content-filtered"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self.call_count += 1
|
||||
if self.call_count == 1:
|
||||
msg = AIMessage(
|
||||
content="# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_truncated_write",
|
||||
"name": "write_file",
|
||||
"args": {
|
||||
"path": "/mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md",
|
||||
"content": "# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与",
|
||||
},
|
||||
}
|
||||
],
|
||||
response_metadata={
|
||||
"finish_reason": "content_filter",
|
||||
"model_name": "kimi-k2.6",
|
||||
"model_provider": "openai",
|
||||
},
|
||||
)
|
||||
else:
|
||||
msg = AIMessage(
|
||||
content="(secondary call, should not be needed)",
|
||||
response_metadata={"finish_reason": "stop", "model_name": "kimi-k2.6"},
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=msg)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Driver
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> int:
|
||||
# Inject the fake model BEFORE constructing the client. Both the
|
||||
# client module and the lead-agent module bind ``create_chat_model``
|
||||
# at import time via ``from deerflow.models import create_chat_model``,
|
||||
# so we patch both attribute slots — the source-of-truth patch on
|
||||
# ``factory.create_chat_model`` doesn't propagate back into already-
|
||||
# imported names.
|
||||
import deerflow.agents.lead_agent.agent as lead_agent_module
|
||||
import deerflow.client as client_module
|
||||
|
||||
fake = _ContentFilteredFakeModel()
|
||||
originals = {
|
||||
"lead": lead_agent_module.create_chat_model,
|
||||
"client": client_module.create_chat_model,
|
||||
}
|
||||
|
||||
def fake_create_chat_model(*args, **kwargs):
|
||||
return fake
|
||||
|
||||
lead_agent_module.create_chat_model = fake_create_chat_model
|
||||
client_module.create_chat_model = fake_create_chat_model
|
||||
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
try:
|
||||
client = DeerFlowClient()
|
||||
|
||||
print("\n=== Streaming a turn through the real lead-agent ===")
|
||||
events: list[dict[str, Any]] = []
|
||||
for event in client.stream(
|
||||
"帮我整理一下最近一周政经新闻,写到 /mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md",
|
||||
thread_id="e2e-safety-1",
|
||||
):
|
||||
events.append({"type": event.type, "data": event.data})
|
||||
|
||||
# ---- Assertions ----
|
||||
safety_event = next(
|
||||
(e for e in events if e["type"] == "custom" and isinstance(e["data"], dict) and e["data"].get("type") == "safety_termination"),
|
||||
None,
|
||||
)
|
||||
final_values = next(
|
||||
(e for e in reversed(events) if e["type"] == "values"),
|
||||
None,
|
||||
)
|
||||
tool_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "tool"]
|
||||
ai_tool_call_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "ai" and e["data"].get("tool_calls")]
|
||||
|
||||
print(f"\n[stats] total stream events: {len(events)}")
|
||||
print(f"[stats] model call count: {fake.call_count}")
|
||||
print(f"[stats] tool messages on stream: {len(tool_messages)}")
|
||||
print(f"[stats] AI messages carrying tool_calls: {len(ai_tool_call_messages)}")
|
||||
|
||||
print("\n[event] safety_termination custom event:")
|
||||
if safety_event is None:
|
||||
print(" *** NOT FOUND ***")
|
||||
return 1
|
||||
for k, v in safety_event["data"].items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
print("\n[state] final AIMessage from last values snapshot:")
|
||||
if final_values is None:
|
||||
print(" *** no values snapshot ***")
|
||||
return 1
|
||||
# `values` event carries `_serialize_message` dicts, not Message objects.
|
||||
final_messages = final_values["data"].get("messages") or []
|
||||
last_ai = next((m for m in reversed(final_messages) if isinstance(m, dict) and m.get("type") == "ai"), None)
|
||||
if last_ai is None:
|
||||
print(" *** no AIMessage in final state ***")
|
||||
print(f" message types seen: {[m.get('type') if isinstance(m, dict) else type(m).__name__ for m in final_messages]}")
|
||||
return 1
|
||||
|
||||
tool_calls = last_ai.get("tool_calls") or []
|
||||
additional_kwargs = last_ai.get("additional_kwargs") or {}
|
||||
response_metadata = last_ai.get("response_metadata") or {}
|
||||
content = last_ai.get("content")
|
||||
|
||||
print(f" tool_calls (must be empty): {tool_calls}")
|
||||
print(f" additional_kwargs.safety_termination: {additional_kwargs.get('safety_termination')}")
|
||||
content_preview = (content if isinstance(content, str) else str(content))[:200]
|
||||
print(f" content[:200]: {content_preview!r}")
|
||||
print(f" response_metadata.finish_reason: {response_metadata.get('finish_reason')}")
|
||||
|
||||
# NOTE: `client._serialize_message` does not include `response_metadata`
|
||||
# in the values-event payload (client-layer behaviour, unrelated to the
|
||||
# middleware). The middleware *does* preserve finish_reason on the
|
||||
# AIMessage object — see test_safety_finish_reason_middleware.py::
|
||||
# TestMessageRewrite::test_preserves_response_metadata_finish_reason.
|
||||
# Here we assert on the observability stamp, which carries the same
|
||||
# evidence and is in the serialized payload.
|
||||
stamp = additional_kwargs.get("safety_termination") or {}
|
||||
failures = []
|
||||
if tool_calls:
|
||||
failures.append("final AIMessage still has tool_calls — middleware did NOT clear them")
|
||||
if not stamp:
|
||||
failures.append("final AIMessage missing safety_termination observability stamp")
|
||||
if tool_messages:
|
||||
failures.append(f"tool node was invoked: {len(tool_messages)} ToolMessage(s) on stream")
|
||||
if stamp.get("reason_value") != "content_filter":
|
||||
failures.append(f"safety_termination.reason_value was {stamp.get('reason_value')!r}, expected 'content_filter'")
|
||||
if safety_event is None:
|
||||
failures.append("safety_termination custom event was not emitted on the stream")
|
||||
|
||||
if failures:
|
||||
print("\n=== FAIL ===")
|
||||
for f in failures:
|
||||
print(f" - {f}")
|
||||
return 1
|
||||
|
||||
print("\n=== PASS ===")
|
||||
print(" - tool_calls cleared on final AIMessage")
|
||||
print(" - tool node never invoked (no ToolMessage on stream)")
|
||||
print(" - safety_termination custom event emitted")
|
||||
print(" - observability stamp written to additional_kwargs")
|
||||
print(" - response_metadata.finish_reason preserved for downstream SSE")
|
||||
return 0
|
||||
finally:
|
||||
lead_agent_module.create_chat_model = originals["lead"]
|
||||
client_module.create_chat_model = originals["client"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -218,6 +218,70 @@ class TestBuildPatchedMessagesPatching:
|
||||
|
||||
assert mw._build_patched_messages(msgs) is None
|
||||
|
||||
def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}),
|
||||
_ai_with_tool_calls(
|
||||
[
|
||||
_tc("web_search", "web_search:11"),
|
||||
_tc("web_search", "web_search:12"),
|
||||
_tc("web_search", "web_search:13"),
|
||||
]
|
||||
),
|
||||
_tool_msg("web_search:11", "web_search"),
|
||||
_tool_msg("web_search:12", "web_search"),
|
||||
_tool_msg("web_search:13", "web_search"),
|
||||
_ai_with_tool_calls(
|
||||
[
|
||||
_tc("web_search", "web_search:9"),
|
||||
_tc("web_search", "web_search:10"),
|
||||
_tc("web_search", "web_search:11"),
|
||||
]
|
||||
),
|
||||
_tool_msg("web_search:9", "web_search"),
|
||||
_tool_msg("web_search:10", "web_search"),
|
||||
_tool_msg("web_search:11", "web_search"),
|
||||
]
|
||||
|
||||
assert mw._build_patched_messages(msgs) is None
|
||||
|
||||
def test_reused_tool_call_id_patches_second_dangling_occurrence(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||
_tool_msg("web_search:11", "web_search"),
|
||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert patched[1].tool_call_id == "web_search:11"
|
||||
assert patched[1].status == "success"
|
||||
assert isinstance(patched[3], ToolMessage)
|
||||
assert patched[3].tool_call_id == "web_search:11"
|
||||
assert patched[3].status == "error"
|
||||
|
||||
def test_reused_tool_call_id_consumes_later_result_for_first_dangling_occurrence(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
result = _tool_msg("web_search:11", "web_search")
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||
result,
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert patched[1] is result
|
||||
assert patched[1].status == "success"
|
||||
assert isinstance(patched[3], ToolMessage)
|
||||
assert patched[3].tool_call_id == "web_search:11"
|
||||
assert patched[3].status == "error"
|
||||
|
||||
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Regression tests for gateway config freshness on the request hot path.
|
||||
|
||||
Bytedance/deer-flow issue #3107 BUG-001: the worker and lead-agent path
|
||||
captured ``app.state.config`` at gateway startup. ``config.yaml`` edits during
|
||||
runtime were therefore ignored — ``get_app_config()``'s mtime-based reload
|
||||
existed but was bypassed because the snapshot object was passed through
|
||||
explicitly.
|
||||
|
||||
These tests pin the desired behaviour: a request-time ``get_config`` call must
|
||||
observe the most recent on-disk ``config.yaml`` (mtime reload), and the
|
||||
runtime ``ContextVar`` override must keep working for per-request injection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway import deps as gateway_deps
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.config.app_config import (
|
||||
AppConfig,
|
||||
pop_current_app_config,
|
||||
push_current_app_config,
|
||||
reset_app_config,
|
||||
set_app_config,
|
||||
)
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_app_config_singleton():
|
||||
"""Ensure each test starts with a clean module-level cache."""
|
||||
reset_app_config()
|
||||
yield
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def _write_config_yaml(path: Path, *, log_level: str) -> None:
|
||||
path.write_text(
|
||||
f"""
|
||||
sandbox:
|
||||
use: deerflow.sandbox.local.provider:LocalSandboxProvider
|
||||
log_level: {log_level}
|
||||
""".strip()
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def _build_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/probe")
|
||||
def probe(cfg: AppConfig = Depends(get_config)):
|
||||
return {"log_level": cfg.log_level}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_get_config_reflects_file_mtime_reload(tmp_path, monkeypatch):
|
||||
"""Editing config.yaml at runtime must be visible to /probe without restart.
|
||||
|
||||
This is the literal repro for the issue: the gateway must not freeze the
|
||||
config to whatever was on disk when the process started.
|
||||
"""
|
||||
config_file = tmp_path / "config.yaml"
|
||||
_write_config_yaml(config_file, log_level="info")
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
|
||||
|
||||
app = _build_app()
|
||||
client = TestClient(app)
|
||||
assert client.get("/probe").json() == {"log_level": "info"}
|
||||
|
||||
# Edit the file and bump its mtime — simulating a maintainer changing
|
||||
# max_tokens / model settings in production while the gateway is live.
|
||||
_write_config_yaml(config_file, log_level="debug")
|
||||
future_mtime = config_file.stat().st_mtime + 5
|
||||
os.utime(config_file, (future_mtime, future_mtime))
|
||||
|
||||
assert client.get("/probe").json() == {"log_level": "debug"}
|
||||
|
||||
|
||||
def test_get_config_respects_runtime_context_override(tmp_path, monkeypatch):
|
||||
"""Per-request ``push_current_app_config`` injection must still win."""
|
||||
config_file = tmp_path / "config.yaml"
|
||||
_write_config_yaml(config_file, log_level="info")
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
|
||||
|
||||
override = AppConfig(sandbox=SandboxConfig(use="test"), log_level="trace")
|
||||
push_current_app_config(override)
|
||||
try:
|
||||
app = _build_app()
|
||||
client = TestClient(app)
|
||||
assert client.get("/probe").json() == {"log_level": "trace"}
|
||||
finally:
|
||||
pop_current_app_config()
|
||||
|
||||
|
||||
def test_get_config_respects_test_set_app_config():
|
||||
"""``set_app_config`` (used by upload/skills router tests) keeps working."""
|
||||
injected = AppConfig(sandbox=SandboxConfig(use="test"), log_level="warning")
|
||||
set_app_config(injected)
|
||||
|
||||
app = _build_app()
|
||||
client = TestClient(app)
|
||||
assert client.get("/probe").json() == {"log_level": "warning"}
|
||||
|
||||
|
||||
def test_run_context_app_config_reflects_yaml_edit(tmp_path, monkeypatch):
|
||||
"""``RunContext.app_config`` must follow live `config.yaml` edits.
|
||||
|
||||
BUG-001 review feedback: the run-context that feeds worker / lead-agent
|
||||
factories must observe the same mtime reload that `get_config()` does;
|
||||
otherwise stale config slips back in through the run path even after the
|
||||
request dependency is fixed.
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.gateway.deps import get_run_context
|
||||
|
||||
config_file = tmp_path / "config.yaml"
|
||||
_write_config_yaml(config_file, log_level="info")
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
|
||||
|
||||
app = FastAPI()
|
||||
# Sentinel values for the rest of the RunContext wiring — we only care
|
||||
# about ``ctx.app_config`` for this assertion.
|
||||
app.state.checkpointer = MagicMock()
|
||||
app.state.store = MagicMock()
|
||||
app.state.run_event_store = MagicMock()
|
||||
app.state.run_events_config = {"frozen": "startup"}
|
||||
app.state.thread_store = MagicMock()
|
||||
|
||||
@app.get("/run-ctx-log-level")
|
||||
def probe(ctx=Depends(get_run_context)):
|
||||
return {
|
||||
"log_level": ctx.app_config.log_level,
|
||||
"run_events_config": ctx.run_events_config,
|
||||
}
|
||||
|
||||
client = TestClient(app)
|
||||
first = client.get("/run-ctx-log-level").json()
|
||||
assert first == {"log_level": "info", "run_events_config": {"frozen": "startup"}}
|
||||
|
||||
_write_config_yaml(config_file, log_level="debug")
|
||||
future_mtime = config_file.stat().st_mtime + 5
|
||||
os.utime(config_file, (future_mtime, future_mtime))
|
||||
|
||||
second = client.get("/run-ctx-log-level").json()
|
||||
# app_config follows the edit; run_events_config stays frozen to the
|
||||
# startup snapshot we wrote onto app.state above.
|
||||
assert second == {"log_level": "debug", "run_events_config": {"frozen": "startup"}}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exception",
|
||||
[
|
||||
FileNotFoundError("config.yaml not found"),
|
||||
PermissionError("config.yaml not readable"),
|
||||
ValueError("invalid config"),
|
||||
RuntimeError("yaml parse error"),
|
||||
],
|
||||
)
|
||||
def test_get_config_returns_503_on_any_load_failure(monkeypatch, exception):
|
||||
"""Any failure to materialise the config must surface as 503, not 500.
|
||||
|
||||
Bytedance/deer-flow issue #3107 BUG-001 review: the original snapshot
|
||||
contract returned 503 when ``app.state.config is None``. The first cut of
|
||||
this fix only mapped ``FileNotFoundError`` to 503, which left
|
||||
``PermissionError`` / ``yaml.YAMLError`` / ``ValidationError`` etc. bubbling
|
||||
up as 500. Catch every load failure at the request boundary.
|
||||
"""
|
||||
|
||||
def _broken_get_app_config():
|
||||
raise exception
|
||||
|
||||
monkeypatch.setattr(gateway_deps, "get_app_config", _broken_get_app_config)
|
||||
|
||||
app = _build_app()
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/probe")
|
||||
|
||||
assert response.status_code == 503
|
||||
assert response.json() == {"detail": "Configuration not available"}
|
||||
@@ -1,41 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def test_get_config_returns_app_state_config():
|
||||
"""get_config should return the exact AppConfig stored on app.state."""
|
||||
app = FastAPI()
|
||||
config = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
app.state.config = config
|
||||
|
||||
@app.get("/probe")
|
||||
def probe(cfg: AppConfig = Depends(get_config)):
|
||||
return {"same_identity": cfg is config, "log_level": cfg.log_level}
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/probe")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"same_identity": True, "log_level": "info"}
|
||||
|
||||
|
||||
def test_get_config_reads_updated_app_state():
|
||||
"""Swapping app.state.config should be visible to the dependency."""
|
||||
app = FastAPI()
|
||||
app.state.config = AppConfig(sandbox=SandboxConfig(use="test"), log_level="info")
|
||||
|
||||
@app.get("/log-level")
|
||||
def log_level(cfg: AppConfig = Depends(get_config)):
|
||||
return {"level": cfg.log_level}
|
||||
|
||||
client = TestClient(app)
|
||||
assert client.get("/log-level").json() == {"level": "info"}
|
||||
|
||||
app.state.config = app.state.config.model_copy(update={"log_level": "debug"})
|
||||
assert client.get("/log-level").json() == {"level": "debug"}
|
||||
@@ -17,7 +17,7 @@ from fastapi import FastAPI
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_langgraph_runtime(_app):
|
||||
async def _noop_langgraph_runtime(_app, _startup_config):
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@@ -81,6 +81,94 @@ def test_normalize_input_passthrough():
|
||||
assert result == {"custom_key": "value"}
|
||||
|
||||
|
||||
def test_normalize_input_preserves_additional_kwargs_and_id():
|
||||
"""Regression: gh #3132 — frontend ships uploaded-file metadata in
|
||||
additional_kwargs.files (and a client-side message id). The gateway must
|
||||
not strip them before the graph runs, otherwise UploadsMiddleware reports
|
||||
"(empty)" for new uploads and the frontend message loses its file chip.
|
||||
"""
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
files = [{"filename": "a.csv", "size": 100, "path": "/mnt/user-data/uploads/a.csv", "status": "uploaded"}]
|
||||
result = normalize_input(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"type": "human",
|
||||
"id": "client-msg-1",
|
||||
"name": "user-input",
|
||||
"content": [{"type": "text", "text": "clean it"}],
|
||||
"additional_kwargs": {"files": files, "custom": "keep-me"},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
assert len(result["messages"]) == 1
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg, HumanMessage)
|
||||
assert msg.id == "client-msg-1"
|
||||
assert msg.name == "user-input"
|
||||
assert msg.content == [{"type": "text", "text": "clean it"}]
|
||||
assert msg.additional_kwargs == {"files": files, "custom": "keep-me"}
|
||||
|
||||
|
||||
def test_normalize_input_passes_through_basemessage_instances():
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
msg = HumanMessage(content="hello", id="m-1", additional_kwargs={"files": [{"filename": "x"}]})
|
||||
result = normalize_input({"messages": [msg]})
|
||||
assert result["messages"][0] is msg
|
||||
|
||||
|
||||
def test_normalize_input_rejects_malformed_message_with_400():
|
||||
"""Boundary validation: ``convert_to_messages`` raises ``ValueError`` when a
|
||||
message dict is missing ``role``/``type``/``content``. ``normalize_input``
|
||||
runs inside the gateway HTTP boundary, so a malformed payload should surface
|
||||
as a 400 referencing the offending entry — not bubble up as a 500.
|
||||
|
||||
Raised after the Copilot review on PR #3136.
|
||||
"""
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
normalize_input({"messages": [{"role": "human", "content": "ok"}, {"oops": "no role here"}]})
|
||||
assert excinfo.value.status_code == 400
|
||||
assert "input.messages[1]" in excinfo.value.detail
|
||||
|
||||
|
||||
def test_normalize_input_handles_non_human_roles():
|
||||
"""The previous implementation collapsed every role to HumanMessage with a
|
||||
`# TODO: handle other message types` comment. Resuming a thread with prior
|
||||
AI/tool messages would silently rewrite them as human turns — corrupting
|
||||
the conversation. Use langchain's standard conversion so ai/system/tool
|
||||
roles round-trip correctly.
|
||||
"""
|
||||
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
|
||||
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
result = normalize_input(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "ai", "content": "hi", "id": "ai-1"},
|
||||
{"role": "tool", "content": "result", "tool_call_id": "call-1"},
|
||||
]
|
||||
}
|
||||
)
|
||||
types = [type(m) for m in result["messages"]]
|
||||
assert types == [SystemMessage, AIMessage, ToolMessage]
|
||||
assert result["messages"][1].id == "ai-1"
|
||||
assert result["messages"][2].tool_call_id == "call-1"
|
||||
|
||||
|
||||
def test_build_run_config_basic():
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
|
||||
@@ -336,8 +336,11 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
)
|
||||
|
||||
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
|
||||
# verify the custom middleware is injected correctly
|
||||
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
|
||||
# verify the custom middleware is injected correctly.
|
||||
# Chain tail order after the custom middleware is:
|
||||
# ..., custom, SafetyFinishReasonMiddleware, ClarificationMiddleware
|
||||
# so the custom mock sits at index [-3].
|
||||
assert len(middlewares) > 0 and isinstance(middlewares[-3], MagicMock)
|
||||
|
||||
|
||||
def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch):
|
||||
|
||||
@@ -0,0 +1,409 @@
|
||||
"""Tests for the MCP persistent-session pool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.mcp.session_pool import MCPSessionPool, get_session_pool, reset_session_pool
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_pool():
|
||||
reset_session_pool()
|
||||
yield
|
||||
reset_session_pool()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCPSessionPool unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_creates_new():
|
||||
"""First call for a key creates a new session."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
session = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert session is mock_session
|
||||
mock_session.initialize.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_reuses_existing():
|
||||
"""Second call for the same key returns the cached session."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
s2 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert s1 is s2
|
||||
# Only one session should have been created.
|
||||
assert mock_cm.__aenter__.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_scope_creates_different_session():
|
||||
"""Different scope keys get different sessions."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
sessions = [AsyncMock(), AsyncMock()]
|
||||
idx = 0
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.enter_count = 0
|
||||
|
||||
async def __aenter__(self):
|
||||
nonlocal idx
|
||||
s = sessions[idx]
|
||||
idx += 1
|
||||
self.enter_count += 1
|
||||
return s
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return False
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=lambda *a, **kw: CmFactory()):
|
||||
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
s2 = await pool.get_session("server", "thread-2", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert s1 is not s2
|
||||
assert s1 is sessions[0]
|
||||
assert s2 is sessions[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_eviction():
|
||||
"""Oldest entries are evicted when the pool is full."""
|
||||
pool = MCPSessionPool()
|
||||
pool.MAX_SESSIONS = 2
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return AsyncMock()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
cms: list[CmFactory] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = CmFactory()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
# Pool is full (2). Adding t3 should evict t1.
|
||||
await pool.get_session("s", "t3", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert cms[0].closed is True
|
||||
assert cms[1].closed is False
|
||||
assert cms[2].closed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_scope():
|
||||
"""close_scope shuts down sessions for a specific scope key."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return AsyncMock()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
cms: list[CmFactory] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = CmFactory()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
await pool.close_scope("t1")
|
||||
|
||||
assert cms[0].closed is True
|
||||
assert cms[1].closed is False
|
||||
|
||||
# t2 session still exists.
|
||||
assert ("s", "t2") in pool._entries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_all():
|
||||
"""close_all shuts down every session."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return AsyncMock()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
cms: list[CmFactory] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = CmFactory()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await pool.get_session("s1", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await pool.get_session("s2", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
await pool.close_all()
|
||||
|
||||
assert all(cm.closed for cm in cms)
|
||||
assert len(pool._entries) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_session_pool_singleton():
|
||||
"""get_session_pool returns the same instance."""
|
||||
p1 = get_session_pool()
|
||||
p2 = get_session_pool()
|
||||
assert p1 is p2
|
||||
|
||||
|
||||
def test_reset_session_pool():
|
||||
"""reset_session_pool clears the singleton."""
|
||||
p1 = get_session_pool()
|
||||
reset_session_pool()
|
||||
p2 = get_session_pool()
|
||||
assert p1 is not p2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: _make_session_pool_tool uses the pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_wrapping():
|
||||
"""The wrapper tool delegates to a pool-managed session."""
|
||||
# Build a dummy StructuredTool (as returned by langchain-mcp-adapters).
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
url: str = Field(..., description="url")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="playwright_navigate",
|
||||
description="Navigate browser",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
connection = {"transport": "stdio", "command": "pw", "args": []}
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
|
||||
|
||||
# Simulate a tool call with a runtime context containing thread_id.
|
||||
mock_runtime = MagicMock()
|
||||
mock_runtime.context = {"thread_id": "thread-42"}
|
||||
mock_runtime.config = {}
|
||||
|
||||
await wrapped.coroutine(runtime=mock_runtime, url="https://example.com")
|
||||
|
||||
mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_extracts_thread_id():
|
||||
"""Thread ID is extracted from runtime.config when not in context."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
x: int = Field(..., description="x")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="server_tool",
|
||||
description="test",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
mock_runtime = MagicMock()
|
||||
mock_runtime.context = {}
|
||||
mock_runtime.config = {"configurable": {"thread_id": "from-config"}}
|
||||
|
||||
await wrapped.coroutine(runtime=mock_runtime, x=1)
|
||||
|
||||
# Verify the session was created with the correct scope key.
|
||||
pool = get_session_pool()
|
||||
assert ("server", "from-config") in pool._entries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_default_scope():
|
||||
"""When no thread_id is available, 'default' is used as scope key."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
x: int = Field(..., description="x")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="server_tool",
|
||||
description="test",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
# No thread_id in runtime at all.
|
||||
await wrapped.coroutine(runtime=None, x=1)
|
||||
|
||||
pool = get_session_pool()
|
||||
assert ("server", "default") in pool._entries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_get_config_fallback():
|
||||
"""When runtime is None, get_config() provides thread_id as fallback."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
x: int = Field(..., description="x")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="server_tool",
|
||||
description="test",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
fake_config = {"configurable": {"thread_id": "from-langgraph-config"}}
|
||||
|
||||
with (
|
||||
patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm),
|
||||
patch("deerflow.mcp.tools.get_config", return_value=fake_config),
|
||||
):
|
||||
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
# runtime=None — get_config() fallback should provide thread_id
|
||||
await wrapped.coroutine(runtime=None, x=1)
|
||||
|
||||
pool = get_session_pool()
|
||||
assert ("server", "from-langgraph-config") in pool._entries
|
||||
|
||||
|
||||
def test_session_pool_tool_sync_wrapper_path_is_safe():
|
||||
"""Sync wrapper (tool.func) invocation doesn't crash on cross-loop access."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||
|
||||
class Args(BaseModel):
|
||||
url: str = Field(..., description="url")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="playwright_navigate",
|
||||
description="Navigate browser",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
connection = {"transport": "stdio", "command": "pw", "args": []}
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
|
||||
# Attach the sync wrapper exactly as get_mcp_tools() does.
|
||||
wrapped.func = make_sync_tool_wrapper(wrapped.coroutine, wrapped.name)
|
||||
|
||||
# Call via the sync path (asyncio.run in a worker thread).
|
||||
# runtime is not supplied so _extract_thread_id falls back to "default".
|
||||
wrapped.func(url="https://example.com")
|
||||
|
||||
mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"})
|
||||
@@ -714,6 +714,110 @@ class TestExternalUsageRecords:
|
||||
assert j._subagent_tokens == 0
|
||||
|
||||
|
||||
class TestProgressSnapshots:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_reports_progress_snapshot(self):
|
||||
snapshots: list[dict] = []
|
||||
|
||||
async def reporter(snapshot: dict) -> None:
|
||||
snapshots.append(snapshot)
|
||||
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal(
|
||||
"r1",
|
||||
"t1",
|
||||
store,
|
||||
flush_threshold=100,
|
||||
progress_reporter=reporter,
|
||||
progress_flush_interval=0,
|
||||
)
|
||||
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
j.on_llm_end(_make_llm_response("Answer", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
await j.flush()
|
||||
|
||||
assert snapshots
|
||||
assert snapshots[-1]["total_tokens"] == 15
|
||||
assert snapshots[-1]["llm_call_count"] == 1
|
||||
assert snapshots[-1]["message_count"] == 1
|
||||
assert snapshots[-1]["last_ai_message"] == "Answer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_throttled_progress_flush_emits_trailing_snapshot(self):
|
||||
snapshots: list[dict] = []
|
||||
trailing_seen = asyncio.Event()
|
||||
|
||||
async def reporter(snapshot: dict) -> None:
|
||||
snapshots.append(snapshot)
|
||||
if snapshot["total_tokens"] == 45:
|
||||
trailing_seen.set()
|
||||
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal(
|
||||
"r1",
|
||||
"t1",
|
||||
store,
|
||||
flush_threshold=100,
|
||||
progress_reporter=reporter,
|
||||
progress_flush_interval=0.01,
|
||||
)
|
||||
j.on_llm_end(
|
||||
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
j.on_llm_end(
|
||||
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
await asyncio.wait_for(trailing_seen.wait(), timeout=1.0)
|
||||
await j.flush()
|
||||
|
||||
assert len(snapshots) >= 2
|
||||
assert snapshots[-1]["total_tokens"] == 45
|
||||
assert snapshots[-1]["llm_call_count"] == 2
|
||||
assert snapshots[-1]["last_ai_message"] == "Second"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_flush_cancels_delayed_progress_without_final_progress_write(self):
|
||||
snapshots: list[dict] = []
|
||||
|
||||
async def reporter(snapshot: dict) -> None:
|
||||
snapshots.append(snapshot)
|
||||
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal(
|
||||
"r1",
|
||||
"t1",
|
||||
store,
|
||||
flush_threshold=100,
|
||||
progress_reporter=reporter,
|
||||
progress_flush_interval=10.0,
|
||||
)
|
||||
j.on_llm_end(
|
||||
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
assert snapshots[-1]["total_tokens"] == 15
|
||||
j.on_llm_end(
|
||||
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
|
||||
await asyncio.wait_for(j.flush(), timeout=0.2)
|
||||
|
||||
assert snapshots[-1]["total_tokens"] == 15
|
||||
assert snapshots[-1]["llm_call_count"] == 1
|
||||
assert snapshots[-1]["last_ai_message"] == "First"
|
||||
|
||||
|
||||
class TestChatModelStartHumanMessage:
|
||||
"""Tests for on_chat_model_start extracting the first human message."""
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.dialects import postgresql
|
||||
|
||||
from deerflow.persistence.run import RunRepository
|
||||
from deerflow.runtime import RunManager, RunStatus
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
|
||||
async def _make_repo(tmp_path):
|
||||
@@ -26,6 +27,42 @@ async def _cleanup():
|
||||
await close_engine()
|
||||
|
||||
|
||||
class _CustomRunStoreWithoutProgress(RunStore):
|
||||
async def put(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def get(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def list_by_thread(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
async def update_status(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def delete(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def update_model_name(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def update_run_completion(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def list_pending(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
async def aggregate_tokens_by_thread(self, *args, **kwargs):
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_progress_defaults_to_noop_for_custom_store():
|
||||
store = _CustomRunStoreWithoutProgress()
|
||||
|
||||
await store.update_run_progress("r1", total_tokens=1)
|
||||
|
||||
|
||||
class TestRunRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_put_and_get(self, tmp_path):
|
||||
@@ -170,6 +207,69 @@ class TestRunRepository:
|
||||
assert row["total_tokens"] == 100
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_progress_keeps_status_running(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="running")
|
||||
await repo.update_run_progress(
|
||||
"r1",
|
||||
total_input_tokens=40,
|
||||
total_output_tokens=10,
|
||||
total_tokens=50,
|
||||
llm_call_count=1,
|
||||
message_count=2,
|
||||
last_ai_message="partial answer",
|
||||
)
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "running"
|
||||
assert row["total_tokens"] == 50
|
||||
assert row["llm_call_count"] == 1
|
||||
assert row["message_count"] == 2
|
||||
assert row["last_ai_message"] == "partial answer"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_progress_preserves_omitted_fields(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="running")
|
||||
await repo.update_run_progress(
|
||||
"r1",
|
||||
total_input_tokens=40,
|
||||
total_output_tokens=10,
|
||||
total_tokens=50,
|
||||
llm_call_count=1,
|
||||
lead_agent_tokens=30,
|
||||
subagent_tokens=20,
|
||||
message_count=2,
|
||||
)
|
||||
|
||||
await repo.update_run_progress("r1", total_tokens=60, last_ai_message="updated")
|
||||
|
||||
row = await repo.get("r1")
|
||||
assert row["total_input_tokens"] == 40
|
||||
assert row["total_output_tokens"] == 10
|
||||
assert row["total_tokens"] == 60
|
||||
assert row["llm_call_count"] == 1
|
||||
assert row["lead_agent_tokens"] == 30
|
||||
assert row["subagent_tokens"] == 20
|
||||
assert row["message_count"] == 2
|
||||
assert row["last_ai_message"] == "updated"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_progress_skips_terminal_runs(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="running")
|
||||
await repo.update_run_completion("r1", status="success", total_tokens=100, llm_call_count=1)
|
||||
|
||||
await repo.update_run_progress("r1", total_tokens=200, llm_call_count=2)
|
||||
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "success"
|
||||
assert row["total_tokens"] == 100
|
||||
assert row["llm_call_count"] == 1
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
@@ -225,6 +325,28 @@ class TestRunRepository:
|
||||
}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_tokens_by_thread_can_include_active_runs(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("success-run", thread_id="t1", status="running")
|
||||
await repo.update_run_completion("success-run", status="success", total_tokens=100, lead_agent_tokens=100)
|
||||
await repo.put("running-run", thread_id="t1", status="running")
|
||||
await repo.update_run_progress("running-run", total_tokens=25, lead_agent_tokens=20, subagent_tokens=5)
|
||||
|
||||
without_active = await repo.aggregate_tokens_by_thread("t1")
|
||||
with_active = await repo.aggregate_tokens_by_thread("t1", include_active=True)
|
||||
|
||||
assert without_active["total_tokens"] == 100
|
||||
assert without_active["total_runs"] == 1
|
||||
assert with_active["total_tokens"] == 125
|
||||
assert with_active["total_runs"] == 2
|
||||
assert with_active["by_caller"] == {
|
||||
"lead_agent": 120,
|
||||
"subagent": 5,
|
||||
"middleware": 0,
|
||||
}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_ordered_desc(self, tmp_path):
|
||||
"""list_by_thread returns newest first."""
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
"""End-to-end graph integration test for SafetyFinishReasonMiddleware.
|
||||
|
||||
Unit tests prove ``_apply`` does the right thing on a synthetic state.
|
||||
This test does one level up: builds a real ``langchain.agents.create_agent``
|
||||
graph with the SafetyFinishReasonMiddleware in place, feeds it a fake model
|
||||
that returns ``finish_reason='content_filter'`` + tool_calls, and asserts:
|
||||
|
||||
1. The tool node is **not** invoked (the dangerous truncated tool call
|
||||
is suppressed).
|
||||
2. The final AIMessage in graph state has ``tool_calls == []``.
|
||||
3. The observability ``safety_termination`` record is attached.
|
||||
4. The user-facing explanation is appended to the message content.
|
||||
|
||||
This is the closest we can get to the issue's failure mode without a live
|
||||
Moonshot key, and it proves the middleware actually gates LangChain's
|
||||
tool router — not just rewrites state in isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
|
||||
_TOOL_INVOCATIONS: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@tool
|
||||
def write_file(path: str, content: str) -> str:
|
||||
"""Pretend to write *content* to *path*. Records the call for assertion."""
|
||||
_TOOL_INVOCATIONS.append({"path": path, "content": content})
|
||||
return f"wrote {len(content)} bytes to {path}"
|
||||
|
||||
|
||||
class _ContentFilteredModel(BaseChatModel):
|
||||
"""Fake chat model that mimics OpenAI/Moonshot's content_filter response.
|
||||
|
||||
First call returns finish_reason='content_filter' + a tool_call whose
|
||||
arguments are visibly truncated. Second call (if reached) returns a
|
||||
normal text completion so the agent can terminate cleanly.
|
||||
"""
|
||||
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-content-filtered"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
# create_agent binds tools onto the model; we don't actually need
|
||||
# to bind anything since responses are hard-coded, but the method
|
||||
# must not raise.
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self.call_count += 1
|
||||
if self.call_count == 1:
|
||||
message = AIMessage(
|
||||
content="Here is the report:\n# Weekly Politics\n- Meeting time: 2026-05-12—",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_truncated_1",
|
||||
"name": "write_file",
|
||||
"args": {
|
||||
"path": "/mnt/user-data/outputs/report.md",
|
||||
"content": "# Weekly Politics\n- Meeting time: 2026-05-12—",
|
||||
},
|
||||
}
|
||||
],
|
||||
response_metadata={"finish_reason": "content_filter", "model_name": "fake-kimi"},
|
||||
)
|
||||
else:
|
||||
message = AIMessage(content="ack", response_metadata={"finish_reason": "stop"})
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
|
||||
class _InspectMiddleware(AgentMiddleware):
|
||||
"""Captures the messages list at every model entry so we can assert
|
||||
no synthetic tool result was injected back into the conversation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.observed: list[list[Any]] = []
|
||||
|
||||
def wrap_model_call(self, request: ModelRequest, handler) -> ModelResponse:
|
||||
self.observed.append(list(request.messages))
|
||||
return handler(request)
|
||||
|
||||
|
||||
def test_content_filter_with_tool_calls_does_not_invoke_tool_node():
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
inspector = _InspectMiddleware()
|
||||
|
||||
agent = create_agent(
|
||||
model=_ContentFilteredModel(),
|
||||
tools=[write_file],
|
||||
# Inspector first so its after_model is registered; Safety last in
|
||||
# the list so it executes first under LIFO (matches production wiring).
|
||||
middleware=[inspector, SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage(content="write me a report")]})
|
||||
|
||||
# Critical assertion: the dangerous truncated tool call must NOT have
|
||||
# been executed. This is the entire point of the middleware.
|
||||
assert _TOOL_INVOCATIONS == [], f"write_file was invoked despite content_filter: {_TOOL_INVOCATIONS}"
|
||||
|
||||
# Final AIMessage has no tool calls left.
|
||||
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
|
||||
assert final_ai.tool_calls == []
|
||||
|
||||
# Observability stamp is present.
|
||||
record = final_ai.additional_kwargs.get("safety_termination")
|
||||
assert record is not None
|
||||
assert record["detector"] == "openai_compatible_content_filter"
|
||||
assert record["reason_field"] == "finish_reason"
|
||||
assert record["reason_value"] == "content_filter"
|
||||
assert record["suppressed_tool_call_count"] == 1
|
||||
assert record["suppressed_tool_call_names"] == ["write_file"]
|
||||
|
||||
# User-facing explanation is appended.
|
||||
assert "safety-related signal" in final_ai.content
|
||||
# Original partial text preserved (we don't throw away what the user
|
||||
# already saw in the stream — see middleware docstring).
|
||||
assert "Weekly Politics" in final_ai.content
|
||||
|
||||
# finish_reason on response_metadata is preserved (so SSE / converters
|
||||
# downstream still see the real provider reason).
|
||||
assert final_ai.response_metadata.get("finish_reason") == "content_filter"
|
||||
|
||||
|
||||
def test_content_filter_without_tool_calls_passes_through_unchanged():
|
||||
"""No tool calls => issue scope says don't intervene; the partial
|
||||
response should be delivered as-is so the user sees what they got."""
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
|
||||
class _NoToolModel(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-no-tool"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
msg = AIMessage(
|
||||
content="Partial answer truncated by safety filter",
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=msg)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
agent = create_agent(
|
||||
model=_NoToolModel(),
|
||||
tools=[write_file],
|
||||
middleware=[SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
result = agent.invoke({"messages": [HumanMessage(content="hi")]})
|
||||
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
|
||||
|
||||
# Content untouched.
|
||||
assert final_ai.content == "Partial answer truncated by safety filter"
|
||||
# No safety_termination stamp because we didn't intervene.
|
||||
assert "safety_termination" not in final_ai.additional_kwargs
|
||||
# tool node never ran (there were no tool calls in the first place).
|
||||
assert _TOOL_INVOCATIONS == []
|
||||
|
||||
|
||||
def test_normal_tool_call_round_trip_is_not_affected():
|
||||
"""Regression: a healthy finish_reason='tool_calls' response must still
|
||||
execute the tool. The middleware must not over-fire."""
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
|
||||
class _HealthyToolModel(BaseChatModel):
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-healthy"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self.call_count += 1
|
||||
if self.call_count == 1:
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_ok",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/tmp/ok", "content": "complete content"},
|
||||
}
|
||||
],
|
||||
response_metadata={"finish_reason": "tool_calls"},
|
||||
)
|
||||
else:
|
||||
msg = AIMessage(content="done", response_metadata={"finish_reason": "stop"})
|
||||
return ChatResult(generations=[ChatGeneration(message=msg)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
agent = create_agent(
|
||||
model=_HealthyToolModel(),
|
||||
tools=[write_file],
|
||||
middleware=[SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
agent.invoke({"messages": [HumanMessage(content="write")]})
|
||||
|
||||
assert _TOOL_INVOCATIONS == [{"path": "/tmp/ok", "content": "complete content"}]
|
||||
@@ -0,0 +1,651 @@
|
||||
"""Unit tests for SafetyFinishReasonMiddleware."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
from deerflow.agents.middlewares.safety_termination_detectors import (
|
||||
SafetyTermination,
|
||||
)
|
||||
from deerflow.config.safety_finish_reason_config import (
|
||||
SafetyDetectorConfig,
|
||||
SafetyFinishReasonConfig,
|
||||
)
|
||||
|
||||
|
||||
def _runtime(thread_id="t-1"):
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": thread_id}
|
||||
return runtime
|
||||
|
||||
|
||||
def _ai(
|
||||
*,
|
||||
content="",
|
||||
tool_calls=None,
|
||||
response_metadata=None,
|
||||
additional_kwargs=None,
|
||||
):
|
||||
return AIMessage(
|
||||
content=content,
|
||||
tool_calls=tool_calls or [],
|
||||
response_metadata=response_metadata or {},
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
def _write_call(idx=1, content_text="半截"):
|
||||
return {
|
||||
"id": f"call_write_{idx}",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/mnt/user-data/outputs/x.md", "content": content_text},
|
||||
}
|
||||
|
||||
|
||||
class AlwaysHitDetector:
|
||||
"""Test fixture: always reports the given termination."""
|
||||
|
||||
name = "always_hit"
|
||||
|
||||
def __init__(self, *, reason_field="finish_reason", reason_value="content_filter", extras=None):
|
||||
self.reason_field = reason_field
|
||||
self.reason_value = reason_value
|
||||
self.extras = extras or {}
|
||||
|
||||
def detect(self, message):
|
||||
return SafetyTermination(
|
||||
detector=self.name,
|
||||
reason_field=self.reason_field,
|
||||
reason_value=self.reason_value,
|
||||
extras=self.extras,
|
||||
)
|
||||
|
||||
|
||||
class NeverHitDetector:
|
||||
name = "never_hit"
|
||||
|
||||
def detect(self, message):
|
||||
return None
|
||||
|
||||
|
||||
class RaisingDetector:
|
||||
name = "raising"
|
||||
|
||||
def detect(self, message):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core trigger behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTriggerCriteria:
|
||||
def test_content_filter_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
patched = result["messages"][0]
|
||||
assert patched.tool_calls == []
|
||||
|
||||
def test_content_filter_without_tool_calls_passes_through(self):
|
||||
"""issue scope: when there are no tool calls the partial text is a
|
||||
legitimate final response and should not be rewritten."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial response",
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_normal_tool_calls_pass_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "tool_calls"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_normal_stop_with_tool_calls_pass_through(self):
|
||||
# Some providers report finish_reason='stop' for tool-call messages.
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "stop"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_empty_message_list_passes_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
assert mw._apply({"messages": []}, _runtime()) is None
|
||||
|
||||
def test_non_ai_last_message_passes_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {"messages": [HumanMessage(content="hi"), SystemMessage(content="sys")]}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_anthropic_refusal_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"stop_reason": "refusal"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_gemini_safety_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "SAFETY"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message rewriting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageRewrite:
|
||||
def test_clears_structured_tool_calls(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1), _write_call(2)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
patched = result["messages"][0]
|
||||
assert patched.tool_calls == []
|
||||
|
||||
def test_clears_raw_additional_kwargs_tool_calls(self):
|
||||
"""Critical defence-in-depth: DanglingToolCallMiddleware will recover
|
||||
tool calls from additional_kwargs.tool_calls if we forget them, which
|
||||
would re-emit a synthetic ToolMessage downstream and confuse the
|
||||
model. We must wipe both."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
raw_tool_calls = [
|
||||
{
|
||||
"id": "call_write_1",
|
||||
"type": "function",
|
||||
"function": {"name": "write_file", "arguments": '{"path": "/x"}'},
|
||||
}
|
||||
]
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
additional_kwargs={
|
||||
"tool_calls": raw_tool_calls,
|
||||
"function_call": {"name": "write_file", "arguments": "{}"},
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
patched = result["messages"][0]
|
||||
assert "tool_calls" not in patched.additional_kwargs
|
||||
assert "function_call" not in patched.additional_kwargs
|
||||
|
||||
def test_preserves_other_additional_kwargs(self):
|
||||
# vLLM puts reasoning under additional_kwargs.reasoning; Anthropic
|
||||
# may carry other provider-specific keys. They must not be wiped.
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
additional_kwargs={
|
||||
"reasoning": "thinking text",
|
||||
"custom_provider_field": {"x": 1},
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.additional_kwargs["reasoning"] == "thinking text"
|
||||
assert patched.additional_kwargs["custom_provider_field"] == {"x": 1}
|
||||
|
||||
def test_writes_observability_field(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1), _write_call(2)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
record = patched.additional_kwargs["safety_termination"]
|
||||
assert record["detector"] == "openai_compatible_content_filter"
|
||||
assert record["reason_field"] == "finish_reason"
|
||||
assert record["reason_value"] == "content_filter"
|
||||
assert record["suppressed_tool_call_count"] == 2
|
||||
assert record["suppressed_tool_call_names"] == ["write_file", "write_file"]
|
||||
|
||||
def test_preserves_response_metadata_finish_reason(self):
|
||||
"""Downstream SSE converters read response_metadata.finish_reason —
|
||||
we want them to see the *real* provider reason, not 'stop'."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter", "model_name": "kimi-k2"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.response_metadata["finish_reason"] == "content_filter"
|
||||
assert patched.response_metadata["model_name"] == "kimi-k2"
|
||||
|
||||
def test_appends_user_facing_explanation_to_str_content(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="some partial text",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, str)
|
||||
assert patched.content.startswith("some partial text")
|
||||
assert "safety-related signal" in patched.content
|
||||
|
||||
def test_handles_empty_content(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, str)
|
||||
assert "safety-related signal" in patched.content
|
||||
|
||||
def test_handles_list_content_thinking_blocks(self):
|
||||
"""Anthropic thinking / vLLM reasoning models emit content blocks.
|
||||
Naively concatenating a string would raise TypeError."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
thinking_blocks = [
|
||||
{"type": "thinking", "text": "let me consider..."},
|
||||
{"type": "text", "text": "partial answer"},
|
||||
]
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content=thinking_blocks,
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, list)
|
||||
assert patched.content[:2] == thinking_blocks
|
||||
assert patched.content[-1]["type"] == "text"
|
||||
assert "safety-related signal" in patched.content[-1]["text"]
|
||||
|
||||
def test_idempotent_on_already_cleared_message(self):
|
||||
# Re-running the middleware on a message we already cleared must not
|
||||
# re-trigger (tool_calls is now empty → fast passthrough).
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
first = mw._apply(state, _runtime())
|
||||
state2 = {"messages": [first["messages"][0]]}
|
||||
second = mw._apply(state2, _runtime())
|
||||
assert second is None
|
||||
|
||||
def test_preserves_message_id_for_add_messages_replacement(self):
|
||||
"""LangGraph's add_messages reducer treats same-id messages as
|
||||
replacements. model_copy keeps id by default."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
original = _ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
# AIMessage auto-generates id; capture it
|
||||
original_id = original.id
|
||||
state = {"messages": [original]}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.id == original_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detector wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDetectorWiring:
|
||||
def test_iterates_detectors_in_order(self):
|
||||
first = AlwaysHitDetector(reason_value="first")
|
||||
second = AlwaysHitDetector(reason_value="second")
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[first, second])
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.additional_kwargs["safety_termination"]["reason_value"] == "first"
|
||||
|
||||
def test_returns_none_when_no_detector_matches(self):
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[NeverHitDetector(), NeverHitDetector()])
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_buggy_detector_does_not_break_run(self):
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[RaisingDetector(), AlwaysHitDetector()])
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].additional_kwargs["safety_termination"]["detector"] == "always_hit"
|
||||
|
||||
def test_constructor_copies_detectors(self):
|
||||
"""Caller mutation after construction must not leak into us."""
|
||||
detectors = [AlwaysHitDetector()]
|
||||
mw = SafetyFinishReasonMiddleware(detectors=detectors)
|
||||
detectors.clear()
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
assert mw._apply(state, _runtime()) is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# from_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFromConfig:
|
||||
def test_default_config_uses_builtin_detectors(self):
|
||||
mw = SafetyFinishReasonMiddleware.from_config(SafetyFinishReasonConfig())
|
||||
assert len(mw._detectors) == 3
|
||||
names = {d.name for d in mw._detectors}
|
||||
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
|
||||
|
||||
def test_custom_detectors_loaded_via_reflection(self):
|
||||
cfg = SafetyFinishReasonConfig(
|
||||
detectors=[
|
||||
SafetyDetectorConfig(
|
||||
use="deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector",
|
||||
config={"finish_reasons": ["custom_filter"]},
|
||||
),
|
||||
]
|
||||
)
|
||||
mw = SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
assert len(mw._detectors) == 1
|
||||
# Confirm the kwargs propagated.
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "custom_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is not None
|
||||
# Default token no longer matches.
|
||||
state2 = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state2, _runtime()) is None
|
||||
|
||||
def test_empty_detector_list_rejected(self):
|
||||
cfg = SafetyFinishReasonConfig(detectors=[])
|
||||
with pytest.raises(ValueError, match="enabled=false"):
|
||||
SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
|
||||
def test_non_detector_class_rejected(self):
|
||||
cfg = SafetyFinishReasonConfig(
|
||||
detectors=[SafetyDetectorConfig(use="builtins:dict")],
|
||||
)
|
||||
with pytest.raises(TypeError):
|
||||
SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stream event
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuditEvent:
|
||||
"""Verify SafetyFinishReasonMiddleware records a `middleware:safety_termination`
|
||||
audit event via RunJournal.record_middleware when the run-scoped journal is
|
||||
exposed under runtime.context["__run_journal"].
|
||||
|
||||
Background: review on PR #3035 — SSE custom event handles live consumers,
|
||||
but post-run audit needs a row in run_events that can be queried with one
|
||||
SQL statement (no JOIN against message body).
|
||||
"""
|
||||
|
||||
def _runtime_with_journal(self, journal):
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "t-audit", "__run_journal": journal}
|
||||
return runtime
|
||||
|
||||
def test_records_audit_event_when_journal_present(self):
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
tc = _write_call(1)
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial",
|
||||
tool_calls=[tc],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, self._runtime_with_journal(journal))
|
||||
assert result is not None
|
||||
|
||||
journal.record_middleware.assert_called_once()
|
||||
call = journal.record_middleware.call_args
|
||||
# tag is positional or kwarg depending on call style; we use kwargs.
|
||||
assert call.kwargs["tag"] == "safety_termination"
|
||||
assert call.kwargs["name"] == "SafetyFinishReasonMiddleware"
|
||||
assert call.kwargs["hook"] == "after_model"
|
||||
assert call.kwargs["action"] == "suppress_tool_calls"
|
||||
|
||||
changes = call.kwargs["changes"]
|
||||
assert changes["detector"] == "openai_compatible_content_filter"
|
||||
assert changes["reason_field"] == "finish_reason"
|
||||
assert changes["reason_value"] == "content_filter"
|
||||
assert changes["suppressed_tool_call_count"] == 1
|
||||
assert changes["suppressed_tool_call_names"] == ["write_file"]
|
||||
assert changes["suppressed_tool_call_ids"] == ["call_write_1"]
|
||||
assert "message_id" in changes
|
||||
assert isinstance(changes["extras"], dict)
|
||||
|
||||
def test_audit_event_never_carries_tool_arguments(self):
|
||||
"""PR #3035 review IMPORTANT: tool args are the filtered content itself
|
||||
and must NOT be persisted to run_events under any circumstance."""
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
sensitive_tc = {
|
||||
"id": "call_x",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/x", "content": "FILTERED_CONTENT_DO_NOT_PERSIST"},
|
||||
}
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[sensitive_tc],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
mw._apply(state, self._runtime_with_journal(journal))
|
||||
flat = repr(journal.record_middleware.call_args)
|
||||
assert "FILTERED_CONTENT_DO_NOT_PERSIST" not in flat, "tool arguments must not leak into audit event"
|
||||
assert "args" not in journal.record_middleware.call_args.kwargs["changes"]
|
||||
|
||||
def test_no_journal_in_runtime_context_is_silently_skipped(self):
|
||||
"""Subagent runtime / unit tests / no-event-store paths have no journal.
|
||||
Middleware must still intervene and clear tool_calls — only the audit
|
||||
event is skipped."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "t-noj"} # no __run_journal
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Should not raise; should still clear tool_calls.
|
||||
result = mw._apply(state, runtime)
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_journal_record_exception_does_not_break_run(self):
|
||||
"""Buggy journal must never propagate an exception into the agent loop."""
|
||||
journal = MagicMock()
|
||||
journal.record_middleware.side_effect = RuntimeError("db down")
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Must not raise.
|
||||
result = mw._apply(state, self._runtime_with_journal(journal))
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_no_record_when_passthrough(self):
|
||||
"""When the middleware does NOT intervene, no audit event is written."""
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "tool_calls"}, # healthy
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, self._runtime_with_journal(journal)) is None
|
||||
journal.record_middleware.assert_not_called()
|
||||
|
||||
|
||||
class TestStreamEvent:
|
||||
def test_emits_event_when_writer_available(self, monkeypatch):
|
||||
captured: list = []
|
||||
|
||||
def fake_writer(payload):
|
||||
captured.append(payload)
|
||||
|
||||
# Patch get_stream_writer at the symbol-resolution site.
|
||||
import langgraph.config
|
||||
|
||||
monkeypatch.setattr(langgraph.config, "get_stream_writer", lambda: fake_writer)
|
||||
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
mw._apply(state, _runtime("t-stream"))
|
||||
|
||||
assert len(captured) == 1
|
||||
payload = captured[0]
|
||||
assert payload["type"] == "safety_termination"
|
||||
assert payload["detector"] == "openai_compatible_content_filter"
|
||||
assert payload["reason_field"] == "finish_reason"
|
||||
assert payload["reason_value"] == "content_filter"
|
||||
assert payload["suppressed_tool_call_count"] == 1
|
||||
assert payload["suppressed_tool_call_names"] == ["write_file"]
|
||||
assert payload["thread_id"] == "t-stream"
|
||||
|
||||
def test_writer_unavailable_does_not_break(self, monkeypatch):
|
||||
import langgraph.config
|
||||
|
||||
def boom():
|
||||
raise LookupError("not in a stream context")
|
||||
|
||||
monkeypatch.setattr(langgraph.config, "get_stream_writer", boom)
|
||||
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Should not raise.
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
@@ -0,0 +1,176 @@
|
||||
"""Unit tests for SafetyTerminationDetector built-ins."""
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from deerflow.agents.middlewares.safety_termination_detectors import (
|
||||
AnthropicRefusalDetector,
|
||||
GeminiSafetyDetector,
|
||||
OpenAICompatibleContentFilterDetector,
|
||||
SafetyTermination,
|
||||
SafetyTerminationDetector,
|
||||
default_detectors,
|
||||
)
|
||||
|
||||
|
||||
def _ai(*, content="", tool_calls=None, response_metadata=None, additional_kwargs=None) -> AIMessage:
|
||||
return AIMessage(
|
||||
content=content,
|
||||
tool_calls=tool_calls or [],
|
||||
response_metadata=response_metadata or {},
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAICompatibleContentFilterDetector:
|
||||
def test_default_matches_content_filter(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
hit = d.detect(_ai(response_metadata={"finish_reason": "content_filter"}))
|
||||
assert hit is not None
|
||||
assert hit.detector == "openai_compatible_content_filter"
|
||||
assert hit.reason_field == "finish_reason"
|
||||
assert hit.reason_value == "content_filter"
|
||||
|
||||
def test_case_insensitive_match(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "CONTENT_FILTER"})) is not None
|
||||
|
||||
def test_other_finish_reasons_pass_through(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "stop"})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "tool_calls"})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "length"})) is None
|
||||
|
||||
def test_missing_metadata_passes_through(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai()) is None
|
||||
|
||||
def test_non_string_finish_reason_passes_through(self):
|
||||
# Some adapters may stash an enum or dict — must not raise.
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": 42})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": {"value": "content_filter"}})) is None
|
||||
|
||||
def test_falls_back_to_additional_kwargs(self):
|
||||
# Legacy adapters surface finish_reason via additional_kwargs.
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
hit = d.detect(_ai(additional_kwargs={"finish_reason": "content_filter"}))
|
||||
assert hit is not None
|
||||
|
||||
def test_configurable_extra_values(self):
|
||||
# Chinese providers sometimes use bespoke tokens.
|
||||
d = OpenAICompatibleContentFilterDetector(finish_reasons=["content_filter", "sensitive", "violation"])
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "sensitive"})) is not None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "violation"})) is not None
|
||||
# Original token still matches.
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "content_filter"})) is not None
|
||||
|
||||
def test_carries_azure_content_filter_results(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
filter_results = {"hate": {"filtered": True, "severity": "high"}}
|
||||
hit = d.detect(
|
||||
_ai(
|
||||
response_metadata={
|
||||
"finish_reason": "content_filter",
|
||||
"content_filter_results": filter_results,
|
||||
},
|
||||
)
|
||||
)
|
||||
assert hit is not None
|
||||
assert hit.extras["content_filter_results"] == filter_results
|
||||
|
||||
|
||||
class TestAnthropicRefusalDetector:
|
||||
def test_default_matches_refusal(self):
|
||||
hit = AnthropicRefusalDetector().detect(_ai(response_metadata={"stop_reason": "refusal"}))
|
||||
assert hit is not None
|
||||
assert hit.reason_field == "stop_reason"
|
||||
assert hit.reason_value == "refusal"
|
||||
|
||||
def test_other_stop_reasons_pass_through(self):
|
||||
d = AnthropicRefusalDetector()
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "end_turn"})) is None
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "tool_use"})) is None
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "max_tokens"})) is None
|
||||
|
||||
def test_anthropic_does_not_steal_finish_reason(self):
|
||||
# An OpenAI message must not accidentally trip the Anthropic detector.
|
||||
assert AnthropicRefusalDetector().detect(_ai(response_metadata={"finish_reason": "content_filter"})) is None
|
||||
|
||||
|
||||
class TestGeminiSafetyDetector:
|
||||
def test_default_set_covers_documented_reasons(self):
|
||||
d = GeminiSafetyDetector()
|
||||
for reason in (
|
||||
# text safety
|
||||
"SAFETY",
|
||||
"BLOCKLIST",
|
||||
"PROHIBITED_CONTENT",
|
||||
"SPII",
|
||||
"RECITATION",
|
||||
# image safety
|
||||
"IMAGE_SAFETY",
|
||||
"IMAGE_PROHIBITED_CONTENT",
|
||||
"IMAGE_RECITATION",
|
||||
):
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is not None, reason
|
||||
|
||||
def test_normal_termination_passes_through(self):
|
||||
d = GeminiSafetyDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "STOP"})) is None
|
||||
# MAX_TOKENS / LANGUAGE / NO_IMAGE / OTHER / IMAGE_OTHER /
|
||||
# MALFORMED_FUNCTION_CALL / UNEXPECTED_TOOL_CALL are intentionally
|
||||
# excluded from the default set — they are either normal termination,
|
||||
# capability mismatches, too broad (OTHER), or tool-call protocol
|
||||
# errors. See GeminiSafetyDetector docstring.
|
||||
for reason in (
|
||||
"MAX_TOKENS",
|
||||
"LANGUAGE",
|
||||
"NO_IMAGE",
|
||||
"OTHER",
|
||||
"IMAGE_OTHER",
|
||||
"MALFORMED_FUNCTION_CALL",
|
||||
"UNEXPECTED_TOOL_CALL",
|
||||
"FINISH_REASON_UNSPECIFIED",
|
||||
):
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is None, reason
|
||||
|
||||
def test_carries_safety_ratings(self):
|
||||
ratings = [{"category": "HARM_CATEGORY_HARASSMENT", "probability": "HIGH"}]
|
||||
hit = GeminiSafetyDetector().detect(
|
||||
_ai(
|
||||
response_metadata={
|
||||
"finish_reason": "SAFETY",
|
||||
"safety_ratings": ratings,
|
||||
},
|
||||
)
|
||||
)
|
||||
assert hit is not None
|
||||
assert hit.extras["safety_ratings"] == ratings
|
||||
|
||||
|
||||
class TestDefaultDetectorSet:
|
||||
def test_default_set_returns_three_detectors(self):
|
||||
dets = default_detectors()
|
||||
names = {d.name for d in dets}
|
||||
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
|
||||
|
||||
def test_default_set_returns_fresh_list(self):
|
||||
# Caller mutation must not affect later calls.
|
||||
first = default_detectors()
|
||||
first.clear()
|
||||
second = default_detectors()
|
||||
assert len(second) == 3
|
||||
|
||||
|
||||
class TestProtocolConformance:
|
||||
def test_builtins_satisfy_protocol(self):
|
||||
for d in default_detectors():
|
||||
assert isinstance(d, SafetyTerminationDetector)
|
||||
|
||||
def test_safety_termination_is_frozen(self):
|
||||
t = SafetyTermination(detector="x", reason_field="finish_reason", reason_value="content_filter")
|
||||
try:
|
||||
t.detector = "y" # type: ignore[misc]
|
||||
except Exception:
|
||||
return
|
||||
raise AssertionError("SafetyTermination should be frozen")
|
||||
@@ -5,6 +5,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.sandbox.exceptions import SandboxError
|
||||
from deerflow.sandbox.tools import (
|
||||
VIRTUAL_PATH_PREFIX,
|
||||
_apply_cwd_prefix,
|
||||
@@ -1140,6 +1141,170 @@ def test_str_replace_and_append_on_same_path_should_preserve_both_updates(monkey
|
||||
assert sandbox.content == "ALPHA\ntail\n"
|
||||
|
||||
|
||||
def test_write_file_tool_bounds_large_oserror_and_masks_local_paths(monkeypatch) -> None:
|
||||
class FailingSandbox:
|
||||
id = "sandbox-write-large-oserror"
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
host_path = f"{_THREAD_DATA['workspace_path']}/nested/output.txt"
|
||||
raise OSError(f"write failed at {host_path}\n{'A' * 12000}\nremote tail marker")
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
sandbox = FailingSandbox()
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: True)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.get_thread_data", lambda runtime: _THREAD_DATA)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.validate_local_tool_path", lambda path, thread_data: None)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.sandbox.tools._resolve_and_validate_user_data_path",
|
||||
lambda path, thread_data: f"{_THREAD_DATA['workspace_path']}/output.txt",
|
||||
)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="写入大文件失败",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="report body",
|
||||
)
|
||||
|
||||
assert len(result) <= 2000
|
||||
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
|
||||
assert "/tmp/deer-flow/threads/t1/user-data/workspace" not in result
|
||||
assert "/mnt/user-data/workspace/nested/output.txt" in result
|
||||
assert "remote tail marker" in result
|
||||
assert "[write_file error truncated:" in result
|
||||
|
||||
|
||||
def test_write_file_tool_preserves_short_oserror_without_truncation(monkeypatch) -> None:
|
||||
class FailingSandbox:
|
||||
id = "sandbox-write-short-oserror"
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
raise OSError("disk quota exceeded")
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
sandbox = FailingSandbox()
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="写入失败",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="tiny payload",
|
||||
)
|
||||
|
||||
assert result == "Error: Failed to write file '/mnt/user-data/workspace/output.txt': OSError: disk quota exceeded"
|
||||
assert "[write_file error truncated:" not in result
|
||||
|
||||
|
||||
def test_write_file_tool_bounds_large_sandbox_error(monkeypatch) -> None:
|
||||
class FailingSandbox:
|
||||
id = "sandbox-write-large-sandbox-error"
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
raise SandboxError(f"remote write rejected {'B' * 12000} final detail")
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
sandbox = FailingSandbox()
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="远端写入失败",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="tiny payload",
|
||||
)
|
||||
|
||||
assert len(result) <= 2000
|
||||
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
|
||||
assert "SandboxError: remote write rejected" in result
|
||||
assert "final detail" in result
|
||||
assert "[write_file error truncated:" in result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raised_error", "expected_fragment"),
|
||||
[
|
||||
pytest.param(
|
||||
PermissionError("permission denied"),
|
||||
"Error: Permission denied writing to file: /mnt/user-data/workspace/output.txt",
|
||||
id="permission",
|
||||
),
|
||||
pytest.param(
|
||||
IsADirectoryError("target is a directory"),
|
||||
"Error: Path is a directory, not a file: /mnt/user-data/workspace/output.txt",
|
||||
id="directory",
|
||||
),
|
||||
pytest.param(
|
||||
Exception("remote sandbox timeout"),
|
||||
"Exception: remote sandbox timeout",
|
||||
id="generic",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_write_file_tool_formats_all_other_failure_branches(
|
||||
monkeypatch,
|
||||
raised_error: Exception,
|
||||
expected_fragment: str,
|
||||
) -> None:
|
||||
class FailingSandbox:
|
||||
id = "sandbox-write-other-failure"
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
raise raised_error
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
sandbox = FailingSandbox()
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="验证错误分支格式化",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="tiny payload",
|
||||
)
|
||||
|
||||
assert "/mnt/user-data/workspace/output.txt" in result
|
||||
assert expected_fragment in result
|
||||
assert "[write_file error truncated:" not in result
|
||||
|
||||
|
||||
def test_write_file_tool_handles_sandbox_init_failure(monkeypatch) -> None:
|
||||
"""Regression for #3133 review: SandboxError raised during sandbox
|
||||
initialization (before the local `requested_path` assignment) must still
|
||||
surface as a bounded tool error rather than an UnboundLocalError.
|
||||
"""
|
||||
|
||||
def raise_sandbox_error(runtime):
|
||||
raise SandboxError("sandbox missing")
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", raise_sandbox_error)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="sandbox 初始化失败",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="tiny payload",
|
||||
)
|
||||
|
||||
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
|
||||
assert "SandboxError: sandbox missing" in result
|
||||
assert "[write_file error truncated:" not in result
|
||||
|
||||
|
||||
def test_file_operation_lock_memory_cleanup() -> None:
|
||||
"""Verify that released locks are eventually cleaned up by WeakValueDictionary.
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from types import SimpleNamespace
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from app.gateway.routers import skills as skills_router
|
||||
from deerflow.skills.storage import get_or_new_skill_storage
|
||||
from deerflow.skills.types import Skill
|
||||
@@ -38,7 +39,8 @@ def _make_skill(name: str, *, enabled: bool) -> Skill:
|
||||
|
||||
def _make_test_app(config) -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.state.config = config
|
||||
app.state.config = config # kept for any startup-style reads
|
||||
app.dependency_overrides[get_config] = lambda: config
|
||||
app.include_router(skills_router.router)
|
||||
return app
|
||||
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Regression tests for _find_usage_recorder callback shape handling.
|
||||
|
||||
Bytedance issue #3107 BUG-002: When LangChain passes ``config["callbacks"]`` as
|
||||
an ``AsyncCallbackManager`` (instead of a plain list), the previous
|
||||
``for cb in callbacks`` loop raised ``TypeError: 'AsyncCallbackManager' object
|
||||
is not iterable``. ToolErrorHandlingMiddleware then converted the entire ``task``
|
||||
tool call into an error ToolMessage, losing the subagent result.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackManager, CallbackManager
|
||||
|
||||
from deerflow.tools.builtins.task_tool import _find_usage_recorder
|
||||
|
||||
|
||||
class _RecorderHandler:
|
||||
def record_external_llm_usage_records(self, records):
|
||||
self.records = records
|
||||
|
||||
|
||||
class _OtherHandler:
|
||||
pass
|
||||
|
||||
|
||||
def _make_runtime(callbacks):
|
||||
return SimpleNamespace(config={"callbacks": callbacks})
|
||||
|
||||
|
||||
def test_find_usage_recorder_with_plain_list():
|
||||
recorder = _RecorderHandler()
|
||||
runtime = _make_runtime([_OtherHandler(), recorder])
|
||||
assert _find_usage_recorder(runtime) is recorder
|
||||
|
||||
|
||||
def test_find_usage_recorder_with_async_callback_manager():
|
||||
"""LangChain wraps callbacks in AsyncCallbackManager for async tool runs.
|
||||
|
||||
The old implementation raised TypeError here. The recorder lives on
|
||||
``manager.handlers``; we must look there too.
|
||||
"""
|
||||
recorder = _RecorderHandler()
|
||||
manager = AsyncCallbackManager(handlers=[_OtherHandler(), recorder])
|
||||
runtime = _make_runtime(manager)
|
||||
assert _find_usage_recorder(runtime) is recorder
|
||||
|
||||
|
||||
def test_find_usage_recorder_with_sync_callback_manager():
|
||||
"""Sync flavor of the same wrapper used by some langchain code paths."""
|
||||
recorder = _RecorderHandler()
|
||||
manager = CallbackManager(handlers=[recorder])
|
||||
runtime = _make_runtime(manager)
|
||||
assert _find_usage_recorder(runtime) is recorder
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_when_no_recorder():
|
||||
manager = AsyncCallbackManager(handlers=[_OtherHandler()])
|
||||
runtime = _make_runtime(manager)
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_handles_empty_manager():
|
||||
manager = AsyncCallbackManager(handlers=[])
|
||||
runtime = _make_runtime(manager)
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_for_none_runtime():
|
||||
assert _find_usage_recorder(None) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_when_callbacks_is_none():
|
||||
runtime = _make_runtime(None)
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_for_single_handler_object():
|
||||
"""A single handler instance (not wrapped in a list or manager) should not crash.
|
||||
|
||||
LangChain's contract is that ``config["callbacks"]`` is a list-or-manager,
|
||||
but we treat any other shape defensively rather than letting a ``for`` loop
|
||||
blow up at runtime.
|
||||
"""
|
||||
runtime = _make_runtime(_RecorderHandler())
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_when_config_not_dict():
|
||||
"""Defensive: a runtime without a dict-shaped config should not raise."""
|
||||
runtime = SimpleNamespace(config="not-a-dict")
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
@@ -53,3 +53,30 @@ def test_thread_token_usage_returns_stable_shape():
|
||||
},
|
||||
}
|
||||
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
|
||||
|
||||
|
||||
def test_thread_token_usage_can_include_active_runs():
|
||||
run_store = MagicMock()
|
||||
run_store.aggregate_tokens_by_thread = AsyncMock(
|
||||
return_value={
|
||||
"total_tokens": 175,
|
||||
"total_input_tokens": 120,
|
||||
"total_output_tokens": 55,
|
||||
"total_runs": 3,
|
||||
"by_model": {"unknown": {"tokens": 175, "runs": 3}},
|
||||
"by_caller": {
|
||||
"lead_agent": 145,
|
||||
"subagent": 25,
|
||||
"middleware": 5,
|
||||
},
|
||||
},
|
||||
)
|
||||
app = _make_app(run_store)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-1/token-usage?include_active=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["total_tokens"] == 175
|
||||
assert response.json()["total_runs"] == 3
|
||||
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1", include_active=True)
|
||||
|
||||
@@ -134,8 +134,14 @@ def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
|
||||
|
||||
assert captured["app_config"] is app_config
|
||||
assert len(middlewares) == 6
|
||||
assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware)
|
||||
# 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling,
|
||||
# SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware
|
||||
# (enabled by default — see SafetyFinishReasonConfig).
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
|
||||
assert len(middlewares) == 7
|
||||
assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares)
|
||||
assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware)
|
||||
|
||||
|
||||
def test_wrap_tool_call_passthrough_on_success():
|
||||
|
||||
@@ -11,6 +11,7 @@ from _router_auth_helpers import call_unwrapped, make_authed_test_app
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from app.gateway.routers import uploads
|
||||
|
||||
|
||||
@@ -687,6 +688,7 @@ def test_upload_limits_endpoint_requires_thread_access():
|
||||
cfg.uploads = {}
|
||||
app = make_authed_test_app(owner_check_passes=False)
|
||||
app.state.config = cfg
|
||||
app.dependency_overrides[get_config] = lambda: cfg
|
||||
app.include_router(uploads.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
|
||||
Reference in New Issue
Block a user