Merge branch 'main' into fix-3127

This commit is contained in:
Willem Jiang
2026-05-22 21:56:04 +08:00
committed by GitHub
57 changed files with 4981 additions and 195 deletions
@@ -218,6 +218,70 @@ class TestBuildPatchedMessagesPatching:
assert mw._build_patched_messages(msgs) is None
def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self):
mw = DanglingToolCallMiddleware()
msgs = [
HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:11"),
_tc("web_search", "web_search:12"),
_tc("web_search", "web_search:13"),
]
),
_tool_msg("web_search:11", "web_search"),
_tool_msg("web_search:12", "web_search"),
_tool_msg("web_search:13", "web_search"),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:9"),
_tc("web_search", "web_search:10"),
_tc("web_search", "web_search:11"),
]
),
_tool_msg("web_search:9", "web_search"),
_tool_msg("web_search:10", "web_search"),
_tool_msg("web_search:11", "web_search"),
]
assert mw._build_patched_messages(msgs) is None
def test_reused_tool_call_id_patches_second_dangling_occurrence(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
_tool_msg("web_search:11", "web_search"),
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "web_search:11"
assert patched[1].status == "success"
assert isinstance(patched[3], ToolMessage)
assert patched[3].tool_call_id == "web_search:11"
assert patched[3].status == "error"
def test_reused_tool_call_id_consumes_later_result_for_first_dangling_occurrence(self):
mw = DanglingToolCallMiddleware()
result = _tool_msg("web_search:11", "web_search")
msgs = [
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
result,
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert patched[1] is result
assert patched[1].status == "success"
assert isinstance(patched[3], ToolMessage)
assert patched[3].tool_call_id == "web_search:11"
assert patched[3].status == "error"
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
mw = DanglingToolCallMiddleware()
msgs = [
@@ -0,0 +1,189 @@
"""Regression tests for gateway config freshness on the request hot path.
Bytedance/deer-flow issue #3107 BUG-001: the worker and lead-agent path
captured ``app.state.config`` at gateway startup. ``config.yaml`` edits during
runtime were therefore ignored — ``get_app_config()``'s mtime-based reload
existed but was bypassed because the snapshot object was passed through
explicitly.
These tests pin the desired behaviour: a request-time ``get_config`` call must
observe the most recent on-disk ``config.yaml`` (mtime reload), and the
runtime ``ContextVar`` override must keep working for per-request injection.
"""
from __future__ import annotations
import os
from pathlib import Path
import pytest
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from app.gateway import deps as gateway_deps
from app.gateway.deps import get_config
from deerflow.config.app_config import (
AppConfig,
pop_current_app_config,
push_current_app_config,
reset_app_config,
set_app_config,
)
from deerflow.config.sandbox_config import SandboxConfig
@pytest.fixture(autouse=True)
def _isolate_app_config_singleton():
"""Ensure each test starts with a clean module-level cache."""
reset_app_config()
yield
reset_app_config()
def _write_config_yaml(path: Path, *, log_level: str) -> None:
path.write_text(
f"""
sandbox:
use: deerflow.sandbox.local.provider:LocalSandboxProvider
log_level: {log_level}
""".strip()
+ "\n",
encoding="utf-8",
)
def _build_app() -> FastAPI:
app = FastAPI()
@app.get("/probe")
def probe(cfg: AppConfig = Depends(get_config)):
return {"log_level": cfg.log_level}
return app
def test_get_config_reflects_file_mtime_reload(tmp_path, monkeypatch):
"""Editing config.yaml at runtime must be visible to /probe without restart.
This is the literal repro for the issue: the gateway must not freeze the
config to whatever was on disk when the process started.
"""
config_file = tmp_path / "config.yaml"
_write_config_yaml(config_file, log_level="info")
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
app = _build_app()
client = TestClient(app)
assert client.get("/probe").json() == {"log_level": "info"}
# Edit the file and bump its mtime — simulating a maintainer changing
# max_tokens / model settings in production while the gateway is live.
_write_config_yaml(config_file, log_level="debug")
future_mtime = config_file.stat().st_mtime + 5
os.utime(config_file, (future_mtime, future_mtime))
assert client.get("/probe").json() == {"log_level": "debug"}
def test_get_config_respects_runtime_context_override(tmp_path, monkeypatch):
"""Per-request ``push_current_app_config`` injection must still win."""
config_file = tmp_path / "config.yaml"
_write_config_yaml(config_file, log_level="info")
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
override = AppConfig(sandbox=SandboxConfig(use="test"), log_level="trace")
push_current_app_config(override)
try:
app = _build_app()
client = TestClient(app)
assert client.get("/probe").json() == {"log_level": "trace"}
finally:
pop_current_app_config()
def test_get_config_respects_test_set_app_config():
"""``set_app_config`` (used by upload/skills router tests) keeps working."""
injected = AppConfig(sandbox=SandboxConfig(use="test"), log_level="warning")
set_app_config(injected)
app = _build_app()
client = TestClient(app)
assert client.get("/probe").json() == {"log_level": "warning"}
def test_run_context_app_config_reflects_yaml_edit(tmp_path, monkeypatch):
"""``RunContext.app_config`` must follow live `config.yaml` edits.
BUG-001 review feedback: the run-context that feeds worker / lead-agent
factories must observe the same mtime reload that `get_config()` does;
otherwise stale config slips back in through the run path even after the
request dependency is fixed.
"""
from unittest.mock import MagicMock
from app.gateway.deps import get_run_context
config_file = tmp_path / "config.yaml"
_write_config_yaml(config_file, log_level="info")
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
app = FastAPI()
# Sentinel values for the rest of the RunContext wiring — we only care
# about ``ctx.app_config`` for this assertion.
app.state.checkpointer = MagicMock()
app.state.store = MagicMock()
app.state.run_event_store = MagicMock()
app.state.run_events_config = {"frozen": "startup"}
app.state.thread_store = MagicMock()
@app.get("/run-ctx-log-level")
def probe(ctx=Depends(get_run_context)):
return {
"log_level": ctx.app_config.log_level,
"run_events_config": ctx.run_events_config,
}
client = TestClient(app)
first = client.get("/run-ctx-log-level").json()
assert first == {"log_level": "info", "run_events_config": {"frozen": "startup"}}
_write_config_yaml(config_file, log_level="debug")
future_mtime = config_file.stat().st_mtime + 5
os.utime(config_file, (future_mtime, future_mtime))
second = client.get("/run-ctx-log-level").json()
# app_config follows the edit; run_events_config stays frozen to the
# startup snapshot we wrote onto app.state above.
assert second == {"log_level": "debug", "run_events_config": {"frozen": "startup"}}
@pytest.mark.parametrize(
"exception",
[
FileNotFoundError("config.yaml not found"),
PermissionError("config.yaml not readable"),
ValueError("invalid config"),
RuntimeError("yaml parse error"),
],
)
def test_get_config_returns_503_on_any_load_failure(monkeypatch, exception):
"""Any failure to materialise the config must surface as 503, not 500.
Bytedance/deer-flow issue #3107 BUG-001 review: the original snapshot
contract returned 503 when ``app.state.config is None``. The first cut of
this fix only mapped ``FileNotFoundError`` to 503, which left
``PermissionError`` / ``yaml.YAMLError`` / ``ValidationError`` etc. bubbling
up as 500. Catch every load failure at the request boundary.
"""
def _broken_get_app_config():
raise exception
monkeypatch.setattr(gateway_deps, "get_app_config", _broken_get_app_config)
app = _build_app()
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/probe")
assert response.status_code == 503
assert response.json() == {"detail": "Configuration not available"}
-41
View File
@@ -1,41 +0,0 @@
from __future__ import annotations
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.sandbox_config import SandboxConfig
def test_get_config_returns_app_state_config():
"""get_config should return the exact AppConfig stored on app.state."""
app = FastAPI()
config = AppConfig(sandbox=SandboxConfig(use="test"))
app.state.config = config
@app.get("/probe")
def probe(cfg: AppConfig = Depends(get_config)):
return {"same_identity": cfg is config, "log_level": cfg.log_level}
client = TestClient(app)
response = client.get("/probe")
assert response.status_code == 200
assert response.json() == {"same_identity": True, "log_level": "info"}
def test_get_config_reads_updated_app_state():
"""Swapping app.state.config should be visible to the dependency."""
app = FastAPI()
app.state.config = AppConfig(sandbox=SandboxConfig(use="test"), log_level="info")
@app.get("/log-level")
def log_level(cfg: AppConfig = Depends(get_config)):
return {"level": cfg.log_level}
client = TestClient(app)
assert client.get("/log-level").json() == {"level": "info"}
app.state.config = app.state.config.model_copy(update={"log_level": "debug"})
assert client.get("/log-level").json() == {"level": "debug"}
@@ -17,7 +17,7 @@ from fastapi import FastAPI
@asynccontextmanager
async def _noop_langgraph_runtime(_app):
async def _noop_langgraph_runtime(_app, _startup_config):
yield
+88
View File
@@ -81,6 +81,94 @@ def test_normalize_input_passthrough():
assert result == {"custom_key": "value"}
def test_normalize_input_preserves_additional_kwargs_and_id():
"""Regression: gh #3132 — frontend ships uploaded-file metadata in
additional_kwargs.files (and a client-side message id). The gateway must
not strip them before the graph runs, otherwise UploadsMiddleware reports
"(empty)" for new uploads and the frontend message loses its file chip.
"""
from langchain_core.messages import HumanMessage
from app.gateway.services import normalize_input
files = [{"filename": "a.csv", "size": 100, "path": "/mnt/user-data/uploads/a.csv", "status": "uploaded"}]
result = normalize_input(
{
"messages": [
{
"type": "human",
"id": "client-msg-1",
"name": "user-input",
"content": [{"type": "text", "text": "clean it"}],
"additional_kwargs": {"files": files, "custom": "keep-me"},
}
]
}
)
assert len(result["messages"]) == 1
msg = result["messages"][0]
assert isinstance(msg, HumanMessage)
assert msg.id == "client-msg-1"
assert msg.name == "user-input"
assert msg.content == [{"type": "text", "text": "clean it"}]
assert msg.additional_kwargs == {"files": files, "custom": "keep-me"}
def test_normalize_input_passes_through_basemessage_instances():
from langchain_core.messages import HumanMessage
from app.gateway.services import normalize_input
msg = HumanMessage(content="hello", id="m-1", additional_kwargs={"files": [{"filename": "x"}]})
result = normalize_input({"messages": [msg]})
assert result["messages"][0] is msg
def test_normalize_input_rejects_malformed_message_with_400():
"""Boundary validation: ``convert_to_messages`` raises ``ValueError`` when a
message dict is missing ``role``/``type``/``content``. ``normalize_input``
runs inside the gateway HTTP boundary, so a malformed payload should surface
as a 400 referencing the offending entry — not bubble up as a 500.
Raised after the Copilot review on PR #3136.
"""
import pytest
from fastapi import HTTPException
from app.gateway.services import normalize_input
with pytest.raises(HTTPException) as excinfo:
normalize_input({"messages": [{"role": "human", "content": "ok"}, {"oops": "no role here"}]})
assert excinfo.value.status_code == 400
assert "input.messages[1]" in excinfo.value.detail
def test_normalize_input_handles_non_human_roles():
"""The previous implementation collapsed every role to HumanMessage with a
`# TODO: handle other message types` comment. Resuming a thread with prior
AI/tool messages would silently rewrite them as human turns — corrupting
the conversation. Use langchain's standard conversion so ai/system/tool
roles round-trip correctly.
"""
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
from app.gateway.services import normalize_input
result = normalize_input(
{
"messages": [
{"role": "system", "content": "sys"},
{"role": "ai", "content": "hi", "id": "ai-1"},
{"role": "tool", "content": "result", "tool_call_id": "call-1"},
]
}
)
types = [type(m) for m in result["messages"]]
assert types == [SystemMessage, AIMessage, ToolMessage]
assert result["messages"][1].id == "ai-1"
assert result["messages"][2].tool_call_id == "call-1"
def test_build_run_config_basic():
from app.gateway.services import build_run_config
@@ -336,8 +336,11 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
)
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
# verify the custom middleware is injected correctly
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
# verify the custom middleware is injected correctly.
# Chain tail order after the custom middleware is:
# ..., custom, SafetyFinishReasonMiddleware, ClarificationMiddleware
# so the custom mock sits at index [-3].
assert len(middlewares) > 0 and isinstance(middlewares[-3], MagicMock)
def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch):
+409
View File
@@ -0,0 +1,409 @@
"""Tests for the MCP persistent-session pool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from deerflow.mcp.session_pool import MCPSessionPool, get_session_pool, reset_session_pool
@pytest.fixture(autouse=True)
def _reset_pool():
reset_session_pool()
yield
reset_session_pool()
# ---------------------------------------------------------------------------
# MCPSessionPool unit tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_session_creates_new():
"""First call for a key creates a new session."""
pool = MCPSessionPool()
mock_session = AsyncMock()
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
session = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
assert session is mock_session
mock_session.initialize.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_session_reuses_existing():
"""Second call for the same key returns the cached session."""
pool = MCPSessionPool()
mock_session = AsyncMock()
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
s2 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
assert s1 is s2
# Only one session should have been created.
assert mock_cm.__aenter__.await_count == 1
@pytest.mark.asyncio
async def test_different_scope_creates_different_session():
"""Different scope keys get different sessions."""
pool = MCPSessionPool()
sessions = [AsyncMock(), AsyncMock()]
idx = 0
class CmFactory:
def __init__(self):
self.enter_count = 0
async def __aenter__(self):
nonlocal idx
s = sessions[idx]
idx += 1
self.enter_count += 1
return s
async def __aexit__(self, *args):
return False
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=lambda *a, **kw: CmFactory()):
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
s2 = await pool.get_session("server", "thread-2", {"transport": "stdio", "command": "x", "args": []})
assert s1 is not s2
assert s1 is sessions[0]
assert s2 is sessions[1]
@pytest.mark.asyncio
async def test_lru_eviction():
"""Oldest entries are evicted when the pool is full."""
pool = MCPSessionPool()
pool.MAX_SESSIONS = 2
class CmFactory:
def __init__(self):
self.closed = False
async def __aenter__(self):
return AsyncMock()
async def __aexit__(self, *args):
self.closed = True
return False
cms: list[CmFactory] = []
def make_cm(*a, **kw):
cm = CmFactory()
cms.append(cm)
return cm
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
# Pool is full (2). Adding t3 should evict t1.
await pool.get_session("s", "t3", {"transport": "stdio", "command": "x", "args": []})
assert cms[0].closed is True
assert cms[1].closed is False
assert cms[2].closed is False
@pytest.mark.asyncio
async def test_close_scope():
"""close_scope shuts down sessions for a specific scope key."""
pool = MCPSessionPool()
class CmFactory:
def __init__(self):
self.closed = False
async def __aenter__(self):
return AsyncMock()
async def __aexit__(self, *args):
self.closed = True
return False
cms: list[CmFactory] = []
def make_cm(*a, **kw):
cm = CmFactory()
cms.append(cm)
return cm
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
await pool.close_scope("t1")
assert cms[0].closed is True
assert cms[1].closed is False
# t2 session still exists.
assert ("s", "t2") in pool._entries
@pytest.mark.asyncio
async def test_close_all():
"""close_all shuts down every session."""
pool = MCPSessionPool()
class CmFactory:
def __init__(self):
self.closed = False
async def __aenter__(self):
return AsyncMock()
async def __aexit__(self, *args):
self.closed = True
return False
cms: list[CmFactory] = []
def make_cm(*a, **kw):
cm = CmFactory()
cms.append(cm)
return cm
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
await pool.get_session("s1", "t1", {"transport": "stdio", "command": "x", "args": []})
await pool.get_session("s2", "t2", {"transport": "stdio", "command": "x", "args": []})
await pool.close_all()
assert all(cm.closed for cm in cms)
assert len(pool._entries) == 0
# ---------------------------------------------------------------------------
# Singleton helpers
# ---------------------------------------------------------------------------
def test_get_session_pool_singleton():
"""get_session_pool returns the same instance."""
p1 = get_session_pool()
p2 = get_session_pool()
assert p1 is p2
def test_reset_session_pool():
"""reset_session_pool clears the singleton."""
p1 = get_session_pool()
reset_session_pool()
p2 = get_session_pool()
assert p1 is not p2
# ---------------------------------------------------------------------------
# Integration: _make_session_pool_tool uses the pool
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_session_pool_tool_wrapping():
"""The wrapper tool delegates to a pool-managed session."""
# Build a dummy StructuredTool (as returned by langchain-mcp-adapters).
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
url: str = Field(..., description="url")
original_tool = StructuredTool(
name="playwright_navigate",
description="Navigate browser",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
connection = {"transport": "stdio", "command": "pw", "args": []}
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
# Simulate a tool call with a runtime context containing thread_id.
mock_runtime = MagicMock()
mock_runtime.context = {"thread_id": "thread-42"}
mock_runtime.config = {}
await wrapped.coroutine(runtime=mock_runtime, url="https://example.com")
mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"})
@pytest.mark.asyncio
async def test_session_pool_tool_extracts_thread_id():
"""Thread ID is extracted from runtime.config when not in context."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
x: int = Field(..., description="x")
original_tool = StructuredTool(
name="server_tool",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
mock_runtime = MagicMock()
mock_runtime.context = {}
mock_runtime.config = {"configurable": {"thread_id": "from-config"}}
await wrapped.coroutine(runtime=mock_runtime, x=1)
# Verify the session was created with the correct scope key.
pool = get_session_pool()
assert ("server", "from-config") in pool._entries
@pytest.mark.asyncio
async def test_session_pool_tool_default_scope():
"""When no thread_id is available, 'default' is used as scope key."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
x: int = Field(..., description="x")
original_tool = StructuredTool(
name="server_tool",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
# No thread_id in runtime at all.
await wrapped.coroutine(runtime=None, x=1)
pool = get_session_pool()
assert ("server", "default") in pool._entries
@pytest.mark.asyncio
async def test_session_pool_tool_get_config_fallback():
"""When runtime is None, get_config() provides thread_id as fallback."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
x: int = Field(..., description="x")
original_tool = StructuredTool(
name="server_tool",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
fake_config = {"configurable": {"thread_id": "from-langgraph-config"}}
with (
patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm),
patch("deerflow.mcp.tools.get_config", return_value=fake_config),
):
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
# runtime=None — get_config() fallback should provide thread_id
await wrapped.coroutine(runtime=None, x=1)
pool = get_session_pool()
assert ("server", "from-langgraph-config") in pool._entries
def test_session_pool_tool_sync_wrapper_path_is_safe():
"""Sync wrapper (tool.func) invocation doesn't crash on cross-loop access."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
from deerflow.tools.sync import make_sync_tool_wrapper
class Args(BaseModel):
url: str = Field(..., description="url")
original_tool = StructuredTool(
name="playwright_navigate",
description="Navigate browser",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
connection = {"transport": "stdio", "command": "pw", "args": []}
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
# Attach the sync wrapper exactly as get_mcp_tools() does.
wrapped.func = make_sync_tool_wrapper(wrapped.coroutine, wrapped.name)
# Call via the sync path (asyncio.run in a worker thread).
# runtime is not supplied so _extract_thread_id falls back to "default".
wrapped.func(url="https://example.com")
mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"})
+104
View File
@@ -714,6 +714,110 @@ class TestExternalUsageRecords:
assert j._subagent_tokens == 0
class TestProgressSnapshots:
@pytest.mark.anyio
async def test_on_llm_end_reports_progress_snapshot(self):
snapshots: list[dict] = []
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=0,
)
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("Answer", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
await j.flush()
assert snapshots
assert snapshots[-1]["total_tokens"] == 15
assert snapshots[-1]["llm_call_count"] == 1
assert snapshots[-1]["message_count"] == 1
assert snapshots[-1]["last_ai_message"] == "Answer"
@pytest.mark.anyio
async def test_throttled_progress_flush_emits_trailing_snapshot(self):
snapshots: list[dict] = []
trailing_seen = asyncio.Event()
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
if snapshot["total_tokens"] == 45:
trailing_seen.set()
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=0.01,
)
j.on_llm_end(
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
j.on_llm_end(
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.wait_for(trailing_seen.wait(), timeout=1.0)
await j.flush()
assert len(snapshots) >= 2
assert snapshots[-1]["total_tokens"] == 45
assert snapshots[-1]["llm_call_count"] == 2
assert snapshots[-1]["last_ai_message"] == "Second"
@pytest.mark.anyio
async def test_flush_cancels_delayed_progress_without_final_progress_write(self):
snapshots: list[dict] = []
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=10.0,
)
j.on_llm_end(
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.sleep(0)
assert snapshots[-1]["total_tokens"] == 15
j.on_llm_end(
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.wait_for(j.flush(), timeout=0.2)
assert snapshots[-1]["total_tokens"] == 15
assert snapshots[-1]["llm_call_count"] == 1
assert snapshots[-1]["last_ai_message"] == "First"
class TestChatModelStartHumanMessage:
"""Tests for on_chat_model_start extracting the first human message."""
+122
View File
@@ -10,6 +10,7 @@ from sqlalchemy.dialects import postgresql
from deerflow.persistence.run import RunRepository
from deerflow.runtime import RunManager, RunStatus
from deerflow.runtime.runs.store.base import RunStore
async def _make_repo(tmp_path):
@@ -26,6 +27,42 @@ async def _cleanup():
await close_engine()
class _CustomRunStoreWithoutProgress(RunStore):
async def put(self, *args, **kwargs):
return None
async def get(self, *args, **kwargs):
return None
async def list_by_thread(self, *args, **kwargs):
return []
async def update_status(self, *args, **kwargs):
return None
async def delete(self, *args, **kwargs):
return None
async def update_model_name(self, *args, **kwargs):
return None
async def update_run_completion(self, *args, **kwargs):
return None
async def list_pending(self, *args, **kwargs):
return []
async def aggregate_tokens_by_thread(self, *args, **kwargs):
return {}
@pytest.mark.anyio
async def test_update_run_progress_defaults_to_noop_for_custom_store():
store = _CustomRunStoreWithoutProgress()
await store.update_run_progress("r1", total_tokens=1)
class TestRunRepository:
@pytest.mark.anyio
async def test_put_and_get(self, tmp_path):
@@ -170,6 +207,69 @@ class TestRunRepository:
assert row["total_tokens"] == 100
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_keeps_status_running(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_progress(
"r1",
total_input_tokens=40,
total_output_tokens=10,
total_tokens=50,
llm_call_count=1,
message_count=2,
last_ai_message="partial answer",
)
row = await repo.get("r1")
assert row["status"] == "running"
assert row["total_tokens"] == 50
assert row["llm_call_count"] == 1
assert row["message_count"] == 2
assert row["last_ai_message"] == "partial answer"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_preserves_omitted_fields(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_progress(
"r1",
total_input_tokens=40,
total_output_tokens=10,
total_tokens=50,
llm_call_count=1,
lead_agent_tokens=30,
subagent_tokens=20,
message_count=2,
)
await repo.update_run_progress("r1", total_tokens=60, last_ai_message="updated")
row = await repo.get("r1")
assert row["total_input_tokens"] == 40
assert row["total_output_tokens"] == 10
assert row["total_tokens"] == 60
assert row["llm_call_count"] == 1
assert row["lead_agent_tokens"] == 30
assert row["subagent_tokens"] == 20
assert row["message_count"] == 2
assert row["last_ai_message"] == "updated"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_skips_terminal_runs(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_completion("r1", status="success", total_tokens=100, llm_call_count=1)
await repo.update_run_progress("r1", total_tokens=200, llm_call_count=2)
row = await repo.get("r1")
assert row["status"] == "success"
assert row["total_tokens"] == 100
assert row["llm_call_count"] == 1
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
repo = await _make_repo(tmp_path)
@@ -225,6 +325,28 @@ class TestRunRepository:
}
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_can_include_active_runs(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("success-run", thread_id="t1", status="running")
await repo.update_run_completion("success-run", status="success", total_tokens=100, lead_agent_tokens=100)
await repo.put("running-run", thread_id="t1", status="running")
await repo.update_run_progress("running-run", total_tokens=25, lead_agent_tokens=20, subagent_tokens=5)
without_active = await repo.aggregate_tokens_by_thread("t1")
with_active = await repo.aggregate_tokens_by_thread("t1", include_active=True)
assert without_active["total_tokens"] == 100
assert without_active["total_runs"] == 1
assert with_active["total_tokens"] == 125
assert with_active["total_runs"] == 2
assert with_active["by_caller"] == {
"lead_agent": 120,
"subagent": 5,
"middleware": 0,
}
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread_ordered_desc(self, tmp_path):
"""list_by_thread returns newest first."""
@@ -0,0 +1,225 @@
"""End-to-end graph integration test for SafetyFinishReasonMiddleware.
Unit tests prove ``_apply`` does the right thing on a synthetic state.
This test does one level up: builds a real ``langchain.agents.create_agent``
graph with the SafetyFinishReasonMiddleware in place, feeds it a fake model
that returns ``finish_reason='content_filter'`` + tool_calls, and asserts:
1. The tool node is **not** invoked (the dangerous truncated tool call
is suppressed).
2. The final AIMessage in graph state has ``tool_calls == []``.
3. The observability ``safety_termination`` record is attached.
4. The user-facing explanation is appended to the message content.
This is the closest we can get to the issue's failure mode without a live
Moonshot key, and it proves the middleware actually gates LangChain's
tool router — not just rewrites state in isolation.
"""
from __future__ import annotations
from typing import Any
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import tool
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
_TOOL_INVOCATIONS: list[dict[str, Any]] = []
@tool
def write_file(path: str, content: str) -> str:
"""Pretend to write *content* to *path*. Records the call for assertion."""
_TOOL_INVOCATIONS.append({"path": path, "content": content})
return f"wrote {len(content)} bytes to {path}"
class _ContentFilteredModel(BaseChatModel):
"""Fake chat model that mimics OpenAI/Moonshot's content_filter response.
First call returns finish_reason='content_filter' + a tool_call whose
arguments are visibly truncated. Second call (if reached) returns a
normal text completion so the agent can terminate cleanly.
"""
call_count: int = 0
@property
def _llm_type(self) -> str:
return "fake-content-filtered"
def bind_tools(self, tools, **kwargs):
# create_agent binds tools onto the model; we don't actually need
# to bind anything since responses are hard-coded, but the method
# must not raise.
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self.call_count += 1
if self.call_count == 1:
message = AIMessage(
content="Here is the report:\n# Weekly Politics\n- Meeting time: 2026-05-12—",
tool_calls=[
{
"id": "call_truncated_1",
"name": "write_file",
"args": {
"path": "/mnt/user-data/outputs/report.md",
"content": "# Weekly Politics\n- Meeting time: 2026-05-12—",
},
}
],
response_metadata={"finish_reason": "content_filter", "model_name": "fake-kimi"},
)
else:
message = AIMessage(content="ack", response_metadata={"finish_reason": "stop"})
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
class _InspectMiddleware(AgentMiddleware):
"""Captures the messages list at every model entry so we can assert
no synthetic tool result was injected back into the conversation."""
def __init__(self) -> None:
super().__init__()
self.observed: list[list[Any]] = []
def wrap_model_call(self, request: ModelRequest, handler) -> ModelResponse:
self.observed.append(list(request.messages))
return handler(request)
def test_content_filter_with_tool_calls_does_not_invoke_tool_node():
_TOOL_INVOCATIONS.clear()
inspector = _InspectMiddleware()
agent = create_agent(
model=_ContentFilteredModel(),
tools=[write_file],
# Inspector first so its after_model is registered; Safety last in
# the list so it executes first under LIFO (matches production wiring).
middleware=[inspector, SafetyFinishReasonMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage(content="write me a report")]})
# Critical assertion: the dangerous truncated tool call must NOT have
# been executed. This is the entire point of the middleware.
assert _TOOL_INVOCATIONS == [], f"write_file was invoked despite content_filter: {_TOOL_INVOCATIONS}"
# Final AIMessage has no tool calls left.
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
assert final_ai.tool_calls == []
# Observability stamp is present.
record = final_ai.additional_kwargs.get("safety_termination")
assert record is not None
assert record["detector"] == "openai_compatible_content_filter"
assert record["reason_field"] == "finish_reason"
assert record["reason_value"] == "content_filter"
assert record["suppressed_tool_call_count"] == 1
assert record["suppressed_tool_call_names"] == ["write_file"]
# User-facing explanation is appended.
assert "safety-related signal" in final_ai.content
# Original partial text preserved (we don't throw away what the user
# already saw in the stream — see middleware docstring).
assert "Weekly Politics" in final_ai.content
# finish_reason on response_metadata is preserved (so SSE / converters
# downstream still see the real provider reason).
assert final_ai.response_metadata.get("finish_reason") == "content_filter"
def test_content_filter_without_tool_calls_passes_through_unchanged():
"""No tool calls => issue scope says don't intervene; the partial
response should be delivered as-is so the user sees what they got."""
_TOOL_INVOCATIONS.clear()
class _NoToolModel(BaseChatModel):
@property
def _llm_type(self) -> str:
return "fake-no-tool"
def bind_tools(self, tools, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
msg = AIMessage(
content="Partial answer truncated by safety filter",
response_metadata={"finish_reason": "content_filter"},
)
return ChatResult(generations=[ChatGeneration(message=msg)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
agent = create_agent(
model=_NoToolModel(),
tools=[write_file],
middleware=[SafetyFinishReasonMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage(content="hi")]})
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
# Content untouched.
assert final_ai.content == "Partial answer truncated by safety filter"
# No safety_termination stamp because we didn't intervene.
assert "safety_termination" not in final_ai.additional_kwargs
# tool node never ran (there were no tool calls in the first place).
assert _TOOL_INVOCATIONS == []
def test_normal_tool_call_round_trip_is_not_affected():
"""Regression: a healthy finish_reason='tool_calls' response must still
execute the tool. The middleware must not over-fire."""
_TOOL_INVOCATIONS.clear()
class _HealthyToolModel(BaseChatModel):
call_count: int = 0
@property
def _llm_type(self) -> str:
return "fake-healthy"
def bind_tools(self, tools, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self.call_count += 1
if self.call_count == 1:
msg = AIMessage(
content="",
tool_calls=[
{
"id": "call_ok",
"name": "write_file",
"args": {"path": "/tmp/ok", "content": "complete content"},
}
],
response_metadata={"finish_reason": "tool_calls"},
)
else:
msg = AIMessage(content="done", response_metadata={"finish_reason": "stop"})
return ChatResult(generations=[ChatGeneration(message=msg)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
agent = create_agent(
model=_HealthyToolModel(),
tools=[write_file],
middleware=[SafetyFinishReasonMiddleware()],
)
agent.invoke({"messages": [HumanMessage(content="write")]})
assert _TOOL_INVOCATIONS == [{"path": "/tmp/ok", "content": "complete content"}]
@@ -0,0 +1,651 @@
"""Unit tests for SafetyFinishReasonMiddleware."""
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
from deerflow.agents.middlewares.safety_termination_detectors import (
SafetyTermination,
)
from deerflow.config.safety_finish_reason_config import (
SafetyDetectorConfig,
SafetyFinishReasonConfig,
)
def _runtime(thread_id="t-1"):
runtime = MagicMock()
runtime.context = {"thread_id": thread_id}
return runtime
def _ai(
*,
content="",
tool_calls=None,
response_metadata=None,
additional_kwargs=None,
):
return AIMessage(
content=content,
tool_calls=tool_calls or [],
response_metadata=response_metadata or {},
additional_kwargs=additional_kwargs or {},
)
def _write_call(idx=1, content_text="半截"):
return {
"id": f"call_write_{idx}",
"name": "write_file",
"args": {"path": "/mnt/user-data/outputs/x.md", "content": content_text},
}
class AlwaysHitDetector:
"""Test fixture: always reports the given termination."""
name = "always_hit"
def __init__(self, *, reason_field="finish_reason", reason_value="content_filter", extras=None):
self.reason_field = reason_field
self.reason_value = reason_value
self.extras = extras or {}
def detect(self, message):
return SafetyTermination(
detector=self.name,
reason_field=self.reason_field,
reason_value=self.reason_value,
extras=self.extras,
)
class NeverHitDetector:
name = "never_hit"
def detect(self, message):
return None
class RaisingDetector:
name = "raising"
def detect(self, message):
raise RuntimeError("boom")
# ---------------------------------------------------------------------------
# Core trigger behaviour
# ---------------------------------------------------------------------------
class TestTriggerCriteria:
def test_content_filter_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="partial",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
patched = result["messages"][0]
assert patched.tool_calls == []
def test_content_filter_without_tool_calls_passes_through(self):
"""issue scope: when there are no tool calls the partial text is a
legitimate final response and should not be rewritten."""
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="partial response",
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_normal_tool_calls_pass_through(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "tool_calls"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_normal_stop_with_tool_calls_pass_through(self):
# Some providers report finish_reason='stop' for tool-call messages.
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "stop"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_empty_message_list_passes_through(self):
mw = SafetyFinishReasonMiddleware()
assert mw._apply({"messages": []}, _runtime()) is None
def test_non_ai_last_message_passes_through(self):
mw = SafetyFinishReasonMiddleware()
state = {"messages": [HumanMessage(content="hi"), SystemMessage(content="sys")]}
assert mw._apply(state, _runtime()) is None
def test_anthropic_refusal_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"stop_reason": "refusal"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].tool_calls == []
def test_gemini_safety_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "SAFETY"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].tool_calls == []
# ---------------------------------------------------------------------------
# Message rewriting
# ---------------------------------------------------------------------------
class TestMessageRewrite:
def test_clears_structured_tool_calls(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call(1), _write_call(2)],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, _runtime())
patched = result["messages"][0]
assert patched.tool_calls == []
def test_clears_raw_additional_kwargs_tool_calls(self):
"""Critical defence-in-depth: DanglingToolCallMiddleware will recover
tool calls from additional_kwargs.tool_calls if we forget them, which
would re-emit a synthetic ToolMessage downstream and confuse the
model. We must wipe both."""
mw = SafetyFinishReasonMiddleware()
raw_tool_calls = [
{
"id": "call_write_1",
"type": "function",
"function": {"name": "write_file", "arguments": '{"path": "/x"}'},
}
]
state = {
"messages": [
_ai(
tool_calls=[_write_call(1)],
response_metadata={"finish_reason": "content_filter"},
additional_kwargs={
"tool_calls": raw_tool_calls,
"function_call": {"name": "write_file", "arguments": "{}"},
},
)
]
}
result = mw._apply(state, _runtime())
patched = result["messages"][0]
assert "tool_calls" not in patched.additional_kwargs
assert "function_call" not in patched.additional_kwargs
def test_preserves_other_additional_kwargs(self):
# vLLM puts reasoning under additional_kwargs.reasoning; Anthropic
# may carry other provider-specific keys. They must not be wiped.
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
additional_kwargs={
"reasoning": "thinking text",
"custom_provider_field": {"x": 1},
},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.additional_kwargs["reasoning"] == "thinking text"
assert patched.additional_kwargs["custom_provider_field"] == {"x": 1}
def test_writes_observability_field(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call(1), _write_call(2)],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
record = patched.additional_kwargs["safety_termination"]
assert record["detector"] == "openai_compatible_content_filter"
assert record["reason_field"] == "finish_reason"
assert record["reason_value"] == "content_filter"
assert record["suppressed_tool_call_count"] == 2
assert record["suppressed_tool_call_names"] == ["write_file", "write_file"]
def test_preserves_response_metadata_finish_reason(self):
"""Downstream SSE converters read response_metadata.finish_reason —
we want them to see the *real* provider reason, not 'stop'."""
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter", "model_name": "kimi-k2"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.response_metadata["finish_reason"] == "content_filter"
assert patched.response_metadata["model_name"] == "kimi-k2"
def test_appends_user_facing_explanation_to_str_content(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="some partial text",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, str)
assert patched.content.startswith("some partial text")
assert "safety-related signal" in patched.content
def test_handles_empty_content(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, str)
assert "safety-related signal" in patched.content
def test_handles_list_content_thinking_blocks(self):
"""Anthropic thinking / vLLM reasoning models emit content blocks.
Naively concatenating a string would raise TypeError."""
mw = SafetyFinishReasonMiddleware()
thinking_blocks = [
{"type": "thinking", "text": "let me consider..."},
{"type": "text", "text": "partial answer"},
]
state = {
"messages": [
_ai(
content=thinking_blocks,
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, list)
assert patched.content[:2] == thinking_blocks
assert patched.content[-1]["type"] == "text"
assert "safety-related signal" in patched.content[-1]["text"]
def test_idempotent_on_already_cleared_message(self):
# Re-running the middleware on a message we already cleared must not
# re-trigger (tool_calls is now empty → fast passthrough).
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
first = mw._apply(state, _runtime())
state2 = {"messages": [first["messages"][0]]}
second = mw._apply(state2, _runtime())
assert second is None
def test_preserves_message_id_for_add_messages_replacement(self):
"""LangGraph's add_messages reducer treats same-id messages as
replacements. model_copy keeps id by default."""
mw = SafetyFinishReasonMiddleware()
original = _ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
# AIMessage auto-generates id; capture it
original_id = original.id
state = {"messages": [original]}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.id == original_id
# ---------------------------------------------------------------------------
# Detector wiring
# ---------------------------------------------------------------------------
class TestDetectorWiring:
def test_iterates_detectors_in_order(self):
first = AlwaysHitDetector(reason_value="first")
second = AlwaysHitDetector(reason_value="second")
mw = SafetyFinishReasonMiddleware(detectors=[first, second])
state = {"messages": [_ai(tool_calls=[_write_call()])]}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.additional_kwargs["safety_termination"]["reason_value"] == "first"
def test_returns_none_when_no_detector_matches(self):
mw = SafetyFinishReasonMiddleware(detectors=[NeverHitDetector(), NeverHitDetector()])
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_buggy_detector_does_not_break_run(self):
mw = SafetyFinishReasonMiddleware(detectors=[RaisingDetector(), AlwaysHitDetector()])
state = {"messages": [_ai(tool_calls=[_write_call()])]}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].additional_kwargs["safety_termination"]["detector"] == "always_hit"
def test_constructor_copies_detectors(self):
"""Caller mutation after construction must not leak into us."""
detectors = [AlwaysHitDetector()]
mw = SafetyFinishReasonMiddleware(detectors=detectors)
detectors.clear()
state = {"messages": [_ai(tool_calls=[_write_call()])]}
assert mw._apply(state, _runtime()) is not None
# ---------------------------------------------------------------------------
# from_config
# ---------------------------------------------------------------------------
class TestFromConfig:
def test_default_config_uses_builtin_detectors(self):
mw = SafetyFinishReasonMiddleware.from_config(SafetyFinishReasonConfig())
assert len(mw._detectors) == 3
names = {d.name for d in mw._detectors}
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
def test_custom_detectors_loaded_via_reflection(self):
cfg = SafetyFinishReasonConfig(
detectors=[
SafetyDetectorConfig(
use="deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector",
config={"finish_reasons": ["custom_filter"]},
),
]
)
mw = SafetyFinishReasonMiddleware.from_config(cfg)
assert len(mw._detectors) == 1
# Confirm the kwargs propagated.
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "custom_filter"},
)
]
}
assert mw._apply(state, _runtime()) is not None
# Default token no longer matches.
state2 = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state2, _runtime()) is None
def test_empty_detector_list_rejected(self):
cfg = SafetyFinishReasonConfig(detectors=[])
with pytest.raises(ValueError, match="enabled=false"):
SafetyFinishReasonMiddleware.from_config(cfg)
def test_non_detector_class_rejected(self):
cfg = SafetyFinishReasonConfig(
detectors=[SafetyDetectorConfig(use="builtins:dict")],
)
with pytest.raises(TypeError):
SafetyFinishReasonMiddleware.from_config(cfg)
# ---------------------------------------------------------------------------
# Stream event
# ---------------------------------------------------------------------------
class TestAuditEvent:
"""Verify SafetyFinishReasonMiddleware records a `middleware:safety_termination`
audit event via RunJournal.record_middleware when the run-scoped journal is
exposed under runtime.context["__run_journal"].
Background: review on PR #3035 — SSE custom event handles live consumers,
but post-run audit needs a row in run_events that can be queried with one
SQL statement (no JOIN against message body).
"""
def _runtime_with_journal(self, journal):
runtime = MagicMock()
runtime.context = {"thread_id": "t-audit", "__run_journal": journal}
return runtime
def test_records_audit_event_when_journal_present(self):
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
tc = _write_call(1)
state = {
"messages": [
_ai(
content="partial",
tool_calls=[tc],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, self._runtime_with_journal(journal))
assert result is not None
journal.record_middleware.assert_called_once()
call = journal.record_middleware.call_args
# tag is positional or kwarg depending on call style; we use kwargs.
assert call.kwargs["tag"] == "safety_termination"
assert call.kwargs["name"] == "SafetyFinishReasonMiddleware"
assert call.kwargs["hook"] == "after_model"
assert call.kwargs["action"] == "suppress_tool_calls"
changes = call.kwargs["changes"]
assert changes["detector"] == "openai_compatible_content_filter"
assert changes["reason_field"] == "finish_reason"
assert changes["reason_value"] == "content_filter"
assert changes["suppressed_tool_call_count"] == 1
assert changes["suppressed_tool_call_names"] == ["write_file"]
assert changes["suppressed_tool_call_ids"] == ["call_write_1"]
assert "message_id" in changes
assert isinstance(changes["extras"], dict)
def test_audit_event_never_carries_tool_arguments(self):
"""PR #3035 review IMPORTANT: tool args are the filtered content itself
and must NOT be persisted to run_events under any circumstance."""
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
sensitive_tc = {
"id": "call_x",
"name": "write_file",
"args": {"path": "/x", "content": "FILTERED_CONTENT_DO_NOT_PERSIST"},
}
state = {
"messages": [
_ai(
tool_calls=[sensitive_tc],
response_metadata={"finish_reason": "content_filter"},
)
]
}
mw._apply(state, self._runtime_with_journal(journal))
flat = repr(journal.record_middleware.call_args)
assert "FILTERED_CONTENT_DO_NOT_PERSIST" not in flat, "tool arguments must not leak into audit event"
assert "args" not in journal.record_middleware.call_args.kwargs["changes"]
def test_no_journal_in_runtime_context_is_silently_skipped(self):
"""Subagent runtime / unit tests / no-event-store paths have no journal.
Middleware must still intervene and clear tool_calls — only the audit
event is skipped."""
mw = SafetyFinishReasonMiddleware()
runtime = MagicMock()
runtime.context = {"thread_id": "t-noj"} # no __run_journal
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Should not raise; should still clear tool_calls.
result = mw._apply(state, runtime)
assert result is not None
assert result["messages"][0].tool_calls == []
def test_journal_record_exception_does_not_break_run(self):
"""Buggy journal must never propagate an exception into the agent loop."""
journal = MagicMock()
journal.record_middleware.side_effect = RuntimeError("db down")
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Must not raise.
result = mw._apply(state, self._runtime_with_journal(journal))
assert result is not None
assert result["messages"][0].tool_calls == []
def test_no_record_when_passthrough(self):
"""When the middleware does NOT intervene, no audit event is written."""
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "tool_calls"}, # healthy
)
]
}
assert mw._apply(state, self._runtime_with_journal(journal)) is None
journal.record_middleware.assert_not_called()
class TestStreamEvent:
def test_emits_event_when_writer_available(self, monkeypatch):
captured: list = []
def fake_writer(payload):
captured.append(payload)
# Patch get_stream_writer at the symbol-resolution site.
import langgraph.config
monkeypatch.setattr(langgraph.config, "get_stream_writer", lambda: fake_writer)
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
mw._apply(state, _runtime("t-stream"))
assert len(captured) == 1
payload = captured[0]
assert payload["type"] == "safety_termination"
assert payload["detector"] == "openai_compatible_content_filter"
assert payload["reason_field"] == "finish_reason"
assert payload["reason_value"] == "content_filter"
assert payload["suppressed_tool_call_count"] == 1
assert payload["suppressed_tool_call_names"] == ["write_file"]
assert payload["thread_id"] == "t-stream"
def test_writer_unavailable_does_not_break(self, monkeypatch):
import langgraph.config
def boom():
raise LookupError("not in a stream context")
monkeypatch.setattr(langgraph.config, "get_stream_writer", boom)
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Should not raise.
result = mw._apply(state, _runtime())
assert result is not None
@@ -0,0 +1,176 @@
"""Unit tests for SafetyTerminationDetector built-ins."""
from langchain_core.messages import AIMessage
from deerflow.agents.middlewares.safety_termination_detectors import (
AnthropicRefusalDetector,
GeminiSafetyDetector,
OpenAICompatibleContentFilterDetector,
SafetyTermination,
SafetyTerminationDetector,
default_detectors,
)
def _ai(*, content="", tool_calls=None, response_metadata=None, additional_kwargs=None) -> AIMessage:
return AIMessage(
content=content,
tool_calls=tool_calls or [],
response_metadata=response_metadata or {},
additional_kwargs=additional_kwargs or {},
)
class TestOpenAICompatibleContentFilterDetector:
def test_default_matches_content_filter(self):
d = OpenAICompatibleContentFilterDetector()
hit = d.detect(_ai(response_metadata={"finish_reason": "content_filter"}))
assert hit is not None
assert hit.detector == "openai_compatible_content_filter"
assert hit.reason_field == "finish_reason"
assert hit.reason_value == "content_filter"
def test_case_insensitive_match(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "CONTENT_FILTER"})) is not None
def test_other_finish_reasons_pass_through(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "stop"})) is None
assert d.detect(_ai(response_metadata={"finish_reason": "tool_calls"})) is None
assert d.detect(_ai(response_metadata={"finish_reason": "length"})) is None
def test_missing_metadata_passes_through(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai()) is None
def test_non_string_finish_reason_passes_through(self):
# Some adapters may stash an enum or dict — must not raise.
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": 42})) is None
assert d.detect(_ai(response_metadata={"finish_reason": {"value": "content_filter"}})) is None
def test_falls_back_to_additional_kwargs(self):
# Legacy adapters surface finish_reason via additional_kwargs.
d = OpenAICompatibleContentFilterDetector()
hit = d.detect(_ai(additional_kwargs={"finish_reason": "content_filter"}))
assert hit is not None
def test_configurable_extra_values(self):
# Chinese providers sometimes use bespoke tokens.
d = OpenAICompatibleContentFilterDetector(finish_reasons=["content_filter", "sensitive", "violation"])
assert d.detect(_ai(response_metadata={"finish_reason": "sensitive"})) is not None
assert d.detect(_ai(response_metadata={"finish_reason": "violation"})) is not None
# Original token still matches.
assert d.detect(_ai(response_metadata={"finish_reason": "content_filter"})) is not None
def test_carries_azure_content_filter_results(self):
d = OpenAICompatibleContentFilterDetector()
filter_results = {"hate": {"filtered": True, "severity": "high"}}
hit = d.detect(
_ai(
response_metadata={
"finish_reason": "content_filter",
"content_filter_results": filter_results,
},
)
)
assert hit is not None
assert hit.extras["content_filter_results"] == filter_results
class TestAnthropicRefusalDetector:
def test_default_matches_refusal(self):
hit = AnthropicRefusalDetector().detect(_ai(response_metadata={"stop_reason": "refusal"}))
assert hit is not None
assert hit.reason_field == "stop_reason"
assert hit.reason_value == "refusal"
def test_other_stop_reasons_pass_through(self):
d = AnthropicRefusalDetector()
assert d.detect(_ai(response_metadata={"stop_reason": "end_turn"})) is None
assert d.detect(_ai(response_metadata={"stop_reason": "tool_use"})) is None
assert d.detect(_ai(response_metadata={"stop_reason": "max_tokens"})) is None
def test_anthropic_does_not_steal_finish_reason(self):
# An OpenAI message must not accidentally trip the Anthropic detector.
assert AnthropicRefusalDetector().detect(_ai(response_metadata={"finish_reason": "content_filter"})) is None
class TestGeminiSafetyDetector:
def test_default_set_covers_documented_reasons(self):
d = GeminiSafetyDetector()
for reason in (
# text safety
"SAFETY",
"BLOCKLIST",
"PROHIBITED_CONTENT",
"SPII",
"RECITATION",
# image safety
"IMAGE_SAFETY",
"IMAGE_PROHIBITED_CONTENT",
"IMAGE_RECITATION",
):
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is not None, reason
def test_normal_termination_passes_through(self):
d = GeminiSafetyDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "STOP"})) is None
# MAX_TOKENS / LANGUAGE / NO_IMAGE / OTHER / IMAGE_OTHER /
# MALFORMED_FUNCTION_CALL / UNEXPECTED_TOOL_CALL are intentionally
# excluded from the default set — they are either normal termination,
# capability mismatches, too broad (OTHER), or tool-call protocol
# errors. See GeminiSafetyDetector docstring.
for reason in (
"MAX_TOKENS",
"LANGUAGE",
"NO_IMAGE",
"OTHER",
"IMAGE_OTHER",
"MALFORMED_FUNCTION_CALL",
"UNEXPECTED_TOOL_CALL",
"FINISH_REASON_UNSPECIFIED",
):
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is None, reason
def test_carries_safety_ratings(self):
ratings = [{"category": "HARM_CATEGORY_HARASSMENT", "probability": "HIGH"}]
hit = GeminiSafetyDetector().detect(
_ai(
response_metadata={
"finish_reason": "SAFETY",
"safety_ratings": ratings,
},
)
)
assert hit is not None
assert hit.extras["safety_ratings"] == ratings
class TestDefaultDetectorSet:
def test_default_set_returns_three_detectors(self):
dets = default_detectors()
names = {d.name for d in dets}
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
def test_default_set_returns_fresh_list(self):
# Caller mutation must not affect later calls.
first = default_detectors()
first.clear()
second = default_detectors()
assert len(second) == 3
class TestProtocolConformance:
def test_builtins_satisfy_protocol(self):
for d in default_detectors():
assert isinstance(d, SafetyTerminationDetector)
def test_safety_termination_is_frozen(self):
t = SafetyTermination(detector="x", reason_field="finish_reason", reason_value="content_filter")
try:
t.detector = "y" # type: ignore[misc]
except Exception:
return
raise AssertionError("SafetyTermination should be frozen")
@@ -5,6 +5,7 @@ from unittest.mock import patch
import pytest
from deerflow.sandbox.exceptions import SandboxError
from deerflow.sandbox.tools import (
VIRTUAL_PATH_PREFIX,
_apply_cwd_prefix,
@@ -1140,6 +1141,170 @@ def test_str_replace_and_append_on_same_path_should_preserve_both_updates(monkey
assert sandbox.content == "ALPHA\ntail\n"
def test_write_file_tool_bounds_large_oserror_and_masks_local_paths(monkeypatch) -> None:
class FailingSandbox:
id = "sandbox-write-large-oserror"
def write_file(self, path: str, content: str, append: bool = False) -> None:
host_path = f"{_THREAD_DATA['workspace_path']}/nested/output.txt"
raise OSError(f"write failed at {host_path}\n{'A' * 12000}\nremote tail marker")
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
sandbox = FailingSandbox()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: True)
monkeypatch.setattr("deerflow.sandbox.tools.get_thread_data", lambda runtime: _THREAD_DATA)
monkeypatch.setattr("deerflow.sandbox.tools.validate_local_tool_path", lambda path, thread_data: None)
monkeypatch.setattr(
"deerflow.sandbox.tools._resolve_and_validate_user_data_path",
lambda path, thread_data: f"{_THREAD_DATA['workspace_path']}/output.txt",
)
result = write_file_tool.func(
runtime=runtime,
description="写入大文件失败",
path="/mnt/user-data/workspace/output.txt",
content="report body",
)
assert len(result) <= 2000
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
assert "/tmp/deer-flow/threads/t1/user-data/workspace" not in result
assert "/mnt/user-data/workspace/nested/output.txt" in result
assert "remote tail marker" in result
assert "[write_file error truncated:" in result
def test_write_file_tool_preserves_short_oserror_without_truncation(monkeypatch) -> None:
class FailingSandbox:
id = "sandbox-write-short-oserror"
def write_file(self, path: str, content: str, append: bool = False) -> None:
raise OSError("disk quota exceeded")
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
sandbox = FailingSandbox()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
result = write_file_tool.func(
runtime=runtime,
description="写入失败",
path="/mnt/user-data/workspace/output.txt",
content="tiny payload",
)
assert result == "Error: Failed to write file '/mnt/user-data/workspace/output.txt': OSError: disk quota exceeded"
assert "[write_file error truncated:" not in result
def test_write_file_tool_bounds_large_sandbox_error(monkeypatch) -> None:
class FailingSandbox:
id = "sandbox-write-large-sandbox-error"
def write_file(self, path: str, content: str, append: bool = False) -> None:
raise SandboxError(f"remote write rejected {'B' * 12000} final detail")
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
sandbox = FailingSandbox()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
result = write_file_tool.func(
runtime=runtime,
description="远端写入失败",
path="/mnt/user-data/workspace/output.txt",
content="tiny payload",
)
assert len(result) <= 2000
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
assert "SandboxError: remote write rejected" in result
assert "final detail" in result
assert "[write_file error truncated:" in result
@pytest.mark.parametrize(
("raised_error", "expected_fragment"),
[
pytest.param(
PermissionError("permission denied"),
"Error: Permission denied writing to file: /mnt/user-data/workspace/output.txt",
id="permission",
),
pytest.param(
IsADirectoryError("target is a directory"),
"Error: Path is a directory, not a file: /mnt/user-data/workspace/output.txt",
id="directory",
),
pytest.param(
Exception("remote sandbox timeout"),
"Exception: remote sandbox timeout",
id="generic",
),
],
)
def test_write_file_tool_formats_all_other_failure_branches(
monkeypatch,
raised_error: Exception,
expected_fragment: str,
) -> None:
class FailingSandbox:
id = "sandbox-write-other-failure"
def write_file(self, path: str, content: str, append: bool = False) -> None:
raise raised_error
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
sandbox = FailingSandbox()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
result = write_file_tool.func(
runtime=runtime,
description="验证错误分支格式化",
path="/mnt/user-data/workspace/output.txt",
content="tiny payload",
)
assert "/mnt/user-data/workspace/output.txt" in result
assert expected_fragment in result
assert "[write_file error truncated:" not in result
def test_write_file_tool_handles_sandbox_init_failure(monkeypatch) -> None:
"""Regression for #3133 review: SandboxError raised during sandbox
initialization (before the local `requested_path` assignment) must still
surface as a bounded tool error rather than an UnboundLocalError.
"""
def raise_sandbox_error(runtime):
raise SandboxError("sandbox missing")
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", raise_sandbox_error)
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
result = write_file_tool.func(
runtime=runtime,
description="sandbox 初始化失败",
path="/mnt/user-data/workspace/output.txt",
content="tiny payload",
)
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
assert "SandboxError: sandbox missing" in result
assert "[write_file error truncated:" not in result
def test_file_operation_lock_memory_cleanup() -> None:
"""Verify that released locks are eventually cleaned up by WeakValueDictionary.
+3 -1
View File
@@ -7,6 +7,7 @@ from types import SimpleNamespace
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.gateway.deps import get_config
from app.gateway.routers import skills as skills_router
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.types import Skill
@@ -38,7 +39,8 @@ def _make_skill(name: str, *, enabled: bool) -> Skill:
def _make_test_app(config) -> FastAPI:
app = FastAPI()
app.state.config = config
app.state.config = config # kept for any startup-style reads
app.dependency_overrides[get_config] = lambda: config
app.include_router(skills_router.router)
return app
@@ -0,0 +1,91 @@
"""Regression tests for _find_usage_recorder callback shape handling.
Bytedance issue #3107 BUG-002: When LangChain passes ``config["callbacks"]`` as
an ``AsyncCallbackManager`` (instead of a plain list), the previous
``for cb in callbacks`` loop raised ``TypeError: 'AsyncCallbackManager' object
is not iterable``. ToolErrorHandlingMiddleware then converted the entire ``task``
tool call into an error ToolMessage, losing the subagent result.
"""
from types import SimpleNamespace
from langchain_core.callbacks import AsyncCallbackManager, CallbackManager
from deerflow.tools.builtins.task_tool import _find_usage_recorder
class _RecorderHandler:
def record_external_llm_usage_records(self, records):
self.records = records
class _OtherHandler:
pass
def _make_runtime(callbacks):
return SimpleNamespace(config={"callbacks": callbacks})
def test_find_usage_recorder_with_plain_list():
recorder = _RecorderHandler()
runtime = _make_runtime([_OtherHandler(), recorder])
assert _find_usage_recorder(runtime) is recorder
def test_find_usage_recorder_with_async_callback_manager():
"""LangChain wraps callbacks in AsyncCallbackManager for async tool runs.
The old implementation raised TypeError here. The recorder lives on
``manager.handlers``; we must look there too.
"""
recorder = _RecorderHandler()
manager = AsyncCallbackManager(handlers=[_OtherHandler(), recorder])
runtime = _make_runtime(manager)
assert _find_usage_recorder(runtime) is recorder
def test_find_usage_recorder_with_sync_callback_manager():
"""Sync flavor of the same wrapper used by some langchain code paths."""
recorder = _RecorderHandler()
manager = CallbackManager(handlers=[recorder])
runtime = _make_runtime(manager)
assert _find_usage_recorder(runtime) is recorder
def test_find_usage_recorder_returns_none_when_no_recorder():
manager = AsyncCallbackManager(handlers=[_OtherHandler()])
runtime = _make_runtime(manager)
assert _find_usage_recorder(runtime) is None
def test_find_usage_recorder_handles_empty_manager():
manager = AsyncCallbackManager(handlers=[])
runtime = _make_runtime(manager)
assert _find_usage_recorder(runtime) is None
def test_find_usage_recorder_returns_none_for_none_runtime():
assert _find_usage_recorder(None) is None
def test_find_usage_recorder_returns_none_when_callbacks_is_none():
runtime = _make_runtime(None)
assert _find_usage_recorder(runtime) is None
def test_find_usage_recorder_returns_none_for_single_handler_object():
"""A single handler instance (not wrapped in a list or manager) should not crash.
LangChain's contract is that ``config["callbacks"]`` is a list-or-manager,
but we treat any other shape defensively rather than letting a ``for`` loop
blow up at runtime.
"""
runtime = _make_runtime(_RecorderHandler())
assert _find_usage_recorder(runtime) is None
def test_find_usage_recorder_returns_none_when_config_not_dict():
"""Defensive: a runtime without a dict-shaped config should not raise."""
runtime = SimpleNamespace(config="not-a-dict")
assert _find_usage_recorder(runtime) is None
+27
View File
@@ -53,3 +53,30 @@ def test_thread_token_usage_returns_stable_shape():
},
}
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
def test_thread_token_usage_can_include_active_runs():
run_store = MagicMock()
run_store.aggregate_tokens_by_thread = AsyncMock(
return_value={
"total_tokens": 175,
"total_input_tokens": 120,
"total_output_tokens": 55,
"total_runs": 3,
"by_model": {"unknown": {"tokens": 175, "runs": 3}},
"by_caller": {
"lead_agent": 145,
"subagent": 25,
"middleware": 5,
},
},
)
app = _make_app(run_store)
with TestClient(app) as client:
response = client.get("/api/threads/thread-1/token-usage?include_active=true")
assert response.status_code == 200
assert response.json()["total_tokens"] == 175
assert response.json()["total_runs"] == 3
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1", include_active=True)
@@ -134,8 +134,14 @@ def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
assert captured["app_config"] is app_config
assert len(middlewares) == 6
assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware)
# 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling,
# SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware
# (enabled by default — see SafetyFinishReasonConfig).
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
assert len(middlewares) == 7
assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares)
assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware)
def test_wrap_tool_call_passthrough_on_success():
+2
View File
@@ -11,6 +11,7 @@ from _router_auth_helpers import call_unwrapped, make_authed_test_app
from fastapi import HTTPException, UploadFile
from fastapi.testclient import TestClient
from app.gateway.deps import get_config
from app.gateway.routers import uploads
@@ -687,6 +688,7 @@ def test_upload_limits_endpoint_requires_thread_access():
cfg.uploads = {}
app = make_authed_test_app(owner_check_passes=False)
app.state.config = cfg
app.dependency_overrides[get_config] = lambda: cfg
app.include_router(uploads.router)
with TestClient(app) as client: