mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
be0eae9825
* fix(runtime): suppress tool execution when provider safety-terminates with tool_calls When a provider stops generation for safety reasons (OpenAI/Moonshot finish_reason=content_filter, Anthropic stop_reason=refusal, Gemini finish_reason=SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/RECITATION/ IMAGE_SAFETY/...), the response may still carry truncated tool_calls. LangChain's tool router treats any non-empty tool_calls as executable, so partial arguments (e.g. write_file with a half-finished markdown) get dispatched and the agent loops on retry. Add SafetyFinishReasonMiddleware at after_model: detect safety termination via a pluggable detector registry, clear both structured tool_calls and raw additional_kwargs.tool_calls / function_call, preserve response_metadata.finish_reason for downstream observers, stamp additional_kwargs.safety_termination for traces, append a user-facing explanation to message content (list-aware for thinking blocks), and emit a safety_termination custom stream event so SSE consumers can reconcile any "tool starting..." UI. Default detectors cover OpenAI-compatible content_filter, Anthropic refusal, and Gemini safety enums (text + image). Custom providers are added via reflection (same pattern as guardrails). Wired into both lead-agent and subagent runtimes. Closes #3028 * fix(runtime): persist safety_termination as a middleware audit event Address review on #3035: the SSE custom event is great for live consumers but invisible to post-run audit. RunEventStore should carry its own row so operators can answer "which runs were safety-suppressed today?" from a single SQL query without joining the message body. Worker now exposes the run-scoped RunJournal via runtime.context["__run_journal"] (sentinel key, internal channel). SafetyFinishReasonMiddleware calls the previously-unused RunJournal.record_middleware, which emits event_type = "middleware:safety_termination" category = "middleware" content = {name, hook, action, changes={ detector, reason_field, reason_value, suppressed_tool_call_count, suppressed_tool_call_names, suppressed_tool_call_ids, message_id, extras}} Tool *arguments* are deliberately excluded — those are the very content the provider filtered and persisting them would defeat the purpose of the safety filter (per review note in #3035). Graceful skips when journal is absent (subagent runtime, unit tests, no-event-store local dev). Journal exceptions never propagate into the agent loop. Refs #3028 * fix(runtime): satisfy ruff format + address Copilot review - ruff format on safety_finish_reason_config.py and e2e demo (CI lint failed on ruff format --check; backend Makefile lint target runs ruff check AND ruff format --check). - Docstring on SafetyFinishReasonConfig now says resolve_variable to match the actual loader used in from_config (the wording was resolve_class previously; behavior is unchanged — resolve_variable mirrors how guardrails.provider is loaded). - Switch the AIMessage type check in SafetyFinishReasonMiddleware._apply from getattr(last, "type") == "ai" to isinstance(last, AIMessage), matching TokenUsageMiddleware / TodoMiddleware / ViewImageMiddleware / SummarizationMiddleware which are the dominant pattern. Refs #3028
226 lines
9.0 KiB
Python
226 lines
9.0 KiB
Python
"""End-to-end graph integration test for SafetyFinishReasonMiddleware.
|
|
|
|
Unit tests prove ``_apply`` does the right thing on a synthetic state.
|
|
This test does one level up: builds a real ``langchain.agents.create_agent``
|
|
graph with the SafetyFinishReasonMiddleware in place, feeds it a fake model
|
|
that returns ``finish_reason='content_filter'`` + tool_calls, and asserts:
|
|
|
|
1. The tool node is **not** invoked (the dangerous truncated tool call
|
|
is suppressed).
|
|
2. The final AIMessage in graph state has ``tool_calls == []``.
|
|
3. The observability ``safety_termination`` record is attached.
|
|
4. The user-facing explanation is appended to the message content.
|
|
|
|
This is the closest we can get to the issue's failure mode without a live
|
|
Moonshot key, and it proves the middleware actually gates LangChain's
|
|
tool router — not just rewrites state in isolation.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from langchain.agents import create_agent
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
|
from langchain_core.language_models import BaseChatModel
|
|
from langchain_core.messages import AIMessage, HumanMessage
|
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
|
from langchain_core.tools import tool
|
|
|
|
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
|
|
|
_TOOL_INVOCATIONS: list[dict[str, Any]] = []
|
|
|
|
|
|
@tool
|
|
def write_file(path: str, content: str) -> str:
|
|
"""Pretend to write *content* to *path*. Records the call for assertion."""
|
|
_TOOL_INVOCATIONS.append({"path": path, "content": content})
|
|
return f"wrote {len(content)} bytes to {path}"
|
|
|
|
|
|
class _ContentFilteredModel(BaseChatModel):
|
|
"""Fake chat model that mimics OpenAI/Moonshot's content_filter response.
|
|
|
|
First call returns finish_reason='content_filter' + a tool_call whose
|
|
arguments are visibly truncated. Second call (if reached) returns a
|
|
normal text completion so the agent can terminate cleanly.
|
|
"""
|
|
|
|
call_count: int = 0
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
return "fake-content-filtered"
|
|
|
|
def bind_tools(self, tools, **kwargs):
|
|
# create_agent binds tools onto the model; we don't actually need
|
|
# to bind anything since responses are hard-coded, but the method
|
|
# must not raise.
|
|
return self
|
|
|
|
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
|
self.call_count += 1
|
|
if self.call_count == 1:
|
|
message = AIMessage(
|
|
content="Here is the report:\n# Weekly Politics\n- Meeting time: 2026-05-12—",
|
|
tool_calls=[
|
|
{
|
|
"id": "call_truncated_1",
|
|
"name": "write_file",
|
|
"args": {
|
|
"path": "/mnt/user-data/outputs/report.md",
|
|
"content": "# Weekly Politics\n- Meeting time: 2026-05-12—",
|
|
},
|
|
}
|
|
],
|
|
response_metadata={"finish_reason": "content_filter", "model_name": "fake-kimi"},
|
|
)
|
|
else:
|
|
message = AIMessage(content="ack", response_metadata={"finish_reason": "stop"})
|
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
|
|
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
|
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
|
|
|
|
|
class _InspectMiddleware(AgentMiddleware):
|
|
"""Captures the messages list at every model entry so we can assert
|
|
no synthetic tool result was injected back into the conversation."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.observed: list[list[Any]] = []
|
|
|
|
def wrap_model_call(self, request: ModelRequest, handler) -> ModelResponse:
|
|
self.observed.append(list(request.messages))
|
|
return handler(request)
|
|
|
|
|
|
def test_content_filter_with_tool_calls_does_not_invoke_tool_node():
|
|
_TOOL_INVOCATIONS.clear()
|
|
inspector = _InspectMiddleware()
|
|
|
|
agent = create_agent(
|
|
model=_ContentFilteredModel(),
|
|
tools=[write_file],
|
|
# Inspector first so its after_model is registered; Safety last in
|
|
# the list so it executes first under LIFO (matches production wiring).
|
|
middleware=[inspector, SafetyFinishReasonMiddleware()],
|
|
)
|
|
|
|
result = agent.invoke({"messages": [HumanMessage(content="write me a report")]})
|
|
|
|
# Critical assertion: the dangerous truncated tool call must NOT have
|
|
# been executed. This is the entire point of the middleware.
|
|
assert _TOOL_INVOCATIONS == [], f"write_file was invoked despite content_filter: {_TOOL_INVOCATIONS}"
|
|
|
|
# Final AIMessage has no tool calls left.
|
|
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
|
|
assert final_ai.tool_calls == []
|
|
|
|
# Observability stamp is present.
|
|
record = final_ai.additional_kwargs.get("safety_termination")
|
|
assert record is not None
|
|
assert record["detector"] == "openai_compatible_content_filter"
|
|
assert record["reason_field"] == "finish_reason"
|
|
assert record["reason_value"] == "content_filter"
|
|
assert record["suppressed_tool_call_count"] == 1
|
|
assert record["suppressed_tool_call_names"] == ["write_file"]
|
|
|
|
# User-facing explanation is appended.
|
|
assert "safety-related signal" in final_ai.content
|
|
# Original partial text preserved (we don't throw away what the user
|
|
# already saw in the stream — see middleware docstring).
|
|
assert "Weekly Politics" in final_ai.content
|
|
|
|
# finish_reason on response_metadata is preserved (so SSE / converters
|
|
# downstream still see the real provider reason).
|
|
assert final_ai.response_metadata.get("finish_reason") == "content_filter"
|
|
|
|
|
|
def test_content_filter_without_tool_calls_passes_through_unchanged():
|
|
"""No tool calls => issue scope says don't intervene; the partial
|
|
response should be delivered as-is so the user sees what they got."""
|
|
_TOOL_INVOCATIONS.clear()
|
|
|
|
class _NoToolModel(BaseChatModel):
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
return "fake-no-tool"
|
|
|
|
def bind_tools(self, tools, **kwargs):
|
|
return self
|
|
|
|
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
|
msg = AIMessage(
|
|
content="Partial answer truncated by safety filter",
|
|
response_metadata={"finish_reason": "content_filter"},
|
|
)
|
|
return ChatResult(generations=[ChatGeneration(message=msg)])
|
|
|
|
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
|
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
|
|
|
agent = create_agent(
|
|
model=_NoToolModel(),
|
|
tools=[write_file],
|
|
middleware=[SafetyFinishReasonMiddleware()],
|
|
)
|
|
result = agent.invoke({"messages": [HumanMessage(content="hi")]})
|
|
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
|
|
|
|
# Content untouched.
|
|
assert final_ai.content == "Partial answer truncated by safety filter"
|
|
# No safety_termination stamp because we didn't intervene.
|
|
assert "safety_termination" not in final_ai.additional_kwargs
|
|
# tool node never ran (there were no tool calls in the first place).
|
|
assert _TOOL_INVOCATIONS == []
|
|
|
|
|
|
def test_normal_tool_call_round_trip_is_not_affected():
|
|
"""Regression: a healthy finish_reason='tool_calls' response must still
|
|
execute the tool. The middleware must not over-fire."""
|
|
_TOOL_INVOCATIONS.clear()
|
|
|
|
class _HealthyToolModel(BaseChatModel):
|
|
call_count: int = 0
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
return "fake-healthy"
|
|
|
|
def bind_tools(self, tools, **kwargs):
|
|
return self
|
|
|
|
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
|
self.call_count += 1
|
|
if self.call_count == 1:
|
|
msg = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"id": "call_ok",
|
|
"name": "write_file",
|
|
"args": {"path": "/tmp/ok", "content": "complete content"},
|
|
}
|
|
],
|
|
response_metadata={"finish_reason": "tool_calls"},
|
|
)
|
|
else:
|
|
msg = AIMessage(content="done", response_metadata={"finish_reason": "stop"})
|
|
return ChatResult(generations=[ChatGeneration(message=msg)])
|
|
|
|
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
|
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
|
|
|
agent = create_agent(
|
|
model=_HealthyToolModel(),
|
|
tools=[write_file],
|
|
middleware=[SafetyFinishReasonMiddleware()],
|
|
)
|
|
agent.invoke({"messages": [HumanMessage(content="write")]})
|
|
|
|
assert _TOOL_INVOCATIONS == [{"path": "/tmp/ok", "content": "complete content"}]
|