Files
deer-flow/backend/tests/test_safety_finish_reason_middleware.py
T
Xinmin Zeng be0eae9825 fix(runtime): suppress tool execution when provider safety-terminates with tool_calls (#3035)
* fix(runtime): suppress tool execution when provider safety-terminates with tool_calls

When a provider stops generation for safety reasons (OpenAI/Moonshot
finish_reason=content_filter, Anthropic stop_reason=refusal, Gemini
finish_reason=SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/RECITATION/
IMAGE_SAFETY/...), the response may still carry truncated tool_calls.
LangChain's tool router treats any non-empty tool_calls as executable,
so partial arguments (e.g. write_file with a half-finished markdown)
get dispatched and the agent loops on retry.

Add SafetyFinishReasonMiddleware at after_model: detect safety
termination via a pluggable detector registry, clear both structured
tool_calls and raw additional_kwargs.tool_calls / function_call,
preserve response_metadata.finish_reason for downstream observers,
stamp additional_kwargs.safety_termination for traces, append a
user-facing explanation to message content (list-aware for thinking
blocks), and emit a safety_termination custom stream event so SSE
consumers can reconcile any "tool starting..." UI.

Default detectors cover OpenAI-compatible content_filter, Anthropic
refusal, and Gemini safety enums (text + image). Custom providers are
added via reflection (same pattern as guardrails). Wired into both
lead-agent and subagent runtimes.

Closes #3028

* fix(runtime): persist safety_termination as a middleware audit event

Address review on #3035: the SSE custom event is great for live
consumers but invisible to post-run audit. RunEventStore should carry
its own row so operators can answer "which runs were safety-suppressed
today?" from a single SQL query without joining the message body.

Worker now exposes the run-scoped RunJournal via
runtime.context["__run_journal"] (sentinel key, internal channel).
SafetyFinishReasonMiddleware calls the previously-unused
RunJournal.record_middleware, which emits

  event_type = "middleware:safety_termination"
  category   = "middleware"
  content    = {name, hook, action, changes={
                  detector, reason_field, reason_value,
                  suppressed_tool_call_count,
                  suppressed_tool_call_names,
                  suppressed_tool_call_ids,
                  message_id, extras}}

Tool *arguments* are deliberately excluded — those are the very content
the provider filtered and persisting them would defeat the purpose of
the safety filter (per review note in #3035).

Graceful skips when journal is absent (subagent runtime, unit tests,
no-event-store local dev). Journal exceptions never propagate into the
agent loop.

Refs #3028

* fix(runtime): satisfy ruff format + address Copilot review

- ruff format on safety_finish_reason_config.py and e2e demo (CI lint
  failed on ruff format --check; backend Makefile lint target runs
  ruff check AND ruff format --check).
- Docstring on SafetyFinishReasonConfig now says resolve_variable to
  match the actual loader used in from_config (the wording was
  resolve_class previously; behavior is unchanged — resolve_variable
  mirrors how guardrails.provider is loaded).
- Switch the AIMessage type check in SafetyFinishReasonMiddleware._apply
  from getattr(last, "type") == "ai" to isinstance(last, AIMessage),
  matching TokenUsageMiddleware / TodoMiddleware / ViewImageMiddleware
  / SummarizationMiddleware which are the dominant pattern.

Refs #3028
2026-05-22 21:20:28 +08:00

652 lines
24 KiB
Python

"""Unit tests for SafetyFinishReasonMiddleware."""
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
from deerflow.agents.middlewares.safety_termination_detectors import (
SafetyTermination,
)
from deerflow.config.safety_finish_reason_config import (
SafetyDetectorConfig,
SafetyFinishReasonConfig,
)
def _runtime(thread_id="t-1"):
runtime = MagicMock()
runtime.context = {"thread_id": thread_id}
return runtime
def _ai(
*,
content="",
tool_calls=None,
response_metadata=None,
additional_kwargs=None,
):
return AIMessage(
content=content,
tool_calls=tool_calls or [],
response_metadata=response_metadata or {},
additional_kwargs=additional_kwargs or {},
)
def _write_call(idx=1, content_text="半截"):
return {
"id": f"call_write_{idx}",
"name": "write_file",
"args": {"path": "/mnt/user-data/outputs/x.md", "content": content_text},
}
class AlwaysHitDetector:
"""Test fixture: always reports the given termination."""
name = "always_hit"
def __init__(self, *, reason_field="finish_reason", reason_value="content_filter", extras=None):
self.reason_field = reason_field
self.reason_value = reason_value
self.extras = extras or {}
def detect(self, message):
return SafetyTermination(
detector=self.name,
reason_field=self.reason_field,
reason_value=self.reason_value,
extras=self.extras,
)
class NeverHitDetector:
name = "never_hit"
def detect(self, message):
return None
class RaisingDetector:
name = "raising"
def detect(self, message):
raise RuntimeError("boom")
# ---------------------------------------------------------------------------
# Core trigger behaviour
# ---------------------------------------------------------------------------
class TestTriggerCriteria:
def test_content_filter_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="partial",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
patched = result["messages"][0]
assert patched.tool_calls == []
def test_content_filter_without_tool_calls_passes_through(self):
"""issue scope: when there are no tool calls the partial text is a
legitimate final response and should not be rewritten."""
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="partial response",
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_normal_tool_calls_pass_through(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "tool_calls"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_normal_stop_with_tool_calls_pass_through(self):
# Some providers report finish_reason='stop' for tool-call messages.
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "stop"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_empty_message_list_passes_through(self):
mw = SafetyFinishReasonMiddleware()
assert mw._apply({"messages": []}, _runtime()) is None
def test_non_ai_last_message_passes_through(self):
mw = SafetyFinishReasonMiddleware()
state = {"messages": [HumanMessage(content="hi"), SystemMessage(content="sys")]}
assert mw._apply(state, _runtime()) is None
def test_anthropic_refusal_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"stop_reason": "refusal"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].tool_calls == []
def test_gemini_safety_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "SAFETY"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].tool_calls == []
# ---------------------------------------------------------------------------
# Message rewriting
# ---------------------------------------------------------------------------
class TestMessageRewrite:
def test_clears_structured_tool_calls(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call(1), _write_call(2)],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, _runtime())
patched = result["messages"][0]
assert patched.tool_calls == []
def test_clears_raw_additional_kwargs_tool_calls(self):
"""Critical defence-in-depth: DanglingToolCallMiddleware will recover
tool calls from additional_kwargs.tool_calls if we forget them, which
would re-emit a synthetic ToolMessage downstream and confuse the
model. We must wipe both."""
mw = SafetyFinishReasonMiddleware()
raw_tool_calls = [
{
"id": "call_write_1",
"type": "function",
"function": {"name": "write_file", "arguments": '{"path": "/x"}'},
}
]
state = {
"messages": [
_ai(
tool_calls=[_write_call(1)],
response_metadata={"finish_reason": "content_filter"},
additional_kwargs={
"tool_calls": raw_tool_calls,
"function_call": {"name": "write_file", "arguments": "{}"},
},
)
]
}
result = mw._apply(state, _runtime())
patched = result["messages"][0]
assert "tool_calls" not in patched.additional_kwargs
assert "function_call" not in patched.additional_kwargs
def test_preserves_other_additional_kwargs(self):
# vLLM puts reasoning under additional_kwargs.reasoning; Anthropic
# may carry other provider-specific keys. They must not be wiped.
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
additional_kwargs={
"reasoning": "thinking text",
"custom_provider_field": {"x": 1},
},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.additional_kwargs["reasoning"] == "thinking text"
assert patched.additional_kwargs["custom_provider_field"] == {"x": 1}
def test_writes_observability_field(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call(1), _write_call(2)],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
record = patched.additional_kwargs["safety_termination"]
assert record["detector"] == "openai_compatible_content_filter"
assert record["reason_field"] == "finish_reason"
assert record["reason_value"] == "content_filter"
assert record["suppressed_tool_call_count"] == 2
assert record["suppressed_tool_call_names"] == ["write_file", "write_file"]
def test_preserves_response_metadata_finish_reason(self):
"""Downstream SSE converters read response_metadata.finish_reason —
we want them to see the *real* provider reason, not 'stop'."""
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter", "model_name": "kimi-k2"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.response_metadata["finish_reason"] == "content_filter"
assert patched.response_metadata["model_name"] == "kimi-k2"
def test_appends_user_facing_explanation_to_str_content(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="some partial text",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, str)
assert patched.content.startswith("some partial text")
assert "safety-related signal" in patched.content
def test_handles_empty_content(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, str)
assert "safety-related signal" in patched.content
def test_handles_list_content_thinking_blocks(self):
"""Anthropic thinking / vLLM reasoning models emit content blocks.
Naively concatenating a string would raise TypeError."""
mw = SafetyFinishReasonMiddleware()
thinking_blocks = [
{"type": "thinking", "text": "let me consider..."},
{"type": "text", "text": "partial answer"},
]
state = {
"messages": [
_ai(
content=thinking_blocks,
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, list)
assert patched.content[:2] == thinking_blocks
assert patched.content[-1]["type"] == "text"
assert "safety-related signal" in patched.content[-1]["text"]
def test_idempotent_on_already_cleared_message(self):
# Re-running the middleware on a message we already cleared must not
# re-trigger (tool_calls is now empty → fast passthrough).
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
first = mw._apply(state, _runtime())
state2 = {"messages": [first["messages"][0]]}
second = mw._apply(state2, _runtime())
assert second is None
def test_preserves_message_id_for_add_messages_replacement(self):
"""LangGraph's add_messages reducer treats same-id messages as
replacements. model_copy keeps id by default."""
mw = SafetyFinishReasonMiddleware()
original = _ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
# AIMessage auto-generates id; capture it
original_id = original.id
state = {"messages": [original]}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.id == original_id
# ---------------------------------------------------------------------------
# Detector wiring
# ---------------------------------------------------------------------------
class TestDetectorWiring:
def test_iterates_detectors_in_order(self):
first = AlwaysHitDetector(reason_value="first")
second = AlwaysHitDetector(reason_value="second")
mw = SafetyFinishReasonMiddleware(detectors=[first, second])
state = {"messages": [_ai(tool_calls=[_write_call()])]}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.additional_kwargs["safety_termination"]["reason_value"] == "first"
def test_returns_none_when_no_detector_matches(self):
mw = SafetyFinishReasonMiddleware(detectors=[NeverHitDetector(), NeverHitDetector()])
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_buggy_detector_does_not_break_run(self):
mw = SafetyFinishReasonMiddleware(detectors=[RaisingDetector(), AlwaysHitDetector()])
state = {"messages": [_ai(tool_calls=[_write_call()])]}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].additional_kwargs["safety_termination"]["detector"] == "always_hit"
def test_constructor_copies_detectors(self):
"""Caller mutation after construction must not leak into us."""
detectors = [AlwaysHitDetector()]
mw = SafetyFinishReasonMiddleware(detectors=detectors)
detectors.clear()
state = {"messages": [_ai(tool_calls=[_write_call()])]}
assert mw._apply(state, _runtime()) is not None
# ---------------------------------------------------------------------------
# from_config
# ---------------------------------------------------------------------------
class TestFromConfig:
def test_default_config_uses_builtin_detectors(self):
mw = SafetyFinishReasonMiddleware.from_config(SafetyFinishReasonConfig())
assert len(mw._detectors) == 3
names = {d.name for d in mw._detectors}
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
def test_custom_detectors_loaded_via_reflection(self):
cfg = SafetyFinishReasonConfig(
detectors=[
SafetyDetectorConfig(
use="deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector",
config={"finish_reasons": ["custom_filter"]},
),
]
)
mw = SafetyFinishReasonMiddleware.from_config(cfg)
assert len(mw._detectors) == 1
# Confirm the kwargs propagated.
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "custom_filter"},
)
]
}
assert mw._apply(state, _runtime()) is not None
# Default token no longer matches.
state2 = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state2, _runtime()) is None
def test_empty_detector_list_rejected(self):
cfg = SafetyFinishReasonConfig(detectors=[])
with pytest.raises(ValueError, match="enabled=false"):
SafetyFinishReasonMiddleware.from_config(cfg)
def test_non_detector_class_rejected(self):
cfg = SafetyFinishReasonConfig(
detectors=[SafetyDetectorConfig(use="builtins:dict")],
)
with pytest.raises(TypeError):
SafetyFinishReasonMiddleware.from_config(cfg)
# ---------------------------------------------------------------------------
# Stream event
# ---------------------------------------------------------------------------
class TestAuditEvent:
"""Verify SafetyFinishReasonMiddleware records a `middleware:safety_termination`
audit event via RunJournal.record_middleware when the run-scoped journal is
exposed under runtime.context["__run_journal"].
Background: review on PR #3035 — SSE custom event handles live consumers,
but post-run audit needs a row in run_events that can be queried with one
SQL statement (no JOIN against message body).
"""
def _runtime_with_journal(self, journal):
runtime = MagicMock()
runtime.context = {"thread_id": "t-audit", "__run_journal": journal}
return runtime
def test_records_audit_event_when_journal_present(self):
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
tc = _write_call(1)
state = {
"messages": [
_ai(
content="partial",
tool_calls=[tc],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, self._runtime_with_journal(journal))
assert result is not None
journal.record_middleware.assert_called_once()
call = journal.record_middleware.call_args
# tag is positional or kwarg depending on call style; we use kwargs.
assert call.kwargs["tag"] == "safety_termination"
assert call.kwargs["name"] == "SafetyFinishReasonMiddleware"
assert call.kwargs["hook"] == "after_model"
assert call.kwargs["action"] == "suppress_tool_calls"
changes = call.kwargs["changes"]
assert changes["detector"] == "openai_compatible_content_filter"
assert changes["reason_field"] == "finish_reason"
assert changes["reason_value"] == "content_filter"
assert changes["suppressed_tool_call_count"] == 1
assert changes["suppressed_tool_call_names"] == ["write_file"]
assert changes["suppressed_tool_call_ids"] == ["call_write_1"]
assert "message_id" in changes
assert isinstance(changes["extras"], dict)
def test_audit_event_never_carries_tool_arguments(self):
"""PR #3035 review IMPORTANT: tool args are the filtered content itself
and must NOT be persisted to run_events under any circumstance."""
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
sensitive_tc = {
"id": "call_x",
"name": "write_file",
"args": {"path": "/x", "content": "FILTERED_CONTENT_DO_NOT_PERSIST"},
}
state = {
"messages": [
_ai(
tool_calls=[sensitive_tc],
response_metadata={"finish_reason": "content_filter"},
)
]
}
mw._apply(state, self._runtime_with_journal(journal))
flat = repr(journal.record_middleware.call_args)
assert "FILTERED_CONTENT_DO_NOT_PERSIST" not in flat, "tool arguments must not leak into audit event"
assert "args" not in journal.record_middleware.call_args.kwargs["changes"]
def test_no_journal_in_runtime_context_is_silently_skipped(self):
"""Subagent runtime / unit tests / no-event-store paths have no journal.
Middleware must still intervene and clear tool_calls — only the audit
event is skipped."""
mw = SafetyFinishReasonMiddleware()
runtime = MagicMock()
runtime.context = {"thread_id": "t-noj"} # no __run_journal
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Should not raise; should still clear tool_calls.
result = mw._apply(state, runtime)
assert result is not None
assert result["messages"][0].tool_calls == []
def test_journal_record_exception_does_not_break_run(self):
"""Buggy journal must never propagate an exception into the agent loop."""
journal = MagicMock()
journal.record_middleware.side_effect = RuntimeError("db down")
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Must not raise.
result = mw._apply(state, self._runtime_with_journal(journal))
assert result is not None
assert result["messages"][0].tool_calls == []
def test_no_record_when_passthrough(self):
"""When the middleware does NOT intervene, no audit event is written."""
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "tool_calls"}, # healthy
)
]
}
assert mw._apply(state, self._runtime_with_journal(journal)) is None
journal.record_middleware.assert_not_called()
class TestStreamEvent:
def test_emits_event_when_writer_available(self, monkeypatch):
captured: list = []
def fake_writer(payload):
captured.append(payload)
# Patch get_stream_writer at the symbol-resolution site.
import langgraph.config
monkeypatch.setattr(langgraph.config, "get_stream_writer", lambda: fake_writer)
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
mw._apply(state, _runtime("t-stream"))
assert len(captured) == 1
payload = captured[0]
assert payload["type"] == "safety_termination"
assert payload["detector"] == "openai_compatible_content_filter"
assert payload["reason_field"] == "finish_reason"
assert payload["reason_value"] == "content_filter"
assert payload["suppressed_tool_call_count"] == 1
assert payload["suppressed_tool_call_names"] == ["write_file"]
assert payload["thread_id"] == "t-stream"
def test_writer_unavailable_does_not_break(self, monkeypatch):
import langgraph.config
def boom():
raise LookupError("not in a stream context")
monkeypatch.setattr(langgraph.config, "get_stream_writer", boom)
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Should not raise.
result = mw._apply(state, _runtime())
assert result is not None