mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
be0eae9825
* fix(runtime): suppress tool execution when provider safety-terminates with tool_calls When a provider stops generation for safety reasons (OpenAI/Moonshot finish_reason=content_filter, Anthropic stop_reason=refusal, Gemini finish_reason=SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/RECITATION/ IMAGE_SAFETY/...), the response may still carry truncated tool_calls. LangChain's tool router treats any non-empty tool_calls as executable, so partial arguments (e.g. write_file with a half-finished markdown) get dispatched and the agent loops on retry. Add SafetyFinishReasonMiddleware at after_model: detect safety termination via a pluggable detector registry, clear both structured tool_calls and raw additional_kwargs.tool_calls / function_call, preserve response_metadata.finish_reason for downstream observers, stamp additional_kwargs.safety_termination for traces, append a user-facing explanation to message content (list-aware for thinking blocks), and emit a safety_termination custom stream event so SSE consumers can reconcile any "tool starting..." UI. Default detectors cover OpenAI-compatible content_filter, Anthropic refusal, and Gemini safety enums (text + image). Custom providers are added via reflection (same pattern as guardrails). Wired into both lead-agent and subagent runtimes. Closes #3028 * fix(runtime): persist safety_termination as a middleware audit event Address review on #3035: the SSE custom event is great for live consumers but invisible to post-run audit. RunEventStore should carry its own row so operators can answer "which runs were safety-suppressed today?" from a single SQL query without joining the message body. Worker now exposes the run-scoped RunJournal via runtime.context["__run_journal"] (sentinel key, internal channel). SafetyFinishReasonMiddleware calls the previously-unused RunJournal.record_middleware, which emits event_type = "middleware:safety_termination" category = "middleware" content = {name, hook, action, changes={ detector, reason_field, reason_value, suppressed_tool_call_count, suppressed_tool_call_names, suppressed_tool_call_ids, message_id, extras}} Tool *arguments* are deliberately excluded — those are the very content the provider filtered and persisting them would defeat the purpose of the safety filter (per review note in #3035). Graceful skips when journal is absent (subagent runtime, unit tests, no-event-store local dev). Journal exceptions never propagate into the agent loop. Refs #3028 * fix(runtime): satisfy ruff format + address Copilot review - ruff format on safety_finish_reason_config.py and e2e demo (CI lint failed on ruff format --check; backend Makefile lint target runs ruff check AND ruff format --check). - Docstring on SafetyFinishReasonConfig now says resolve_variable to match the actual loader used in from_config (the wording was resolve_class previously; behavior is unchanged — resolve_variable mirrors how guardrails.provider is loaded). - Switch the AIMessage type check in SafetyFinishReasonMiddleware._apply from getattr(last, "type") == "ai" to isinstance(last, AIMessage), matching TokenUsageMiddleware / TodoMiddleware / ViewImageMiddleware / SummarizationMiddleware which are the dominant pattern. Refs #3028
254 lines
8.9 KiB
Python
254 lines
8.9 KiB
Python
import sys
|
|
from types import ModuleType, SimpleNamespace
|
|
|
|
import pytest
|
|
from langchain_core.messages import ToolMessage
|
|
from langgraph.errors import GraphInterrupt
|
|
|
|
from deerflow.agents.middlewares.tool_error_handling_middleware import (
|
|
ToolErrorHandlingMiddleware,
|
|
build_subagent_runtime_middlewares,
|
|
)
|
|
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
|
from deerflow.config.app_config import AppConfig, CircuitBreakerConfig
|
|
from deerflow.config.guardrails_config import GuardrailsConfig
|
|
from deerflow.config.model_config import ModelConfig
|
|
from deerflow.config.sandbox_config import SandboxConfig
|
|
|
|
|
|
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
|
|
tool_call = {"name": name}
|
|
if tool_call_id is not None:
|
|
tool_call["id"] = tool_call_id
|
|
return SimpleNamespace(tool_call=tool_call)
|
|
|
|
|
|
def _module(name: str, **attrs):
|
|
module = ModuleType(name)
|
|
for key, value in attrs.items():
|
|
setattr(module, key, value)
|
|
return module
|
|
|
|
|
|
def _make_app_config(*, supports_vision: bool = False) -> AppConfig:
|
|
return AppConfig(
|
|
models=[
|
|
ModelConfig(
|
|
name="test-model",
|
|
display_name="test-model",
|
|
description=None,
|
|
use="langchain_openai:ChatOpenAI",
|
|
model="test-model",
|
|
supports_vision=supports_vision,
|
|
)
|
|
],
|
|
sandbox=SandboxConfig(use="test"),
|
|
guardrails=GuardrailsConfig(enabled=False),
|
|
circuit_breaker=CircuitBreakerConfig(failure_threshold=7, recovery_timeout_sec=11),
|
|
)
|
|
|
|
|
|
def _stub_runtime_middleware_imports(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
class FakeMiddleware:
|
|
def __init__(self, *args, **kwargs):
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
|
|
class FakeLLMErrorHandlingMiddleware:
|
|
def __init__(self, *, app_config):
|
|
self.app_config = app_config
|
|
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.agents.middlewares.llm_error_handling_middleware",
|
|
_module(
|
|
"deerflow.agents.middlewares.llm_error_handling_middleware",
|
|
LLMErrorHandlingMiddleware=FakeLLMErrorHandlingMiddleware,
|
|
),
|
|
)
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.agents.middlewares.thread_data_middleware",
|
|
_module("deerflow.agents.middlewares.thread_data_middleware", ThreadDataMiddleware=FakeMiddleware),
|
|
)
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.sandbox.middleware",
|
|
_module("deerflow.sandbox.middleware", SandboxMiddleware=FakeMiddleware),
|
|
)
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.agents.middlewares.dangling_tool_call_middleware",
|
|
_module("deerflow.agents.middlewares.dangling_tool_call_middleware", DanglingToolCallMiddleware=FakeMiddleware),
|
|
)
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.agents.middlewares.sandbox_audit_middleware",
|
|
_module("deerflow.agents.middlewares.sandbox_audit_middleware", SandboxAuditMiddleware=FakeMiddleware),
|
|
)
|
|
|
|
|
|
def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware(monkeypatch: pytest.MonkeyPatch):
|
|
captured: dict[str, object] = {}
|
|
|
|
class FakeMiddleware:
|
|
def __init__(self, *args, **kwargs):
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
|
|
class FakeLLMErrorHandlingMiddleware:
|
|
def __init__(self, *, app_config):
|
|
captured["app_config"] = app_config
|
|
|
|
app_config = _make_app_config()
|
|
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.agents.middlewares.llm_error_handling_middleware",
|
|
_module(
|
|
"deerflow.agents.middlewares.llm_error_handling_middleware",
|
|
LLMErrorHandlingMiddleware=FakeLLMErrorHandlingMiddleware,
|
|
),
|
|
)
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.agents.middlewares.thread_data_middleware",
|
|
_module("deerflow.agents.middlewares.thread_data_middleware", ThreadDataMiddleware=FakeMiddleware),
|
|
)
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.sandbox.middleware",
|
|
_module("deerflow.sandbox.middleware", SandboxMiddleware=FakeMiddleware),
|
|
)
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.agents.middlewares.dangling_tool_call_middleware",
|
|
_module("deerflow.agents.middlewares.dangling_tool_call_middleware", DanglingToolCallMiddleware=FakeMiddleware),
|
|
)
|
|
monkeypatch.setitem(
|
|
sys.modules,
|
|
"deerflow.agents.middlewares.sandbox_audit_middleware",
|
|
_module("deerflow.agents.middlewares.sandbox_audit_middleware", SandboxAuditMiddleware=FakeMiddleware),
|
|
)
|
|
|
|
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
|
|
|
|
assert captured["app_config"] is app_config
|
|
# 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling,
|
|
# SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware
|
|
# (enabled by default — see SafetyFinishReasonConfig).
|
|
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
|
|
|
assert len(middlewares) == 7
|
|
assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares)
|
|
assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware)
|
|
|
|
|
|
def test_wrap_tool_call_passthrough_on_success():
|
|
middleware = ToolErrorHandlingMiddleware()
|
|
req = _request()
|
|
expected = ToolMessage(content="ok", tool_call_id="tc-1", name="web_search")
|
|
|
|
result = middleware.wrap_tool_call(req, lambda _req: expected)
|
|
|
|
assert result is expected
|
|
|
|
|
|
def test_wrap_tool_call_returns_error_tool_message_on_exception():
|
|
middleware = ToolErrorHandlingMiddleware()
|
|
req = _request(name="web_search", tool_call_id="tc-42")
|
|
|
|
def _boom(_req):
|
|
raise RuntimeError("network down")
|
|
|
|
result = middleware.wrap_tool_call(req, _boom)
|
|
|
|
assert isinstance(result, ToolMessage)
|
|
assert result.tool_call_id == "tc-42"
|
|
assert result.name == "web_search"
|
|
assert result.status == "error"
|
|
assert "Tool 'web_search' failed" in result.text
|
|
assert "network down" in result.text
|
|
|
|
|
|
def test_wrap_tool_call_uses_fallback_tool_call_id_when_missing():
|
|
middleware = ToolErrorHandlingMiddleware()
|
|
req = _request(name="mcp_tool", tool_call_id=None)
|
|
|
|
def _boom(_req):
|
|
raise ValueError("bad request")
|
|
|
|
result = middleware.wrap_tool_call(req, _boom)
|
|
|
|
assert isinstance(result, ToolMessage)
|
|
assert result.tool_call_id == "missing_tool_call_id"
|
|
assert result.name == "mcp_tool"
|
|
assert result.status == "error"
|
|
|
|
|
|
def test_wrap_tool_call_reraises_graph_interrupt():
|
|
middleware = ToolErrorHandlingMiddleware()
|
|
req = _request(name="ask_clarification", tool_call_id="tc-int")
|
|
|
|
def _interrupt(_req):
|
|
raise GraphInterrupt(())
|
|
|
|
with pytest.raises(GraphInterrupt):
|
|
middleware.wrap_tool_call(req, _interrupt)
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_awrap_tool_call_returns_error_tool_message_on_exception():
|
|
middleware = ToolErrorHandlingMiddleware()
|
|
req = _request(name="mcp_tool", tool_call_id="tc-async")
|
|
|
|
async def _boom(_req):
|
|
raise TimeoutError("request timed out")
|
|
|
|
result = await middleware.awrap_tool_call(req, _boom)
|
|
|
|
assert isinstance(result, ToolMessage)
|
|
assert result.tool_call_id == "tc-async"
|
|
assert result.name == "mcp_tool"
|
|
assert result.status == "error"
|
|
assert "request timed out" in result.text
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_awrap_tool_call_reraises_graph_interrupt():
|
|
middleware = ToolErrorHandlingMiddleware()
|
|
req = _request(name="ask_clarification", tool_call_id="tc-int-async")
|
|
|
|
async def _interrupt(_req):
|
|
raise GraphInterrupt(())
|
|
|
|
with pytest.raises(GraphInterrupt):
|
|
await middleware.awrap_tool_call(req, _interrupt)
|
|
|
|
|
|
def test_subagent_runtime_middlewares_include_view_image_for_vision_model(monkeypatch):
|
|
app_config = _make_app_config(supports_vision=True)
|
|
_stub_runtime_middleware_imports(monkeypatch)
|
|
|
|
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name="test-model")
|
|
|
|
assert any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
|
|
|
|
|
|
def test_subagent_runtime_middlewares_include_view_image_for_default_vision_model(monkeypatch):
|
|
app_config = _make_app_config(supports_vision=True)
|
|
_stub_runtime_middleware_imports(monkeypatch)
|
|
|
|
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=None)
|
|
|
|
assert any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
|
|
|
|
|
|
def test_subagent_runtime_middlewares_skip_view_image_for_text_model(monkeypatch):
|
|
app_config = _make_app_config(supports_vision=False)
|
|
_stub_runtime_middleware_imports(monkeypatch)
|
|
|
|
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name="test-model")
|
|
|
|
assert not any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
|