Files
deer-flow/backend/tests/test_safety_termination_detectors.py
T
Xinmin Zeng be0eae9825 fix(runtime): suppress tool execution when provider safety-terminates with tool_calls (#3035)
* fix(runtime): suppress tool execution when provider safety-terminates with tool_calls

When a provider stops generation for safety reasons (OpenAI/Moonshot
finish_reason=content_filter, Anthropic stop_reason=refusal, Gemini
finish_reason=SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/RECITATION/
IMAGE_SAFETY/...), the response may still carry truncated tool_calls.
LangChain's tool router treats any non-empty tool_calls as executable,
so partial arguments (e.g. write_file with a half-finished markdown)
get dispatched and the agent loops on retry.

Add SafetyFinishReasonMiddleware at after_model: detect safety
termination via a pluggable detector registry, clear both structured
tool_calls and raw additional_kwargs.tool_calls / function_call,
preserve response_metadata.finish_reason for downstream observers,
stamp additional_kwargs.safety_termination for traces, append a
user-facing explanation to message content (list-aware for thinking
blocks), and emit a safety_termination custom stream event so SSE
consumers can reconcile any "tool starting..." UI.

Default detectors cover OpenAI-compatible content_filter, Anthropic
refusal, and Gemini safety enums (text + image). Custom providers are
added via reflection (same pattern as guardrails). Wired into both
lead-agent and subagent runtimes.

Closes #3028

* fix(runtime): persist safety_termination as a middleware audit event

Address review on #3035: the SSE custom event is great for live
consumers but invisible to post-run audit. RunEventStore should carry
its own row so operators can answer "which runs were safety-suppressed
today?" from a single SQL query without joining the message body.

Worker now exposes the run-scoped RunJournal via
runtime.context["__run_journal"] (sentinel key, internal channel).
SafetyFinishReasonMiddleware calls the previously-unused
RunJournal.record_middleware, which emits

  event_type = "middleware:safety_termination"
  category   = "middleware"
  content    = {name, hook, action, changes={
                  detector, reason_field, reason_value,
                  suppressed_tool_call_count,
                  suppressed_tool_call_names,
                  suppressed_tool_call_ids,
                  message_id, extras}}

Tool *arguments* are deliberately excluded — those are the very content
the provider filtered and persisting them would defeat the purpose of
the safety filter (per review note in #3035).

Graceful skips when journal is absent (subagent runtime, unit tests,
no-event-store local dev). Journal exceptions never propagate into the
agent loop.

Refs #3028

* fix(runtime): satisfy ruff format + address Copilot review

- ruff format on safety_finish_reason_config.py and e2e demo (CI lint
  failed on ruff format --check; backend Makefile lint target runs
  ruff check AND ruff format --check).
- Docstring on SafetyFinishReasonConfig now says resolve_variable to
  match the actual loader used in from_config (the wording was
  resolve_class previously; behavior is unchanged — resolve_variable
  mirrors how guardrails.provider is loaded).
- Switch the AIMessage type check in SafetyFinishReasonMiddleware._apply
  from getattr(last, "type") == "ai" to isinstance(last, AIMessage),
  matching TokenUsageMiddleware / TodoMiddleware / ViewImageMiddleware
  / SummarizationMiddleware which are the dominant pattern.

Refs #3028
2026-05-22 21:20:28 +08:00

177 lines
7.1 KiB
Python

"""Unit tests for SafetyTerminationDetector built-ins."""
from langchain_core.messages import AIMessage
from deerflow.agents.middlewares.safety_termination_detectors import (
AnthropicRefusalDetector,
GeminiSafetyDetector,
OpenAICompatibleContentFilterDetector,
SafetyTermination,
SafetyTerminationDetector,
default_detectors,
)
def _ai(*, content="", tool_calls=None, response_metadata=None, additional_kwargs=None) -> AIMessage:
return AIMessage(
content=content,
tool_calls=tool_calls or [],
response_metadata=response_metadata or {},
additional_kwargs=additional_kwargs or {},
)
class TestOpenAICompatibleContentFilterDetector:
def test_default_matches_content_filter(self):
d = OpenAICompatibleContentFilterDetector()
hit = d.detect(_ai(response_metadata={"finish_reason": "content_filter"}))
assert hit is not None
assert hit.detector == "openai_compatible_content_filter"
assert hit.reason_field == "finish_reason"
assert hit.reason_value == "content_filter"
def test_case_insensitive_match(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "CONTENT_FILTER"})) is not None
def test_other_finish_reasons_pass_through(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "stop"})) is None
assert d.detect(_ai(response_metadata={"finish_reason": "tool_calls"})) is None
assert d.detect(_ai(response_metadata={"finish_reason": "length"})) is None
def test_missing_metadata_passes_through(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai()) is None
def test_non_string_finish_reason_passes_through(self):
# Some adapters may stash an enum or dict — must not raise.
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": 42})) is None
assert d.detect(_ai(response_metadata={"finish_reason": {"value": "content_filter"}})) is None
def test_falls_back_to_additional_kwargs(self):
# Legacy adapters surface finish_reason via additional_kwargs.
d = OpenAICompatibleContentFilterDetector()
hit = d.detect(_ai(additional_kwargs={"finish_reason": "content_filter"}))
assert hit is not None
def test_configurable_extra_values(self):
# Chinese providers sometimes use bespoke tokens.
d = OpenAICompatibleContentFilterDetector(finish_reasons=["content_filter", "sensitive", "violation"])
assert d.detect(_ai(response_metadata={"finish_reason": "sensitive"})) is not None
assert d.detect(_ai(response_metadata={"finish_reason": "violation"})) is not None
# Original token still matches.
assert d.detect(_ai(response_metadata={"finish_reason": "content_filter"})) is not None
def test_carries_azure_content_filter_results(self):
d = OpenAICompatibleContentFilterDetector()
filter_results = {"hate": {"filtered": True, "severity": "high"}}
hit = d.detect(
_ai(
response_metadata={
"finish_reason": "content_filter",
"content_filter_results": filter_results,
},
)
)
assert hit is not None
assert hit.extras["content_filter_results"] == filter_results
class TestAnthropicRefusalDetector:
def test_default_matches_refusal(self):
hit = AnthropicRefusalDetector().detect(_ai(response_metadata={"stop_reason": "refusal"}))
assert hit is not None
assert hit.reason_field == "stop_reason"
assert hit.reason_value == "refusal"
def test_other_stop_reasons_pass_through(self):
d = AnthropicRefusalDetector()
assert d.detect(_ai(response_metadata={"stop_reason": "end_turn"})) is None
assert d.detect(_ai(response_metadata={"stop_reason": "tool_use"})) is None
assert d.detect(_ai(response_metadata={"stop_reason": "max_tokens"})) is None
def test_anthropic_does_not_steal_finish_reason(self):
# An OpenAI message must not accidentally trip the Anthropic detector.
assert AnthropicRefusalDetector().detect(_ai(response_metadata={"finish_reason": "content_filter"})) is None
class TestGeminiSafetyDetector:
def test_default_set_covers_documented_reasons(self):
d = GeminiSafetyDetector()
for reason in (
# text safety
"SAFETY",
"BLOCKLIST",
"PROHIBITED_CONTENT",
"SPII",
"RECITATION",
# image safety
"IMAGE_SAFETY",
"IMAGE_PROHIBITED_CONTENT",
"IMAGE_RECITATION",
):
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is not None, reason
def test_normal_termination_passes_through(self):
d = GeminiSafetyDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "STOP"})) is None
# MAX_TOKENS / LANGUAGE / NO_IMAGE / OTHER / IMAGE_OTHER /
# MALFORMED_FUNCTION_CALL / UNEXPECTED_TOOL_CALL are intentionally
# excluded from the default set — they are either normal termination,
# capability mismatches, too broad (OTHER), or tool-call protocol
# errors. See GeminiSafetyDetector docstring.
for reason in (
"MAX_TOKENS",
"LANGUAGE",
"NO_IMAGE",
"OTHER",
"IMAGE_OTHER",
"MALFORMED_FUNCTION_CALL",
"UNEXPECTED_TOOL_CALL",
"FINISH_REASON_UNSPECIFIED",
):
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is None, reason
def test_carries_safety_ratings(self):
ratings = [{"category": "HARM_CATEGORY_HARASSMENT", "probability": "HIGH"}]
hit = GeminiSafetyDetector().detect(
_ai(
response_metadata={
"finish_reason": "SAFETY",
"safety_ratings": ratings,
},
)
)
assert hit is not None
assert hit.extras["safety_ratings"] == ratings
class TestDefaultDetectorSet:
def test_default_set_returns_three_detectors(self):
dets = default_detectors()
names = {d.name for d in dets}
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
def test_default_set_returns_fresh_list(self):
# Caller mutation must not affect later calls.
first = default_detectors()
first.clear()
second = default_detectors()
assert len(second) == 3
class TestProtocolConformance:
def test_builtins_satisfy_protocol(self):
for d in default_detectors():
assert isinstance(d, SafetyTerminationDetector)
def test_safety_termination_is_frozen(self):
t = SafetyTermination(detector="x", reason_field="finish_reason", reason_value="content_filter")
try:
t.detector = "y" # type: ignore[misc]
except Exception:
return
raise AssertionError("SafetyTermination should be frozen")