mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
Merge branch 'main' into fix-3127
This commit is contained in:
@@ -218,6 +218,70 @@ class TestBuildPatchedMessagesPatching:
|
||||
|
||||
assert mw._build_patched_messages(msgs) is None
|
||||
|
||||
def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}),
|
||||
_ai_with_tool_calls(
|
||||
[
|
||||
_tc("web_search", "web_search:11"),
|
||||
_tc("web_search", "web_search:12"),
|
||||
_tc("web_search", "web_search:13"),
|
||||
]
|
||||
),
|
||||
_tool_msg("web_search:11", "web_search"),
|
||||
_tool_msg("web_search:12", "web_search"),
|
||||
_tool_msg("web_search:13", "web_search"),
|
||||
_ai_with_tool_calls(
|
||||
[
|
||||
_tc("web_search", "web_search:9"),
|
||||
_tc("web_search", "web_search:10"),
|
||||
_tc("web_search", "web_search:11"),
|
||||
]
|
||||
),
|
||||
_tool_msg("web_search:9", "web_search"),
|
||||
_tool_msg("web_search:10", "web_search"),
|
||||
_tool_msg("web_search:11", "web_search"),
|
||||
]
|
||||
|
||||
assert mw._build_patched_messages(msgs) is None
|
||||
|
||||
def test_reused_tool_call_id_patches_second_dangling_occurrence(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||
_tool_msg("web_search:11", "web_search"),
|
||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert patched[1].tool_call_id == "web_search:11"
|
||||
assert patched[1].status == "success"
|
||||
assert isinstance(patched[3], ToolMessage)
|
||||
assert patched[3].tool_call_id == "web_search:11"
|
||||
assert patched[3].status == "error"
|
||||
|
||||
def test_reused_tool_call_id_consumes_later_result_for_first_dangling_occurrence(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
result = _tool_msg("web_search:11", "web_search")
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
||||
result,
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert patched[1] is result
|
||||
assert patched[1].status == "success"
|
||||
assert isinstance(patched[3], ToolMessage)
|
||||
assert patched[3].tool_call_id == "web_search:11"
|
||||
assert patched[3].status == "error"
|
||||
|
||||
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Regression tests for gateway config freshness on the request hot path.
|
||||
|
||||
Bytedance/deer-flow issue #3107 BUG-001: the worker and lead-agent path
|
||||
captured ``app.state.config`` at gateway startup. ``config.yaml`` edits during
|
||||
runtime were therefore ignored — ``get_app_config()``'s mtime-based reload
|
||||
existed but was bypassed because the snapshot object was passed through
|
||||
explicitly.
|
||||
|
||||
These tests pin the desired behaviour: a request-time ``get_config`` call must
|
||||
observe the most recent on-disk ``config.yaml`` (mtime reload), and the
|
||||
runtime ``ContextVar`` override must keep working for per-request injection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway import deps as gateway_deps
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.config.app_config import (
|
||||
AppConfig,
|
||||
pop_current_app_config,
|
||||
push_current_app_config,
|
||||
reset_app_config,
|
||||
set_app_config,
|
||||
)
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_app_config_singleton():
|
||||
"""Ensure each test starts with a clean module-level cache."""
|
||||
reset_app_config()
|
||||
yield
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def _write_config_yaml(path: Path, *, log_level: str) -> None:
|
||||
path.write_text(
|
||||
f"""
|
||||
sandbox:
|
||||
use: deerflow.sandbox.local.provider:LocalSandboxProvider
|
||||
log_level: {log_level}
|
||||
""".strip()
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def _build_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/probe")
|
||||
def probe(cfg: AppConfig = Depends(get_config)):
|
||||
return {"log_level": cfg.log_level}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_get_config_reflects_file_mtime_reload(tmp_path, monkeypatch):
|
||||
"""Editing config.yaml at runtime must be visible to /probe without restart.
|
||||
|
||||
This is the literal repro for the issue: the gateway must not freeze the
|
||||
config to whatever was on disk when the process started.
|
||||
"""
|
||||
config_file = tmp_path / "config.yaml"
|
||||
_write_config_yaml(config_file, log_level="info")
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
|
||||
|
||||
app = _build_app()
|
||||
client = TestClient(app)
|
||||
assert client.get("/probe").json() == {"log_level": "info"}
|
||||
|
||||
# Edit the file and bump its mtime — simulating a maintainer changing
|
||||
# max_tokens / model settings in production while the gateway is live.
|
||||
_write_config_yaml(config_file, log_level="debug")
|
||||
future_mtime = config_file.stat().st_mtime + 5
|
||||
os.utime(config_file, (future_mtime, future_mtime))
|
||||
|
||||
assert client.get("/probe").json() == {"log_level": "debug"}
|
||||
|
||||
|
||||
def test_get_config_respects_runtime_context_override(tmp_path, monkeypatch):
|
||||
"""Per-request ``push_current_app_config`` injection must still win."""
|
||||
config_file = tmp_path / "config.yaml"
|
||||
_write_config_yaml(config_file, log_level="info")
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
|
||||
|
||||
override = AppConfig(sandbox=SandboxConfig(use="test"), log_level="trace")
|
||||
push_current_app_config(override)
|
||||
try:
|
||||
app = _build_app()
|
||||
client = TestClient(app)
|
||||
assert client.get("/probe").json() == {"log_level": "trace"}
|
||||
finally:
|
||||
pop_current_app_config()
|
||||
|
||||
|
||||
def test_get_config_respects_test_set_app_config():
|
||||
"""``set_app_config`` (used by upload/skills router tests) keeps working."""
|
||||
injected = AppConfig(sandbox=SandboxConfig(use="test"), log_level="warning")
|
||||
set_app_config(injected)
|
||||
|
||||
app = _build_app()
|
||||
client = TestClient(app)
|
||||
assert client.get("/probe").json() == {"log_level": "warning"}
|
||||
|
||||
|
||||
def test_run_context_app_config_reflects_yaml_edit(tmp_path, monkeypatch):
|
||||
"""``RunContext.app_config`` must follow live `config.yaml` edits.
|
||||
|
||||
BUG-001 review feedback: the run-context that feeds worker / lead-agent
|
||||
factories must observe the same mtime reload that `get_config()` does;
|
||||
otherwise stale config slips back in through the run path even after the
|
||||
request dependency is fixed.
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.gateway.deps import get_run_context
|
||||
|
||||
config_file = tmp_path / "config.yaml"
|
||||
_write_config_yaml(config_file, log_level="info")
|
||||
monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file))
|
||||
|
||||
app = FastAPI()
|
||||
# Sentinel values for the rest of the RunContext wiring — we only care
|
||||
# about ``ctx.app_config`` for this assertion.
|
||||
app.state.checkpointer = MagicMock()
|
||||
app.state.store = MagicMock()
|
||||
app.state.run_event_store = MagicMock()
|
||||
app.state.run_events_config = {"frozen": "startup"}
|
||||
app.state.thread_store = MagicMock()
|
||||
|
||||
@app.get("/run-ctx-log-level")
|
||||
def probe(ctx=Depends(get_run_context)):
|
||||
return {
|
||||
"log_level": ctx.app_config.log_level,
|
||||
"run_events_config": ctx.run_events_config,
|
||||
}
|
||||
|
||||
client = TestClient(app)
|
||||
first = client.get("/run-ctx-log-level").json()
|
||||
assert first == {"log_level": "info", "run_events_config": {"frozen": "startup"}}
|
||||
|
||||
_write_config_yaml(config_file, log_level="debug")
|
||||
future_mtime = config_file.stat().st_mtime + 5
|
||||
os.utime(config_file, (future_mtime, future_mtime))
|
||||
|
||||
second = client.get("/run-ctx-log-level").json()
|
||||
# app_config follows the edit; run_events_config stays frozen to the
|
||||
# startup snapshot we wrote onto app.state above.
|
||||
assert second == {"log_level": "debug", "run_events_config": {"frozen": "startup"}}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exception",
|
||||
[
|
||||
FileNotFoundError("config.yaml not found"),
|
||||
PermissionError("config.yaml not readable"),
|
||||
ValueError("invalid config"),
|
||||
RuntimeError("yaml parse error"),
|
||||
],
|
||||
)
|
||||
def test_get_config_returns_503_on_any_load_failure(monkeypatch, exception):
|
||||
"""Any failure to materialise the config must surface as 503, not 500.
|
||||
|
||||
Bytedance/deer-flow issue #3107 BUG-001 review: the original snapshot
|
||||
contract returned 503 when ``app.state.config is None``. The first cut of
|
||||
this fix only mapped ``FileNotFoundError`` to 503, which left
|
||||
``PermissionError`` / ``yaml.YAMLError`` / ``ValidationError`` etc. bubbling
|
||||
up as 500. Catch every load failure at the request boundary.
|
||||
"""
|
||||
|
||||
def _broken_get_app_config():
|
||||
raise exception
|
||||
|
||||
monkeypatch.setattr(gateway_deps, "get_app_config", _broken_get_app_config)
|
||||
|
||||
app = _build_app()
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/probe")
|
||||
|
||||
assert response.status_code == 503
|
||||
assert response.json() == {"detail": "Configuration not available"}
|
||||
@@ -1,41 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def test_get_config_returns_app_state_config():
|
||||
"""get_config should return the exact AppConfig stored on app.state."""
|
||||
app = FastAPI()
|
||||
config = AppConfig(sandbox=SandboxConfig(use="test"))
|
||||
app.state.config = config
|
||||
|
||||
@app.get("/probe")
|
||||
def probe(cfg: AppConfig = Depends(get_config)):
|
||||
return {"same_identity": cfg is config, "log_level": cfg.log_level}
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/probe")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"same_identity": True, "log_level": "info"}
|
||||
|
||||
|
||||
def test_get_config_reads_updated_app_state():
|
||||
"""Swapping app.state.config should be visible to the dependency."""
|
||||
app = FastAPI()
|
||||
app.state.config = AppConfig(sandbox=SandboxConfig(use="test"), log_level="info")
|
||||
|
||||
@app.get("/log-level")
|
||||
def log_level(cfg: AppConfig = Depends(get_config)):
|
||||
return {"level": cfg.log_level}
|
||||
|
||||
client = TestClient(app)
|
||||
assert client.get("/log-level").json() == {"level": "info"}
|
||||
|
||||
app.state.config = app.state.config.model_copy(update={"log_level": "debug"})
|
||||
assert client.get("/log-level").json() == {"level": "debug"}
|
||||
@@ -17,7 +17,7 @@ from fastapi import FastAPI
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _noop_langgraph_runtime(_app):
|
||||
async def _noop_langgraph_runtime(_app, _startup_config):
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@@ -81,6 +81,94 @@ def test_normalize_input_passthrough():
|
||||
assert result == {"custom_key": "value"}
|
||||
|
||||
|
||||
def test_normalize_input_preserves_additional_kwargs_and_id():
|
||||
"""Regression: gh #3132 — frontend ships uploaded-file metadata in
|
||||
additional_kwargs.files (and a client-side message id). The gateway must
|
||||
not strip them before the graph runs, otherwise UploadsMiddleware reports
|
||||
"(empty)" for new uploads and the frontend message loses its file chip.
|
||||
"""
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
files = [{"filename": "a.csv", "size": 100, "path": "/mnt/user-data/uploads/a.csv", "status": "uploaded"}]
|
||||
result = normalize_input(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"type": "human",
|
||||
"id": "client-msg-1",
|
||||
"name": "user-input",
|
||||
"content": [{"type": "text", "text": "clean it"}],
|
||||
"additional_kwargs": {"files": files, "custom": "keep-me"},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
assert len(result["messages"]) == 1
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg, HumanMessage)
|
||||
assert msg.id == "client-msg-1"
|
||||
assert msg.name == "user-input"
|
||||
assert msg.content == [{"type": "text", "text": "clean it"}]
|
||||
assert msg.additional_kwargs == {"files": files, "custom": "keep-me"}
|
||||
|
||||
|
||||
def test_normalize_input_passes_through_basemessage_instances():
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
msg = HumanMessage(content="hello", id="m-1", additional_kwargs={"files": [{"filename": "x"}]})
|
||||
result = normalize_input({"messages": [msg]})
|
||||
assert result["messages"][0] is msg
|
||||
|
||||
|
||||
def test_normalize_input_rejects_malformed_message_with_400():
|
||||
"""Boundary validation: ``convert_to_messages`` raises ``ValueError`` when a
|
||||
message dict is missing ``role``/``type``/``content``. ``normalize_input``
|
||||
runs inside the gateway HTTP boundary, so a malformed payload should surface
|
||||
as a 400 referencing the offending entry — not bubble up as a 500.
|
||||
|
||||
Raised after the Copilot review on PR #3136.
|
||||
"""
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
normalize_input({"messages": [{"role": "human", "content": "ok"}, {"oops": "no role here"}]})
|
||||
assert excinfo.value.status_code == 400
|
||||
assert "input.messages[1]" in excinfo.value.detail
|
||||
|
||||
|
||||
def test_normalize_input_handles_non_human_roles():
|
||||
"""The previous implementation collapsed every role to HumanMessage with a
|
||||
`# TODO: handle other message types` comment. Resuming a thread with prior
|
||||
AI/tool messages would silently rewrite them as human turns — corrupting
|
||||
the conversation. Use langchain's standard conversion so ai/system/tool
|
||||
roles round-trip correctly.
|
||||
"""
|
||||
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
|
||||
|
||||
from app.gateway.services import normalize_input
|
||||
|
||||
result = normalize_input(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "ai", "content": "hi", "id": "ai-1"},
|
||||
{"role": "tool", "content": "result", "tool_call_id": "call-1"},
|
||||
]
|
||||
}
|
||||
)
|
||||
types = [type(m) for m in result["messages"]]
|
||||
assert types == [SystemMessage, AIMessage, ToolMessage]
|
||||
assert result["messages"][1].id == "ai-1"
|
||||
assert result["messages"][2].tool_call_id == "call-1"
|
||||
|
||||
|
||||
def test_build_run_config_basic():
|
||||
from app.gateway.services import build_run_config
|
||||
|
||||
|
||||
@@ -336,8 +336,11 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
)
|
||||
|
||||
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
|
||||
# verify the custom middleware is injected correctly
|
||||
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
|
||||
# verify the custom middleware is injected correctly.
|
||||
# Chain tail order after the custom middleware is:
|
||||
# ..., custom, SafetyFinishReasonMiddleware, ClarificationMiddleware
|
||||
# so the custom mock sits at index [-3].
|
||||
assert len(middlewares) > 0 and isinstance(middlewares[-3], MagicMock)
|
||||
|
||||
|
||||
def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch):
|
||||
|
||||
@@ -0,0 +1,409 @@
|
||||
"""Tests for the MCP persistent-session pool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.mcp.session_pool import MCPSessionPool, get_session_pool, reset_session_pool
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_pool():
|
||||
reset_session_pool()
|
||||
yield
|
||||
reset_session_pool()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCPSessionPool unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_creates_new():
|
||||
"""First call for a key creates a new session."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
session = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert session is mock_session
|
||||
mock_session.initialize.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_reuses_existing():
|
||||
"""Second call for the same key returns the cached session."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
s2 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert s1 is s2
|
||||
# Only one session should have been created.
|
||||
assert mock_cm.__aenter__.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_scope_creates_different_session():
|
||||
"""Different scope keys get different sessions."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
sessions = [AsyncMock(), AsyncMock()]
|
||||
idx = 0
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.enter_count = 0
|
||||
|
||||
async def __aenter__(self):
|
||||
nonlocal idx
|
||||
s = sessions[idx]
|
||||
idx += 1
|
||||
self.enter_count += 1
|
||||
return s
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return False
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=lambda *a, **kw: CmFactory()):
|
||||
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
s2 = await pool.get_session("server", "thread-2", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert s1 is not s2
|
||||
assert s1 is sessions[0]
|
||||
assert s2 is sessions[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_eviction():
|
||||
"""Oldest entries are evicted when the pool is full."""
|
||||
pool = MCPSessionPool()
|
||||
pool.MAX_SESSIONS = 2
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return AsyncMock()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
cms: list[CmFactory] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = CmFactory()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
# Pool is full (2). Adding t3 should evict t1.
|
||||
await pool.get_session("s", "t3", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert cms[0].closed is True
|
||||
assert cms[1].closed is False
|
||||
assert cms[2].closed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_scope():
|
||||
"""close_scope shuts down sessions for a specific scope key."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return AsyncMock()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
cms: list[CmFactory] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = CmFactory()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
await pool.close_scope("t1")
|
||||
|
||||
assert cms[0].closed is True
|
||||
assert cms[1].closed is False
|
||||
|
||||
# t2 session still exists.
|
||||
assert ("s", "t2") in pool._entries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_all():
|
||||
"""close_all shuts down every session."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return AsyncMock()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
cms: list[CmFactory] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = CmFactory()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await pool.get_session("s1", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await pool.get_session("s2", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
await pool.close_all()
|
||||
|
||||
assert all(cm.closed for cm in cms)
|
||||
assert len(pool._entries) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_session_pool_singleton():
|
||||
"""get_session_pool returns the same instance."""
|
||||
p1 = get_session_pool()
|
||||
p2 = get_session_pool()
|
||||
assert p1 is p2
|
||||
|
||||
|
||||
def test_reset_session_pool():
|
||||
"""reset_session_pool clears the singleton."""
|
||||
p1 = get_session_pool()
|
||||
reset_session_pool()
|
||||
p2 = get_session_pool()
|
||||
assert p1 is not p2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: _make_session_pool_tool uses the pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_wrapping():
|
||||
"""The wrapper tool delegates to a pool-managed session."""
|
||||
# Build a dummy StructuredTool (as returned by langchain-mcp-adapters).
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
url: str = Field(..., description="url")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="playwright_navigate",
|
||||
description="Navigate browser",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
connection = {"transport": "stdio", "command": "pw", "args": []}
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
|
||||
|
||||
# Simulate a tool call with a runtime context containing thread_id.
|
||||
mock_runtime = MagicMock()
|
||||
mock_runtime.context = {"thread_id": "thread-42"}
|
||||
mock_runtime.config = {}
|
||||
|
||||
await wrapped.coroutine(runtime=mock_runtime, url="https://example.com")
|
||||
|
||||
mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_extracts_thread_id():
|
||||
"""Thread ID is extracted from runtime.config when not in context."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
x: int = Field(..., description="x")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="server_tool",
|
||||
description="test",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
mock_runtime = MagicMock()
|
||||
mock_runtime.context = {}
|
||||
mock_runtime.config = {"configurable": {"thread_id": "from-config"}}
|
||||
|
||||
await wrapped.coroutine(runtime=mock_runtime, x=1)
|
||||
|
||||
# Verify the session was created with the correct scope key.
|
||||
pool = get_session_pool()
|
||||
assert ("server", "from-config") in pool._entries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_default_scope():
|
||||
"""When no thread_id is available, 'default' is used as scope key."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
x: int = Field(..., description="x")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="server_tool",
|
||||
description="test",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
# No thread_id in runtime at all.
|
||||
await wrapped.coroutine(runtime=None, x=1)
|
||||
|
||||
pool = get_session_pool()
|
||||
assert ("server", "default") in pool._entries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_get_config_fallback():
|
||||
"""When runtime is None, get_config() provides thread_id as fallback."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
x: int = Field(..., description="x")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="server_tool",
|
||||
description="test",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
fake_config = {"configurable": {"thread_id": "from-langgraph-config"}}
|
||||
|
||||
with (
|
||||
patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm),
|
||||
patch("deerflow.mcp.tools.get_config", return_value=fake_config),
|
||||
):
|
||||
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
# runtime=None — get_config() fallback should provide thread_id
|
||||
await wrapped.coroutine(runtime=None, x=1)
|
||||
|
||||
pool = get_session_pool()
|
||||
assert ("server", "from-langgraph-config") in pool._entries
|
||||
|
||||
|
||||
def test_session_pool_tool_sync_wrapper_path_is_safe():
|
||||
"""Sync wrapper (tool.func) invocation doesn't crash on cross-loop access."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||
|
||||
class Args(BaseModel):
|
||||
url: str = Field(..., description="url")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="playwright_navigate",
|
||||
description="Navigate browser",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
connection = {"transport": "stdio", "command": "pw", "args": []}
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
|
||||
# Attach the sync wrapper exactly as get_mcp_tools() does.
|
||||
wrapped.func = make_sync_tool_wrapper(wrapped.coroutine, wrapped.name)
|
||||
|
||||
# Call via the sync path (asyncio.run in a worker thread).
|
||||
# runtime is not supplied so _extract_thread_id falls back to "default".
|
||||
wrapped.func(url="https://example.com")
|
||||
|
||||
mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"})
|
||||
@@ -714,6 +714,110 @@ class TestExternalUsageRecords:
|
||||
assert j._subagent_tokens == 0
|
||||
|
||||
|
||||
class TestProgressSnapshots:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_reports_progress_snapshot(self):
|
||||
snapshots: list[dict] = []
|
||||
|
||||
async def reporter(snapshot: dict) -> None:
|
||||
snapshots.append(snapshot)
|
||||
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal(
|
||||
"r1",
|
||||
"t1",
|
||||
store,
|
||||
flush_threshold=100,
|
||||
progress_reporter=reporter,
|
||||
progress_flush_interval=0,
|
||||
)
|
||||
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
j.on_llm_end(_make_llm_response("Answer", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
await j.flush()
|
||||
|
||||
assert snapshots
|
||||
assert snapshots[-1]["total_tokens"] == 15
|
||||
assert snapshots[-1]["llm_call_count"] == 1
|
||||
assert snapshots[-1]["message_count"] == 1
|
||||
assert snapshots[-1]["last_ai_message"] == "Answer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_throttled_progress_flush_emits_trailing_snapshot(self):
|
||||
snapshots: list[dict] = []
|
||||
trailing_seen = asyncio.Event()
|
||||
|
||||
async def reporter(snapshot: dict) -> None:
|
||||
snapshots.append(snapshot)
|
||||
if snapshot["total_tokens"] == 45:
|
||||
trailing_seen.set()
|
||||
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal(
|
||||
"r1",
|
||||
"t1",
|
||||
store,
|
||||
flush_threshold=100,
|
||||
progress_reporter=reporter,
|
||||
progress_flush_interval=0.01,
|
||||
)
|
||||
j.on_llm_end(
|
||||
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
j.on_llm_end(
|
||||
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
await asyncio.wait_for(trailing_seen.wait(), timeout=1.0)
|
||||
await j.flush()
|
||||
|
||||
assert len(snapshots) >= 2
|
||||
assert snapshots[-1]["total_tokens"] == 45
|
||||
assert snapshots[-1]["llm_call_count"] == 2
|
||||
assert snapshots[-1]["last_ai_message"] == "Second"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_flush_cancels_delayed_progress_without_final_progress_write(self):
|
||||
snapshots: list[dict] = []
|
||||
|
||||
async def reporter(snapshot: dict) -> None:
|
||||
snapshots.append(snapshot)
|
||||
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal(
|
||||
"r1",
|
||||
"t1",
|
||||
store,
|
||||
flush_threshold=100,
|
||||
progress_reporter=reporter,
|
||||
progress_flush_interval=10.0,
|
||||
)
|
||||
j.on_llm_end(
|
||||
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
assert snapshots[-1]["total_tokens"] == 15
|
||||
j.on_llm_end(
|
||||
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
|
||||
await asyncio.wait_for(j.flush(), timeout=0.2)
|
||||
|
||||
assert snapshots[-1]["total_tokens"] == 15
|
||||
assert snapshots[-1]["llm_call_count"] == 1
|
||||
assert snapshots[-1]["last_ai_message"] == "First"
|
||||
|
||||
|
||||
class TestChatModelStartHumanMessage:
|
||||
"""Tests for on_chat_model_start extracting the first human message."""
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.dialects import postgresql
|
||||
|
||||
from deerflow.persistence.run import RunRepository
|
||||
from deerflow.runtime import RunManager, RunStatus
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
|
||||
async def _make_repo(tmp_path):
|
||||
@@ -26,6 +27,42 @@ async def _cleanup():
|
||||
await close_engine()
|
||||
|
||||
|
||||
class _CustomRunStoreWithoutProgress(RunStore):
|
||||
async def put(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def get(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def list_by_thread(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
async def update_status(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def delete(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def update_model_name(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def update_run_completion(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def list_pending(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
async def aggregate_tokens_by_thread(self, *args, **kwargs):
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_progress_defaults_to_noop_for_custom_store():
|
||||
store = _CustomRunStoreWithoutProgress()
|
||||
|
||||
await store.update_run_progress("r1", total_tokens=1)
|
||||
|
||||
|
||||
class TestRunRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_put_and_get(self, tmp_path):
|
||||
@@ -170,6 +207,69 @@ class TestRunRepository:
|
||||
assert row["total_tokens"] == 100
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_progress_keeps_status_running(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="running")
|
||||
await repo.update_run_progress(
|
||||
"r1",
|
||||
total_input_tokens=40,
|
||||
total_output_tokens=10,
|
||||
total_tokens=50,
|
||||
llm_call_count=1,
|
||||
message_count=2,
|
||||
last_ai_message="partial answer",
|
||||
)
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "running"
|
||||
assert row["total_tokens"] == 50
|
||||
assert row["llm_call_count"] == 1
|
||||
assert row["message_count"] == 2
|
||||
assert row["last_ai_message"] == "partial answer"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_progress_preserves_omitted_fields(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="running")
|
||||
await repo.update_run_progress(
|
||||
"r1",
|
||||
total_input_tokens=40,
|
||||
total_output_tokens=10,
|
||||
total_tokens=50,
|
||||
llm_call_count=1,
|
||||
lead_agent_tokens=30,
|
||||
subagent_tokens=20,
|
||||
message_count=2,
|
||||
)
|
||||
|
||||
await repo.update_run_progress("r1", total_tokens=60, last_ai_message="updated")
|
||||
|
||||
row = await repo.get("r1")
|
||||
assert row["total_input_tokens"] == 40
|
||||
assert row["total_output_tokens"] == 10
|
||||
assert row["total_tokens"] == 60
|
||||
assert row["llm_call_count"] == 1
|
||||
assert row["lead_agent_tokens"] == 30
|
||||
assert row["subagent_tokens"] == 20
|
||||
assert row["message_count"] == 2
|
||||
assert row["last_ai_message"] == "updated"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_progress_skips_terminal_runs(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="running")
|
||||
await repo.update_run_completion("r1", status="success", total_tokens=100, llm_call_count=1)
|
||||
|
||||
await repo.update_run_progress("r1", total_tokens=200, llm_call_count=2)
|
||||
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "success"
|
||||
assert row["total_tokens"] == 100
|
||||
assert row["llm_call_count"] == 1
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
@@ -225,6 +325,28 @@ class TestRunRepository:
|
||||
}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_tokens_by_thread_can_include_active_runs(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("success-run", thread_id="t1", status="running")
|
||||
await repo.update_run_completion("success-run", status="success", total_tokens=100, lead_agent_tokens=100)
|
||||
await repo.put("running-run", thread_id="t1", status="running")
|
||||
await repo.update_run_progress("running-run", total_tokens=25, lead_agent_tokens=20, subagent_tokens=5)
|
||||
|
||||
without_active = await repo.aggregate_tokens_by_thread("t1")
|
||||
with_active = await repo.aggregate_tokens_by_thread("t1", include_active=True)
|
||||
|
||||
assert without_active["total_tokens"] == 100
|
||||
assert without_active["total_runs"] == 1
|
||||
assert with_active["total_tokens"] == 125
|
||||
assert with_active["total_runs"] == 2
|
||||
assert with_active["by_caller"] == {
|
||||
"lead_agent": 120,
|
||||
"subagent": 5,
|
||||
"middleware": 0,
|
||||
}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_ordered_desc(self, tmp_path):
|
||||
"""list_by_thread returns newest first."""
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
"""End-to-end graph integration test for SafetyFinishReasonMiddleware.
|
||||
|
||||
Unit tests prove ``_apply`` does the right thing on a synthetic state.
|
||||
This test does one level up: builds a real ``langchain.agents.create_agent``
|
||||
graph with the SafetyFinishReasonMiddleware in place, feeds it a fake model
|
||||
that returns ``finish_reason='content_filter'`` + tool_calls, and asserts:
|
||||
|
||||
1. The tool node is **not** invoked (the dangerous truncated tool call
|
||||
is suppressed).
|
||||
2. The final AIMessage in graph state has ``tool_calls == []``.
|
||||
3. The observability ``safety_termination`` record is attached.
|
||||
4. The user-facing explanation is appended to the message content.
|
||||
|
||||
This is the closest we can get to the issue's failure mode without a live
|
||||
Moonshot key, and it proves the middleware actually gates LangChain's
|
||||
tool router — not just rewrites state in isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
|
||||
_TOOL_INVOCATIONS: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@tool
|
||||
def write_file(path: str, content: str) -> str:
|
||||
"""Pretend to write *content* to *path*. Records the call for assertion."""
|
||||
_TOOL_INVOCATIONS.append({"path": path, "content": content})
|
||||
return f"wrote {len(content)} bytes to {path}"
|
||||
|
||||
|
||||
class _ContentFilteredModel(BaseChatModel):
|
||||
"""Fake chat model that mimics OpenAI/Moonshot's content_filter response.
|
||||
|
||||
First call returns finish_reason='content_filter' + a tool_call whose
|
||||
arguments are visibly truncated. Second call (if reached) returns a
|
||||
normal text completion so the agent can terminate cleanly.
|
||||
"""
|
||||
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-content-filtered"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
# create_agent binds tools onto the model; we don't actually need
|
||||
# to bind anything since responses are hard-coded, but the method
|
||||
# must not raise.
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self.call_count += 1
|
||||
if self.call_count == 1:
|
||||
message = AIMessage(
|
||||
content="Here is the report:\n# Weekly Politics\n- Meeting time: 2026-05-12—",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_truncated_1",
|
||||
"name": "write_file",
|
||||
"args": {
|
||||
"path": "/mnt/user-data/outputs/report.md",
|
||||
"content": "# Weekly Politics\n- Meeting time: 2026-05-12—",
|
||||
},
|
||||
}
|
||||
],
|
||||
response_metadata={"finish_reason": "content_filter", "model_name": "fake-kimi"},
|
||||
)
|
||||
else:
|
||||
message = AIMessage(content="ack", response_metadata={"finish_reason": "stop"})
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
|
||||
class _InspectMiddleware(AgentMiddleware):
|
||||
"""Captures the messages list at every model entry so we can assert
|
||||
no synthetic tool result was injected back into the conversation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.observed: list[list[Any]] = []
|
||||
|
||||
def wrap_model_call(self, request: ModelRequest, handler) -> ModelResponse:
|
||||
self.observed.append(list(request.messages))
|
||||
return handler(request)
|
||||
|
||||
|
||||
def test_content_filter_with_tool_calls_does_not_invoke_tool_node():
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
inspector = _InspectMiddleware()
|
||||
|
||||
agent = create_agent(
|
||||
model=_ContentFilteredModel(),
|
||||
tools=[write_file],
|
||||
# Inspector first so its after_model is registered; Safety last in
|
||||
# the list so it executes first under LIFO (matches production wiring).
|
||||
middleware=[inspector, SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage(content="write me a report")]})
|
||||
|
||||
# Critical assertion: the dangerous truncated tool call must NOT have
|
||||
# been executed. This is the entire point of the middleware.
|
||||
assert _TOOL_INVOCATIONS == [], f"write_file was invoked despite content_filter: {_TOOL_INVOCATIONS}"
|
||||
|
||||
# Final AIMessage has no tool calls left.
|
||||
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
|
||||
assert final_ai.tool_calls == []
|
||||
|
||||
# Observability stamp is present.
|
||||
record = final_ai.additional_kwargs.get("safety_termination")
|
||||
assert record is not None
|
||||
assert record["detector"] == "openai_compatible_content_filter"
|
||||
assert record["reason_field"] == "finish_reason"
|
||||
assert record["reason_value"] == "content_filter"
|
||||
assert record["suppressed_tool_call_count"] == 1
|
||||
assert record["suppressed_tool_call_names"] == ["write_file"]
|
||||
|
||||
# User-facing explanation is appended.
|
||||
assert "safety-related signal" in final_ai.content
|
||||
# Original partial text preserved (we don't throw away what the user
|
||||
# already saw in the stream — see middleware docstring).
|
||||
assert "Weekly Politics" in final_ai.content
|
||||
|
||||
# finish_reason on response_metadata is preserved (so SSE / converters
|
||||
# downstream still see the real provider reason).
|
||||
assert final_ai.response_metadata.get("finish_reason") == "content_filter"
|
||||
|
||||
|
||||
def test_content_filter_without_tool_calls_passes_through_unchanged():
|
||||
"""No tool calls => issue scope says don't intervene; the partial
|
||||
response should be delivered as-is so the user sees what they got."""
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
|
||||
class _NoToolModel(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-no-tool"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
msg = AIMessage(
|
||||
content="Partial answer truncated by safety filter",
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=msg)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
agent = create_agent(
|
||||
model=_NoToolModel(),
|
||||
tools=[write_file],
|
||||
middleware=[SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
result = agent.invoke({"messages": [HumanMessage(content="hi")]})
|
||||
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
|
||||
|
||||
# Content untouched.
|
||||
assert final_ai.content == "Partial answer truncated by safety filter"
|
||||
# No safety_termination stamp because we didn't intervene.
|
||||
assert "safety_termination" not in final_ai.additional_kwargs
|
||||
# tool node never ran (there were no tool calls in the first place).
|
||||
assert _TOOL_INVOCATIONS == []
|
||||
|
||||
|
||||
def test_normal_tool_call_round_trip_is_not_affected():
|
||||
"""Regression: a healthy finish_reason='tool_calls' response must still
|
||||
execute the tool. The middleware must not over-fire."""
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
|
||||
class _HealthyToolModel(BaseChatModel):
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-healthy"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self.call_count += 1
|
||||
if self.call_count == 1:
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_ok",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/tmp/ok", "content": "complete content"},
|
||||
}
|
||||
],
|
||||
response_metadata={"finish_reason": "tool_calls"},
|
||||
)
|
||||
else:
|
||||
msg = AIMessage(content="done", response_metadata={"finish_reason": "stop"})
|
||||
return ChatResult(generations=[ChatGeneration(message=msg)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
agent = create_agent(
|
||||
model=_HealthyToolModel(),
|
||||
tools=[write_file],
|
||||
middleware=[SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
agent.invoke({"messages": [HumanMessage(content="write")]})
|
||||
|
||||
assert _TOOL_INVOCATIONS == [{"path": "/tmp/ok", "content": "complete content"}]
|
||||
@@ -0,0 +1,651 @@
|
||||
"""Unit tests for SafetyFinishReasonMiddleware."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
from deerflow.agents.middlewares.safety_termination_detectors import (
|
||||
SafetyTermination,
|
||||
)
|
||||
from deerflow.config.safety_finish_reason_config import (
|
||||
SafetyDetectorConfig,
|
||||
SafetyFinishReasonConfig,
|
||||
)
|
||||
|
||||
|
||||
def _runtime(thread_id="t-1"):
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": thread_id}
|
||||
return runtime
|
||||
|
||||
|
||||
def _ai(
|
||||
*,
|
||||
content="",
|
||||
tool_calls=None,
|
||||
response_metadata=None,
|
||||
additional_kwargs=None,
|
||||
):
|
||||
return AIMessage(
|
||||
content=content,
|
||||
tool_calls=tool_calls or [],
|
||||
response_metadata=response_metadata or {},
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
def _write_call(idx=1, content_text="半截"):
|
||||
return {
|
||||
"id": f"call_write_{idx}",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/mnt/user-data/outputs/x.md", "content": content_text},
|
||||
}
|
||||
|
||||
|
||||
class AlwaysHitDetector:
|
||||
"""Test fixture: always reports the given termination."""
|
||||
|
||||
name = "always_hit"
|
||||
|
||||
def __init__(self, *, reason_field="finish_reason", reason_value="content_filter", extras=None):
|
||||
self.reason_field = reason_field
|
||||
self.reason_value = reason_value
|
||||
self.extras = extras or {}
|
||||
|
||||
def detect(self, message):
|
||||
return SafetyTermination(
|
||||
detector=self.name,
|
||||
reason_field=self.reason_field,
|
||||
reason_value=self.reason_value,
|
||||
extras=self.extras,
|
||||
)
|
||||
|
||||
|
||||
class NeverHitDetector:
|
||||
name = "never_hit"
|
||||
|
||||
def detect(self, message):
|
||||
return None
|
||||
|
||||
|
||||
class RaisingDetector:
|
||||
name = "raising"
|
||||
|
||||
def detect(self, message):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core trigger behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTriggerCriteria:
|
||||
def test_content_filter_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
patched = result["messages"][0]
|
||||
assert patched.tool_calls == []
|
||||
|
||||
def test_content_filter_without_tool_calls_passes_through(self):
|
||||
"""issue scope: when there are no tool calls the partial text is a
|
||||
legitimate final response and should not be rewritten."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial response",
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_normal_tool_calls_pass_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "tool_calls"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_normal_stop_with_tool_calls_pass_through(self):
|
||||
# Some providers report finish_reason='stop' for tool-call messages.
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "stop"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_empty_message_list_passes_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
assert mw._apply({"messages": []}, _runtime()) is None
|
||||
|
||||
def test_non_ai_last_message_passes_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {"messages": [HumanMessage(content="hi"), SystemMessage(content="sys")]}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_anthropic_refusal_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"stop_reason": "refusal"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_gemini_safety_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "SAFETY"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message rewriting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageRewrite:
|
||||
def test_clears_structured_tool_calls(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1), _write_call(2)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
patched = result["messages"][0]
|
||||
assert patched.tool_calls == []
|
||||
|
||||
def test_clears_raw_additional_kwargs_tool_calls(self):
|
||||
"""Critical defence-in-depth: DanglingToolCallMiddleware will recover
|
||||
tool calls from additional_kwargs.tool_calls if we forget them, which
|
||||
would re-emit a synthetic ToolMessage downstream and confuse the
|
||||
model. We must wipe both."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
raw_tool_calls = [
|
||||
{
|
||||
"id": "call_write_1",
|
||||
"type": "function",
|
||||
"function": {"name": "write_file", "arguments": '{"path": "/x"}'},
|
||||
}
|
||||
]
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
additional_kwargs={
|
||||
"tool_calls": raw_tool_calls,
|
||||
"function_call": {"name": "write_file", "arguments": "{}"},
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
patched = result["messages"][0]
|
||||
assert "tool_calls" not in patched.additional_kwargs
|
||||
assert "function_call" not in patched.additional_kwargs
|
||||
|
||||
def test_preserves_other_additional_kwargs(self):
|
||||
# vLLM puts reasoning under additional_kwargs.reasoning; Anthropic
|
||||
# may carry other provider-specific keys. They must not be wiped.
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
additional_kwargs={
|
||||
"reasoning": "thinking text",
|
||||
"custom_provider_field": {"x": 1},
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.additional_kwargs["reasoning"] == "thinking text"
|
||||
assert patched.additional_kwargs["custom_provider_field"] == {"x": 1}
|
||||
|
||||
def test_writes_observability_field(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1), _write_call(2)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
record = patched.additional_kwargs["safety_termination"]
|
||||
assert record["detector"] == "openai_compatible_content_filter"
|
||||
assert record["reason_field"] == "finish_reason"
|
||||
assert record["reason_value"] == "content_filter"
|
||||
assert record["suppressed_tool_call_count"] == 2
|
||||
assert record["suppressed_tool_call_names"] == ["write_file", "write_file"]
|
||||
|
||||
def test_preserves_response_metadata_finish_reason(self):
|
||||
"""Downstream SSE converters read response_metadata.finish_reason —
|
||||
we want them to see the *real* provider reason, not 'stop'."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter", "model_name": "kimi-k2"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.response_metadata["finish_reason"] == "content_filter"
|
||||
assert patched.response_metadata["model_name"] == "kimi-k2"
|
||||
|
||||
def test_appends_user_facing_explanation_to_str_content(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="some partial text",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, str)
|
||||
assert patched.content.startswith("some partial text")
|
||||
assert "safety-related signal" in patched.content
|
||||
|
||||
def test_handles_empty_content(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, str)
|
||||
assert "safety-related signal" in patched.content
|
||||
|
||||
def test_handles_list_content_thinking_blocks(self):
|
||||
"""Anthropic thinking / vLLM reasoning models emit content blocks.
|
||||
Naively concatenating a string would raise TypeError."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
thinking_blocks = [
|
||||
{"type": "thinking", "text": "let me consider..."},
|
||||
{"type": "text", "text": "partial answer"},
|
||||
]
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content=thinking_blocks,
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, list)
|
||||
assert patched.content[:2] == thinking_blocks
|
||||
assert patched.content[-1]["type"] == "text"
|
||||
assert "safety-related signal" in patched.content[-1]["text"]
|
||||
|
||||
def test_idempotent_on_already_cleared_message(self):
|
||||
# Re-running the middleware on a message we already cleared must not
|
||||
# re-trigger (tool_calls is now empty → fast passthrough).
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
first = mw._apply(state, _runtime())
|
||||
state2 = {"messages": [first["messages"][0]]}
|
||||
second = mw._apply(state2, _runtime())
|
||||
assert second is None
|
||||
|
||||
def test_preserves_message_id_for_add_messages_replacement(self):
|
||||
"""LangGraph's add_messages reducer treats same-id messages as
|
||||
replacements. model_copy keeps id by default."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
original = _ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
# AIMessage auto-generates id; capture it
|
||||
original_id = original.id
|
||||
state = {"messages": [original]}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.id == original_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detector wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDetectorWiring:
|
||||
def test_iterates_detectors_in_order(self):
|
||||
first = AlwaysHitDetector(reason_value="first")
|
||||
second = AlwaysHitDetector(reason_value="second")
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[first, second])
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.additional_kwargs["safety_termination"]["reason_value"] == "first"
|
||||
|
||||
def test_returns_none_when_no_detector_matches(self):
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[NeverHitDetector(), NeverHitDetector()])
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_buggy_detector_does_not_break_run(self):
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[RaisingDetector(), AlwaysHitDetector()])
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].additional_kwargs["safety_termination"]["detector"] == "always_hit"
|
||||
|
||||
def test_constructor_copies_detectors(self):
|
||||
"""Caller mutation after construction must not leak into us."""
|
||||
detectors = [AlwaysHitDetector()]
|
||||
mw = SafetyFinishReasonMiddleware(detectors=detectors)
|
||||
detectors.clear()
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
assert mw._apply(state, _runtime()) is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# from_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFromConfig:
|
||||
def test_default_config_uses_builtin_detectors(self):
|
||||
mw = SafetyFinishReasonMiddleware.from_config(SafetyFinishReasonConfig())
|
||||
assert len(mw._detectors) == 3
|
||||
names = {d.name for d in mw._detectors}
|
||||
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
|
||||
|
||||
def test_custom_detectors_loaded_via_reflection(self):
|
||||
cfg = SafetyFinishReasonConfig(
|
||||
detectors=[
|
||||
SafetyDetectorConfig(
|
||||
use="deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector",
|
||||
config={"finish_reasons": ["custom_filter"]},
|
||||
),
|
||||
]
|
||||
)
|
||||
mw = SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
assert len(mw._detectors) == 1
|
||||
# Confirm the kwargs propagated.
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "custom_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is not None
|
||||
# Default token no longer matches.
|
||||
state2 = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state2, _runtime()) is None
|
||||
|
||||
def test_empty_detector_list_rejected(self):
|
||||
cfg = SafetyFinishReasonConfig(detectors=[])
|
||||
with pytest.raises(ValueError, match="enabled=false"):
|
||||
SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
|
||||
def test_non_detector_class_rejected(self):
|
||||
cfg = SafetyFinishReasonConfig(
|
||||
detectors=[SafetyDetectorConfig(use="builtins:dict")],
|
||||
)
|
||||
with pytest.raises(TypeError):
|
||||
SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stream event
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuditEvent:
|
||||
"""Verify SafetyFinishReasonMiddleware records a `middleware:safety_termination`
|
||||
audit event via RunJournal.record_middleware when the run-scoped journal is
|
||||
exposed under runtime.context["__run_journal"].
|
||||
|
||||
Background: review on PR #3035 — SSE custom event handles live consumers,
|
||||
but post-run audit needs a row in run_events that can be queried with one
|
||||
SQL statement (no JOIN against message body).
|
||||
"""
|
||||
|
||||
def _runtime_with_journal(self, journal):
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "t-audit", "__run_journal": journal}
|
||||
return runtime
|
||||
|
||||
def test_records_audit_event_when_journal_present(self):
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
tc = _write_call(1)
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial",
|
||||
tool_calls=[tc],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, self._runtime_with_journal(journal))
|
||||
assert result is not None
|
||||
|
||||
journal.record_middleware.assert_called_once()
|
||||
call = journal.record_middleware.call_args
|
||||
# tag is positional or kwarg depending on call style; we use kwargs.
|
||||
assert call.kwargs["tag"] == "safety_termination"
|
||||
assert call.kwargs["name"] == "SafetyFinishReasonMiddleware"
|
||||
assert call.kwargs["hook"] == "after_model"
|
||||
assert call.kwargs["action"] == "suppress_tool_calls"
|
||||
|
||||
changes = call.kwargs["changes"]
|
||||
assert changes["detector"] == "openai_compatible_content_filter"
|
||||
assert changes["reason_field"] == "finish_reason"
|
||||
assert changes["reason_value"] == "content_filter"
|
||||
assert changes["suppressed_tool_call_count"] == 1
|
||||
assert changes["suppressed_tool_call_names"] == ["write_file"]
|
||||
assert changes["suppressed_tool_call_ids"] == ["call_write_1"]
|
||||
assert "message_id" in changes
|
||||
assert isinstance(changes["extras"], dict)
|
||||
|
||||
def test_audit_event_never_carries_tool_arguments(self):
|
||||
"""PR #3035 review IMPORTANT: tool args are the filtered content itself
|
||||
and must NOT be persisted to run_events under any circumstance."""
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
sensitive_tc = {
|
||||
"id": "call_x",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/x", "content": "FILTERED_CONTENT_DO_NOT_PERSIST"},
|
||||
}
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[sensitive_tc],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
mw._apply(state, self._runtime_with_journal(journal))
|
||||
flat = repr(journal.record_middleware.call_args)
|
||||
assert "FILTERED_CONTENT_DO_NOT_PERSIST" not in flat, "tool arguments must not leak into audit event"
|
||||
assert "args" not in journal.record_middleware.call_args.kwargs["changes"]
|
||||
|
||||
def test_no_journal_in_runtime_context_is_silently_skipped(self):
|
||||
"""Subagent runtime / unit tests / no-event-store paths have no journal.
|
||||
Middleware must still intervene and clear tool_calls — only the audit
|
||||
event is skipped."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "t-noj"} # no __run_journal
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Should not raise; should still clear tool_calls.
|
||||
result = mw._apply(state, runtime)
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_journal_record_exception_does_not_break_run(self):
|
||||
"""Buggy journal must never propagate an exception into the agent loop."""
|
||||
journal = MagicMock()
|
||||
journal.record_middleware.side_effect = RuntimeError("db down")
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Must not raise.
|
||||
result = mw._apply(state, self._runtime_with_journal(journal))
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_no_record_when_passthrough(self):
|
||||
"""When the middleware does NOT intervene, no audit event is written."""
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "tool_calls"}, # healthy
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, self._runtime_with_journal(journal)) is None
|
||||
journal.record_middleware.assert_not_called()
|
||||
|
||||
|
||||
class TestStreamEvent:
|
||||
def test_emits_event_when_writer_available(self, monkeypatch):
|
||||
captured: list = []
|
||||
|
||||
def fake_writer(payload):
|
||||
captured.append(payload)
|
||||
|
||||
# Patch get_stream_writer at the symbol-resolution site.
|
||||
import langgraph.config
|
||||
|
||||
monkeypatch.setattr(langgraph.config, "get_stream_writer", lambda: fake_writer)
|
||||
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
mw._apply(state, _runtime("t-stream"))
|
||||
|
||||
assert len(captured) == 1
|
||||
payload = captured[0]
|
||||
assert payload["type"] == "safety_termination"
|
||||
assert payload["detector"] == "openai_compatible_content_filter"
|
||||
assert payload["reason_field"] == "finish_reason"
|
||||
assert payload["reason_value"] == "content_filter"
|
||||
assert payload["suppressed_tool_call_count"] == 1
|
||||
assert payload["suppressed_tool_call_names"] == ["write_file"]
|
||||
assert payload["thread_id"] == "t-stream"
|
||||
|
||||
def test_writer_unavailable_does_not_break(self, monkeypatch):
|
||||
import langgraph.config
|
||||
|
||||
def boom():
|
||||
raise LookupError("not in a stream context")
|
||||
|
||||
monkeypatch.setattr(langgraph.config, "get_stream_writer", boom)
|
||||
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Should not raise.
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
@@ -0,0 +1,176 @@
|
||||
"""Unit tests for SafetyTerminationDetector built-ins."""
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from deerflow.agents.middlewares.safety_termination_detectors import (
|
||||
AnthropicRefusalDetector,
|
||||
GeminiSafetyDetector,
|
||||
OpenAICompatibleContentFilterDetector,
|
||||
SafetyTermination,
|
||||
SafetyTerminationDetector,
|
||||
default_detectors,
|
||||
)
|
||||
|
||||
|
||||
def _ai(*, content="", tool_calls=None, response_metadata=None, additional_kwargs=None) -> AIMessage:
|
||||
return AIMessage(
|
||||
content=content,
|
||||
tool_calls=tool_calls or [],
|
||||
response_metadata=response_metadata or {},
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAICompatibleContentFilterDetector:
|
||||
def test_default_matches_content_filter(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
hit = d.detect(_ai(response_metadata={"finish_reason": "content_filter"}))
|
||||
assert hit is not None
|
||||
assert hit.detector == "openai_compatible_content_filter"
|
||||
assert hit.reason_field == "finish_reason"
|
||||
assert hit.reason_value == "content_filter"
|
||||
|
||||
def test_case_insensitive_match(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "CONTENT_FILTER"})) is not None
|
||||
|
||||
def test_other_finish_reasons_pass_through(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "stop"})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "tool_calls"})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "length"})) is None
|
||||
|
||||
def test_missing_metadata_passes_through(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai()) is None
|
||||
|
||||
def test_non_string_finish_reason_passes_through(self):
|
||||
# Some adapters may stash an enum or dict — must not raise.
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": 42})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": {"value": "content_filter"}})) is None
|
||||
|
||||
def test_falls_back_to_additional_kwargs(self):
|
||||
# Legacy adapters surface finish_reason via additional_kwargs.
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
hit = d.detect(_ai(additional_kwargs={"finish_reason": "content_filter"}))
|
||||
assert hit is not None
|
||||
|
||||
def test_configurable_extra_values(self):
|
||||
# Chinese providers sometimes use bespoke tokens.
|
||||
d = OpenAICompatibleContentFilterDetector(finish_reasons=["content_filter", "sensitive", "violation"])
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "sensitive"})) is not None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "violation"})) is not None
|
||||
# Original token still matches.
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "content_filter"})) is not None
|
||||
|
||||
def test_carries_azure_content_filter_results(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
filter_results = {"hate": {"filtered": True, "severity": "high"}}
|
||||
hit = d.detect(
|
||||
_ai(
|
||||
response_metadata={
|
||||
"finish_reason": "content_filter",
|
||||
"content_filter_results": filter_results,
|
||||
},
|
||||
)
|
||||
)
|
||||
assert hit is not None
|
||||
assert hit.extras["content_filter_results"] == filter_results
|
||||
|
||||
|
||||
class TestAnthropicRefusalDetector:
|
||||
def test_default_matches_refusal(self):
|
||||
hit = AnthropicRefusalDetector().detect(_ai(response_metadata={"stop_reason": "refusal"}))
|
||||
assert hit is not None
|
||||
assert hit.reason_field == "stop_reason"
|
||||
assert hit.reason_value == "refusal"
|
||||
|
||||
def test_other_stop_reasons_pass_through(self):
|
||||
d = AnthropicRefusalDetector()
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "end_turn"})) is None
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "tool_use"})) is None
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "max_tokens"})) is None
|
||||
|
||||
def test_anthropic_does_not_steal_finish_reason(self):
|
||||
# An OpenAI message must not accidentally trip the Anthropic detector.
|
||||
assert AnthropicRefusalDetector().detect(_ai(response_metadata={"finish_reason": "content_filter"})) is None
|
||||
|
||||
|
||||
class TestGeminiSafetyDetector:
|
||||
def test_default_set_covers_documented_reasons(self):
|
||||
d = GeminiSafetyDetector()
|
||||
for reason in (
|
||||
# text safety
|
||||
"SAFETY",
|
||||
"BLOCKLIST",
|
||||
"PROHIBITED_CONTENT",
|
||||
"SPII",
|
||||
"RECITATION",
|
||||
# image safety
|
||||
"IMAGE_SAFETY",
|
||||
"IMAGE_PROHIBITED_CONTENT",
|
||||
"IMAGE_RECITATION",
|
||||
):
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is not None, reason
|
||||
|
||||
def test_normal_termination_passes_through(self):
|
||||
d = GeminiSafetyDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "STOP"})) is None
|
||||
# MAX_TOKENS / LANGUAGE / NO_IMAGE / OTHER / IMAGE_OTHER /
|
||||
# MALFORMED_FUNCTION_CALL / UNEXPECTED_TOOL_CALL are intentionally
|
||||
# excluded from the default set — they are either normal termination,
|
||||
# capability mismatches, too broad (OTHER), or tool-call protocol
|
||||
# errors. See GeminiSafetyDetector docstring.
|
||||
for reason in (
|
||||
"MAX_TOKENS",
|
||||
"LANGUAGE",
|
||||
"NO_IMAGE",
|
||||
"OTHER",
|
||||
"IMAGE_OTHER",
|
||||
"MALFORMED_FUNCTION_CALL",
|
||||
"UNEXPECTED_TOOL_CALL",
|
||||
"FINISH_REASON_UNSPECIFIED",
|
||||
):
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is None, reason
|
||||
|
||||
def test_carries_safety_ratings(self):
|
||||
ratings = [{"category": "HARM_CATEGORY_HARASSMENT", "probability": "HIGH"}]
|
||||
hit = GeminiSafetyDetector().detect(
|
||||
_ai(
|
||||
response_metadata={
|
||||
"finish_reason": "SAFETY",
|
||||
"safety_ratings": ratings,
|
||||
},
|
||||
)
|
||||
)
|
||||
assert hit is not None
|
||||
assert hit.extras["safety_ratings"] == ratings
|
||||
|
||||
|
||||
class TestDefaultDetectorSet:
|
||||
def test_default_set_returns_three_detectors(self):
|
||||
dets = default_detectors()
|
||||
names = {d.name for d in dets}
|
||||
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
|
||||
|
||||
def test_default_set_returns_fresh_list(self):
|
||||
# Caller mutation must not affect later calls.
|
||||
first = default_detectors()
|
||||
first.clear()
|
||||
second = default_detectors()
|
||||
assert len(second) == 3
|
||||
|
||||
|
||||
class TestProtocolConformance:
|
||||
def test_builtins_satisfy_protocol(self):
|
||||
for d in default_detectors():
|
||||
assert isinstance(d, SafetyTerminationDetector)
|
||||
|
||||
def test_safety_termination_is_frozen(self):
|
||||
t = SafetyTermination(detector="x", reason_field="finish_reason", reason_value="content_filter")
|
||||
try:
|
||||
t.detector = "y" # type: ignore[misc]
|
||||
except Exception:
|
||||
return
|
||||
raise AssertionError("SafetyTermination should be frozen")
|
||||
@@ -5,6 +5,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.sandbox.exceptions import SandboxError
|
||||
from deerflow.sandbox.tools import (
|
||||
VIRTUAL_PATH_PREFIX,
|
||||
_apply_cwd_prefix,
|
||||
@@ -1140,6 +1141,170 @@ def test_str_replace_and_append_on_same_path_should_preserve_both_updates(monkey
|
||||
assert sandbox.content == "ALPHA\ntail\n"
|
||||
|
||||
|
||||
def test_write_file_tool_bounds_large_oserror_and_masks_local_paths(monkeypatch) -> None:
|
||||
class FailingSandbox:
|
||||
id = "sandbox-write-large-oserror"
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
host_path = f"{_THREAD_DATA['workspace_path']}/nested/output.txt"
|
||||
raise OSError(f"write failed at {host_path}\n{'A' * 12000}\nremote tail marker")
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
sandbox = FailingSandbox()
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: True)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.get_thread_data", lambda runtime: _THREAD_DATA)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.validate_local_tool_path", lambda path, thread_data: None)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.sandbox.tools._resolve_and_validate_user_data_path",
|
||||
lambda path, thread_data: f"{_THREAD_DATA['workspace_path']}/output.txt",
|
||||
)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="写入大文件失败",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="report body",
|
||||
)
|
||||
|
||||
assert len(result) <= 2000
|
||||
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
|
||||
assert "/tmp/deer-flow/threads/t1/user-data/workspace" not in result
|
||||
assert "/mnt/user-data/workspace/nested/output.txt" in result
|
||||
assert "remote tail marker" in result
|
||||
assert "[write_file error truncated:" in result
|
||||
|
||||
|
||||
def test_write_file_tool_preserves_short_oserror_without_truncation(monkeypatch) -> None:
|
||||
class FailingSandbox:
|
||||
id = "sandbox-write-short-oserror"
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
raise OSError("disk quota exceeded")
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
sandbox = FailingSandbox()
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="写入失败",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="tiny payload",
|
||||
)
|
||||
|
||||
assert result == "Error: Failed to write file '/mnt/user-data/workspace/output.txt': OSError: disk quota exceeded"
|
||||
assert "[write_file error truncated:" not in result
|
||||
|
||||
|
||||
def test_write_file_tool_bounds_large_sandbox_error(monkeypatch) -> None:
|
||||
class FailingSandbox:
|
||||
id = "sandbox-write-large-sandbox-error"
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
raise SandboxError(f"remote write rejected {'B' * 12000} final detail")
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
sandbox = FailingSandbox()
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="远端写入失败",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="tiny payload",
|
||||
)
|
||||
|
||||
assert len(result) <= 2000
|
||||
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
|
||||
assert "SandboxError: remote write rejected" in result
|
||||
assert "final detail" in result
|
||||
assert "[write_file error truncated:" in result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raised_error", "expected_fragment"),
|
||||
[
|
||||
pytest.param(
|
||||
PermissionError("permission denied"),
|
||||
"Error: Permission denied writing to file: /mnt/user-data/workspace/output.txt",
|
||||
id="permission",
|
||||
),
|
||||
pytest.param(
|
||||
IsADirectoryError("target is a directory"),
|
||||
"Error: Path is a directory, not a file: /mnt/user-data/workspace/output.txt",
|
||||
id="directory",
|
||||
),
|
||||
pytest.param(
|
||||
Exception("remote sandbox timeout"),
|
||||
"Exception: remote sandbox timeout",
|
||||
id="generic",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_write_file_tool_formats_all_other_failure_branches(
|
||||
monkeypatch,
|
||||
raised_error: Exception,
|
||||
expected_fragment: str,
|
||||
) -> None:
|
||||
class FailingSandbox:
|
||||
id = "sandbox-write-other-failure"
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
raise raised_error
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
sandbox = FailingSandbox()
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="验证错误分支格式化",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="tiny payload",
|
||||
)
|
||||
|
||||
assert "/mnt/user-data/workspace/output.txt" in result
|
||||
assert expected_fragment in result
|
||||
assert "[write_file error truncated:" not in result
|
||||
|
||||
|
||||
def test_write_file_tool_handles_sandbox_init_failure(monkeypatch) -> None:
|
||||
"""Regression for #3133 review: SandboxError raised during sandbox
|
||||
initialization (before the local `requested_path` assignment) must still
|
||||
surface as a bounded tool error rather than an UnboundLocalError.
|
||||
"""
|
||||
|
||||
def raise_sandbox_error(runtime):
|
||||
raise SandboxError("sandbox missing")
|
||||
|
||||
runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={})
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", raise_sandbox_error)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
result = write_file_tool.func(
|
||||
runtime=runtime,
|
||||
description="sandbox 初始化失败",
|
||||
path="/mnt/user-data/workspace/output.txt",
|
||||
content="tiny payload",
|
||||
)
|
||||
|
||||
assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result
|
||||
assert "SandboxError: sandbox missing" in result
|
||||
assert "[write_file error truncated:" not in result
|
||||
|
||||
|
||||
def test_file_operation_lock_memory_cleanup() -> None:
|
||||
"""Verify that released locks are eventually cleaned up by WeakValueDictionary.
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from types import SimpleNamespace
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from app.gateway.routers import skills as skills_router
|
||||
from deerflow.skills.storage import get_or_new_skill_storage
|
||||
from deerflow.skills.types import Skill
|
||||
@@ -38,7 +39,8 @@ def _make_skill(name: str, *, enabled: bool) -> Skill:
|
||||
|
||||
def _make_test_app(config) -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.state.config = config
|
||||
app.state.config = config # kept for any startup-style reads
|
||||
app.dependency_overrides[get_config] = lambda: config
|
||||
app.include_router(skills_router.router)
|
||||
return app
|
||||
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Regression tests for _find_usage_recorder callback shape handling.
|
||||
|
||||
Bytedance issue #3107 BUG-002: When LangChain passes ``config["callbacks"]`` as
|
||||
an ``AsyncCallbackManager`` (instead of a plain list), the previous
|
||||
``for cb in callbacks`` loop raised ``TypeError: 'AsyncCallbackManager' object
|
||||
is not iterable``. ToolErrorHandlingMiddleware then converted the entire ``task``
|
||||
tool call into an error ToolMessage, losing the subagent result.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackManager, CallbackManager
|
||||
|
||||
from deerflow.tools.builtins.task_tool import _find_usage_recorder
|
||||
|
||||
|
||||
class _RecorderHandler:
|
||||
def record_external_llm_usage_records(self, records):
|
||||
self.records = records
|
||||
|
||||
|
||||
class _OtherHandler:
|
||||
pass
|
||||
|
||||
|
||||
def _make_runtime(callbacks):
|
||||
return SimpleNamespace(config={"callbacks": callbacks})
|
||||
|
||||
|
||||
def test_find_usage_recorder_with_plain_list():
|
||||
recorder = _RecorderHandler()
|
||||
runtime = _make_runtime([_OtherHandler(), recorder])
|
||||
assert _find_usage_recorder(runtime) is recorder
|
||||
|
||||
|
||||
def test_find_usage_recorder_with_async_callback_manager():
|
||||
"""LangChain wraps callbacks in AsyncCallbackManager for async tool runs.
|
||||
|
||||
The old implementation raised TypeError here. The recorder lives on
|
||||
``manager.handlers``; we must look there too.
|
||||
"""
|
||||
recorder = _RecorderHandler()
|
||||
manager = AsyncCallbackManager(handlers=[_OtherHandler(), recorder])
|
||||
runtime = _make_runtime(manager)
|
||||
assert _find_usage_recorder(runtime) is recorder
|
||||
|
||||
|
||||
def test_find_usage_recorder_with_sync_callback_manager():
|
||||
"""Sync flavor of the same wrapper used by some langchain code paths."""
|
||||
recorder = _RecorderHandler()
|
||||
manager = CallbackManager(handlers=[recorder])
|
||||
runtime = _make_runtime(manager)
|
||||
assert _find_usage_recorder(runtime) is recorder
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_when_no_recorder():
|
||||
manager = AsyncCallbackManager(handlers=[_OtherHandler()])
|
||||
runtime = _make_runtime(manager)
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_handles_empty_manager():
|
||||
manager = AsyncCallbackManager(handlers=[])
|
||||
runtime = _make_runtime(manager)
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_for_none_runtime():
|
||||
assert _find_usage_recorder(None) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_when_callbacks_is_none():
|
||||
runtime = _make_runtime(None)
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_for_single_handler_object():
|
||||
"""A single handler instance (not wrapped in a list or manager) should not crash.
|
||||
|
||||
LangChain's contract is that ``config["callbacks"]`` is a list-or-manager,
|
||||
but we treat any other shape defensively rather than letting a ``for`` loop
|
||||
blow up at runtime.
|
||||
"""
|
||||
runtime = _make_runtime(_RecorderHandler())
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
|
||||
|
||||
def test_find_usage_recorder_returns_none_when_config_not_dict():
|
||||
"""Defensive: a runtime without a dict-shaped config should not raise."""
|
||||
runtime = SimpleNamespace(config="not-a-dict")
|
||||
assert _find_usage_recorder(runtime) is None
|
||||
@@ -53,3 +53,30 @@ def test_thread_token_usage_returns_stable_shape():
|
||||
},
|
||||
}
|
||||
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
|
||||
|
||||
|
||||
def test_thread_token_usage_can_include_active_runs():
|
||||
run_store = MagicMock()
|
||||
run_store.aggregate_tokens_by_thread = AsyncMock(
|
||||
return_value={
|
||||
"total_tokens": 175,
|
||||
"total_input_tokens": 120,
|
||||
"total_output_tokens": 55,
|
||||
"total_runs": 3,
|
||||
"by_model": {"unknown": {"tokens": 175, "runs": 3}},
|
||||
"by_caller": {
|
||||
"lead_agent": 145,
|
||||
"subagent": 25,
|
||||
"middleware": 5,
|
||||
},
|
||||
},
|
||||
)
|
||||
app = _make_app(run_store)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-1/token-usage?include_active=true")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["total_tokens"] == 175
|
||||
assert response.json()["total_runs"] == 3
|
||||
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1", include_active=True)
|
||||
|
||||
@@ -134,8 +134,14 @@ def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
|
||||
|
||||
assert captured["app_config"] is app_config
|
||||
assert len(middlewares) == 6
|
||||
assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware)
|
||||
# 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling,
|
||||
# SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware
|
||||
# (enabled by default — see SafetyFinishReasonConfig).
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
|
||||
assert len(middlewares) == 7
|
||||
assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares)
|
||||
assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware)
|
||||
|
||||
|
||||
def test_wrap_tool_call_passthrough_on_success():
|
||||
|
||||
@@ -11,6 +11,7 @@ from _router_auth_helpers import call_unwrapped, make_authed_test_app
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.deps import get_config
|
||||
from app.gateway.routers import uploads
|
||||
|
||||
|
||||
@@ -687,6 +688,7 @@ def test_upload_limits_endpoint_requires_thread_access():
|
||||
cfg.uploads = {}
|
||||
app = make_authed_test_app(owner_check_passes=False)
|
||||
app.state.config = cfg
|
||||
app.dependency_overrides[get_config] = lambda: cfg
|
||||
app.include_router(uploads.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
|
||||
Reference in New Issue
Block a user