fix(runtime): suppress tool execution when provider safety-terminates with tool_calls (#3035)

* fix(runtime): suppress tool execution when provider safety-terminates with tool_calls

When a provider stops generation for safety reasons (OpenAI/Moonshot
finish_reason=content_filter, Anthropic stop_reason=refusal, Gemini
finish_reason=SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/RECITATION/
IMAGE_SAFETY/...), the response may still carry truncated tool_calls.
LangChain's tool router treats any non-empty tool_calls as executable,
so partial arguments (e.g. write_file with a half-finished markdown)
get dispatched and the agent loops on retry.

Add SafetyFinishReasonMiddleware at after_model: detect safety
termination via a pluggable detector registry, clear both structured
tool_calls and raw additional_kwargs.tool_calls / function_call,
preserve response_metadata.finish_reason for downstream observers,
stamp additional_kwargs.safety_termination for traces, append a
user-facing explanation to message content (list-aware for thinking
blocks), and emit a safety_termination custom stream event so SSE
consumers can reconcile any "tool starting..." UI.

Default detectors cover OpenAI-compatible content_filter, Anthropic
refusal, and Gemini safety enums (text + image). Custom providers are
added via reflection (same pattern as guardrails). Wired into both
lead-agent and subagent runtimes.

Closes #3028

* fix(runtime): persist safety_termination as a middleware audit event

Address review on #3035: the SSE custom event is great for live
consumers but invisible to post-run audit. RunEventStore should carry
its own row so operators can answer "which runs were safety-suppressed
today?" from a single SQL query without joining the message body.

Worker now exposes the run-scoped RunJournal via
runtime.context["__run_journal"] (sentinel key, internal channel).
SafetyFinishReasonMiddleware calls the previously-unused
RunJournal.record_middleware, which emits

  event_type = "middleware:safety_termination"
  category   = "middleware"
  content    = {name, hook, action, changes={
                  detector, reason_field, reason_value,
                  suppressed_tool_call_count,
                  suppressed_tool_call_names,
                  suppressed_tool_call_ids,
                  message_id, extras}}

Tool *arguments* are deliberately excluded — those are the very content
the provider filtered and persisting them would defeat the purpose of
the safety filter (per review note in #3035).

Graceful skips when journal is absent (subagent runtime, unit tests,
no-event-store local dev). Journal exceptions never propagate into the
agent loop.

Refs #3028

* fix(runtime): satisfy ruff format + address Copilot review

- ruff format on safety_finish_reason_config.py and e2e demo (CI lint
  failed on ruff format --check; backend Makefile lint target runs
  ruff check AND ruff format --check).
- Docstring on SafetyFinishReasonConfig now says resolve_variable to
  match the actual loader used in from_config (the wording was
  resolve_class previously; behavior is unchanged — resolve_variable
  mirrors how guardrails.provider is loaded).
- Switch the AIMessage type check in SafetyFinishReasonMiddleware._apply
  from getattr(last, "type") == "ai" to isinstance(last, AIMessage),
  matching TokenUsageMiddleware / TodoMiddleware / ViewImageMiddleware
  / SummarizationMiddleware which are the dominant pattern.

Refs #3028
This commit is contained in:
Xinmin Zeng
2026-05-22 21:20:28 +08:00
committed by GitHub
parent 253542ea0d
commit be0eae9825
14 changed files with 1936 additions and 5 deletions
@@ -29,6 +29,7 @@ from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
@@ -338,6 +339,15 @@ def _build_middlewares(
if custom_middlewares:
middlewares.extend(custom_middlewares)
# SafetyFinishReasonMiddleware — suppress tool execution when the provider
# safety-terminated the response. Registered after custom middlewares so
# that LangChain's reverse-order after_model dispatch runs Safety first;
# cleared tool_calls then flow through Loop/Subagent accounting without
# firing extra alarms. See safety_finish_reason_middleware.py docstring.
safety_config = resolved_app_config.safety_finish_reason
if safety_config.enabled:
middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config))
# ClarificationMiddleware should always be last
middlewares.append(ClarificationMiddleware())
return middlewares
@@ -0,0 +1,317 @@
"""Suppress tool execution when the provider safety-terminated the response.
Background — see issue bytedance/deer-flow#3028.
Some providers (OpenAI ``finish_reason='content_filter'``, Anthropic
``stop_reason='refusal'``, Gemini ``finish_reason='SAFETY'`` ...) can stop
generation mid-stream while still returning partially-formed ``tool_calls``.
LangChain's tool router treats any AIMessage with a non-empty ``tool_calls``
field as "go execute these", so half-truncated arguments — e.g. a markdown
``write_file`` that stops in the middle of a sentence — get dispatched as if
they were complete. The agent then sees the truncated file, tries to fix it,
gets filtered again, and loops.
This middleware sits at ``after_model`` and gates that behaviour: when a
configured ``SafetyTerminationDetector`` fires *and* the AIMessage carries
tool calls, we strip the tool calls (both structured and raw provider
payloads), append a user-facing explanation, and stash observability fields
in ``additional_kwargs.safety_termination`` so logs, traces, and SSE
consumers can see what happened.
Hook choice: ``after_model`` (not ``wrap_model_call``) because the response
is a *normal* return — not an exception — and we want to participate in the
same after-model chain as ``LoopDetectionMiddleware``, with which we share
the same tool-call-suppression mechanic but a different trigger.
Placement: register *after* ``LoopDetectionMiddleware`` in the middleware
list. LangChain factory wires ``after_model`` edges in reverse list order
(``langchain/agents/factory.py:add_edge("model", middleware_w_after_model[-1])``,
then walks ``range(len-1, 0, -1)``), so the *last* registered middleware is
the *first* to observe the model output. Registering Safety after Loop
means Safety sees the raw response first, clears tool calls if it fires,
and Loop then accounts against the cleaned message.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime
from deerflow.agents.middlewares.safety_termination_detectors import (
SafetyTermination,
SafetyTerminationDetector,
default_detectors,
)
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
if TYPE_CHECKING:
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
logger = logging.getLogger(__name__)
_USER_FACING_MESSAGE = (
"The model provider stopped this response with a safety-related signal "
"({reason_field}={reason_value!r}, detector={detector!r}). Any tool "
"calls produced in this turn were suppressed because their arguments "
"may be truncated and unsafe to execute. Please rephrase the request "
"or ask for a narrower output."
)
class SafetyFinishReasonMiddleware(AgentMiddleware[AgentState]):
"""Strip tool_calls from AIMessages flagged by a SafetyTerminationDetector."""
def __init__(self, detectors: list[SafetyTerminationDetector] | None = None) -> None:
super().__init__()
# Copy so caller mutations after construction don't leak into us.
self._detectors: list[SafetyTerminationDetector] = list(detectors) if detectors else default_detectors()
@classmethod
def from_config(cls, config: SafetyFinishReasonConfig) -> SafetyFinishReasonMiddleware:
"""Construct from validated Pydantic config, honouring the
reflection-loaded detector list when provided.
An explicit empty list is intentionally rejected — it would silently
disable detection while leaving the middleware in the chain, which
is the worst of both worlds. Use ``enabled: false`` instead.
"""
if config.detectors is None:
return cls()
if not config.detectors:
raise ValueError("safety_finish_reason.detectors must be omitted (use built-ins) or contain at least one entry; use enabled=false to disable the middleware entirely.")
from deerflow.reflection import resolve_variable
detectors: list[SafetyTerminationDetector] = []
for entry in config.detectors:
detector_cls = resolve_variable(entry.use)
kwargs = dict(entry.config) if entry.config else {}
detector = detector_cls(**kwargs)
if not isinstance(detector, SafetyTerminationDetector):
raise TypeError(f"{entry.use} did not produce a SafetyTerminationDetector (got {type(detector).__name__}); ensure it has a `name` attribute and a `detect(message)` method")
detectors.append(detector)
return cls(detectors=detectors)
# ----- detection -------------------------------------------------------
def _detect(self, message: AIMessage) -> SafetyTermination | None:
for detector in self._detectors:
try:
hit = detector.detect(message)
except Exception: # noqa: BLE001 - never let a buggy detector break the agent run
logger.exception("SafetyTerminationDetector %r raised; treating as no-match", getattr(detector, "name", type(detector).__name__))
continue
if hit is not None:
return hit
return None
# ----- message rewriting ----------------------------------------------
@staticmethod
def _append_user_message(content: object, text: str) -> str | list:
"""Append a plain-text explanation to AIMessage content.
Mirrors ``LoopDetectionMiddleware._append_text`` so list-content
responses (Anthropic thinking blocks, vLLM reasoning splits) keep
their structure instead of being string-coerced into a TypeError.
"""
if content is None or content == "":
return text
if isinstance(content, list):
return [*content, {"type": "text", "text": f"\n\n{text}"}]
if isinstance(content, str):
return content + f"\n\n{text}"
return str(content) + f"\n\n{text}"
def _build_suppressed_message(
self,
message: AIMessage,
termination: SafetyTermination,
) -> AIMessage:
suppressed_names = [tc.get("name") or "unknown" for tc in (message.tool_calls or [])]
explanation = _USER_FACING_MESSAGE.format(
reason_field=termination.reason_field,
reason_value=termination.reason_value,
detector=termination.detector,
)
new_content = self._append_user_message(message.content, explanation)
# clone_ai_message_with_tool_calls handles structured tool_calls,
# raw additional_kwargs.tool_calls, and function_call in one shot.
# It only rewrites finish_reason when the old value was "tool_calls",
# which is not our case — content_filter / refusal / SAFETY stay put
# so downstream SSE / converters keep seeing the real provider reason.
cleared = clone_ai_message_with_tool_calls(message, [], content=new_content)
# Re-clone additional_kwargs so we don't accidentally mutate the
# dict returned by clone_ai_message_with_tool_calls (which already
# made a shallow copy, but downstream model_copy still references
# it). Then stamp the observability record.
kwargs = dict(getattr(cleared, "additional_kwargs", None) or {})
kwargs["safety_termination"] = {
"detector": termination.detector,
"reason_field": termination.reason_field,
"reason_value": termination.reason_value,
"suppressed_tool_call_count": len(suppressed_names),
"suppressed_tool_call_names": suppressed_names,
"extras": dict(termination.extras) if termination.extras else {},
}
return cleared.model_copy(update={"additional_kwargs": kwargs})
# ----- observability ---------------------------------------------------
def _emit_event(
self,
termination: SafetyTermination,
suppressed_names: list[str],
runtime: Runtime,
) -> None:
"""Notify SSE consumers (e.g. the web UI) that a tool turn was
suppressed so they can reconcile any "tool starting..." placeholders
already streamed to the user. Failures are logged at debug and
ignored — this is a best-effort signal."""
try:
from langgraph.config import get_stream_writer
writer = get_stream_writer()
except Exception: # noqa: BLE001
logger.debug("get_stream_writer unavailable; skipping safety_termination event", exc_info=True)
return
thread_id = None
if runtime is not None and getattr(runtime, "context", None):
thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None
try:
writer(
{
"type": "safety_termination",
"detector": termination.detector,
"reason_field": termination.reason_field,
"reason_value": termination.reason_value,
"suppressed_tool_call_count": len(suppressed_names),
"suppressed_tool_call_names": suppressed_names,
"thread_id": thread_id,
}
)
except Exception: # noqa: BLE001
logger.debug("Failed to emit safety_termination stream event", exc_info=True)
def _record_audit_event(
self,
termination: SafetyTermination,
message,
tool_calls: list[dict],
runtime: Runtime,
) -> None:
"""Write a ``middleware:safety_termination`` record to RunEventStore
for post-run auditability.
The custom stream event in ``_emit_event`` is consumed by live SSE
clients and disappears after the run; this event is persisted so an
operator can answer "which runs were safety-suppressed today?" from
a single SQL query without joining the message body. Worker exposes
the run-scoped ``RunJournal`` via ``runtime.context["__run_journal"]``;
absent in unit-test / subagent / no-event-store paths, in which case
we silently skip.
Tool **arguments** are deliberately **not** recorded — those are the
very content the provider filtered; persisting them would defeat the
purpose of the safety filter. Names / count / ids are sufficient for
audit and debugging (issue #3028 review).
"""
journal = None
if runtime is not None and getattr(runtime, "context", None):
context = runtime.context
if isinstance(context, dict):
journal = context.get("__run_journal")
if journal is None:
return
suppressed_names = [tc.get("name") or "unknown" for tc in tool_calls]
suppressed_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
changes = {
"detector": termination.detector,
"reason_field": termination.reason_field,
"reason_value": termination.reason_value,
"suppressed_tool_call_count": len(tool_calls),
"suppressed_tool_call_names": suppressed_names,
"suppressed_tool_call_ids": suppressed_ids,
"message_id": getattr(message, "id", None),
"extras": dict(termination.extras) if termination.extras else {},
}
try:
journal.record_middleware(
tag="safety_termination",
name=type(self).__name__,
hook="after_model",
action="suppress_tool_calls",
changes=changes,
)
except Exception: # noqa: BLE001
# Audit-event persistence must never break agent execution.
logger.debug("Failed to record middleware:safety_termination event", exc_info=True)
# ----- main apply ------------------------------------------------------
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
messages = state.get("messages", [])
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage):
return None
# Issue scope: only intervene when there's something to suppress.
# ``content_filter`` without tool_calls is allowed through unchanged
# so the partial text response (if any) reaches the user naturally.
tool_calls = last.tool_calls
if not tool_calls:
return None
termination = self._detect(last)
if termination is None:
return None
patched = self._build_suppressed_message(last, termination)
thread_id = None
if runtime is not None and getattr(runtime, "context", None):
thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None
logger.warning(
"Provider safety termination detected — suppressed %d tool call(s)",
len(tool_calls),
extra={
"thread_id": thread_id,
"detector": termination.detector,
"reason_field": termination.reason_field,
"reason_value": termination.reason_value,
"suppressed_tool_call_names": [tc.get("name") for tc in tool_calls],
},
)
self._emit_event(termination, [tc.get("name") or "unknown" for tc in tool_calls], runtime)
self._record_audit_event(termination, last, list(tool_calls), runtime)
return {"messages": [patched]}
# ----- hooks -----------------------------------------------------------
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
@@ -0,0 +1,237 @@
"""Detectors for provider-side safety termination signals.
Different LLM providers signal "I stopped this response for safety reasons"
through different fields with different values. This module defines a small
strategy interface and three built-in detectors that cover the major
providers DeerFlow supports today. New providers (Wenxin, Hunyuan, Bedrock
adapters, in-house gateways, ...) can be added by implementing
``SafetyTerminationDetector`` and wiring it through
``config.yaml: safety_finish_reason.detectors``.
The middleware that consumes these detectors lives in
``safety_finish_reason_middleware.py``.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Protocol, runtime_checkable
from langchain_core.messages import AIMessage
@dataclass(frozen=True)
class SafetyTermination:
"""A detected safety-related termination signal.
Attributes:
detector: Name of the detector that produced this result. Used for
observability so operators can see which provider rule fired.
reason_field: The message metadata field that carried the signal
(e.g. ``finish_reason``, ``stop_reason``).
reason_value: The actual value of that field
(e.g. ``content_filter``, ``refusal``, ``SAFETY``).
extras: Provider-specific metadata that may help downstream
consumers (e.g. Azure OpenAI content_filter_results, Gemini
safety_ratings). Detectors are free to populate or skip this.
"""
detector: str
reason_field: str
reason_value: str
extras: dict[str, Any] = field(default_factory=dict)
@runtime_checkable
class SafetyTerminationDetector(Protocol):
"""Strategy interface for provider safety termination detection."""
name: str
def detect(self, message: AIMessage) -> SafetyTermination | None:
"""Return a SafetyTermination if *message* indicates provider safety
termination, otherwise return ``None``.
Implementations must be side-effect free and tolerant of missing or
oddly-typed metadata — detectors run on every model response.
"""
...
def _get_metadata_value(message: AIMessage, field_name: str) -> str | None:
"""Read a string-typed value from either ``response_metadata`` or
``additional_kwargs``.
LangChain provider adapters are inconsistent about where they stash
provider stop signals. Most modern adapters use ``response_metadata``,
but some legacy / passthrough paths still surface them via
``additional_kwargs``. We check both, in that order, and only accept
string values — Pydantic enums or dicts are ignored so we never raise
on malformed inputs.
"""
for container_name in ("response_metadata", "additional_kwargs"):
container = getattr(message, container_name, None) or {}
if not isinstance(container, dict):
continue
value = container.get(field_name)
if isinstance(value, str) and value:
return value
return None
class OpenAICompatibleContentFilterDetector:
"""OpenAI-compatible content_filter signal.
Covers OpenAI, Azure OpenAI, Moonshot/Kimi, DeepSeek, Mistral, vLLM,
Qwen (OpenAI-compatible mode), and any other adapter that follows the
OpenAI ``finish_reason`` convention.
Some Chinese providers ship custom OpenAI-compatible gateways that use
alternative tokens like ``sensitive`` or ``violation``. Extend the set
via the ``finish_reasons`` kwarg in config.
"""
name = "openai_compatible_content_filter"
def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None:
configured = finish_reasons if finish_reasons is not None else ("content_filter",)
self._finish_reasons: frozenset[str] = frozenset(r.lower() for r in configured)
def detect(self, message: AIMessage) -> SafetyTermination | None:
value = _get_metadata_value(message, "finish_reason")
if value is None or value.lower() not in self._finish_reasons:
return None
extras: dict[str, Any] = {}
# Azure OpenAI ships a structured content_filter_results block; carry it
# through so operators can see *what* was filtered without re-tracing.
response_metadata = getattr(message, "response_metadata", None) or {}
if isinstance(response_metadata, dict):
filter_results = response_metadata.get("content_filter_results")
if filter_results:
extras["content_filter_results"] = filter_results
return SafetyTermination(
detector=self.name,
reason_field="finish_reason",
reason_value=value,
extras=extras,
)
class AnthropicRefusalDetector:
"""Anthropic ``stop_reason == "refusal"`` signal.
Anthropic models surface safety refusals via a dedicated ``stop_reason``
rather than ``finish_reason``. See:
https://platform.claude.com/docs/en/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals
"""
name = "anthropic_refusal"
def __init__(self, stop_reasons: list[str] | tuple[str, ...] | None = None) -> None:
configured = stop_reasons if stop_reasons is not None else ("refusal",)
self._stop_reasons: frozenset[str] = frozenset(r.lower() for r in configured)
def detect(self, message: AIMessage) -> SafetyTermination | None:
value = _get_metadata_value(message, "stop_reason")
if value is None or value.lower() not in self._stop_reasons:
return None
return SafetyTermination(
detector=self.name,
reason_field="stop_reason",
reason_value=value,
)
class GeminiSafetyDetector:
"""Gemini / Vertex AI safety-related finish reasons.
Gemini uses the same ``finish_reason`` field as OpenAI but with an
enumerated upper-case taxonomy. The default set covers every Gemini
finish_reason that means "the model stopped because the content/image
tripped a safety, blocklist, recitation, or PII filter" — i.e. cases
where any tool_calls returned alongside are likely truncated/
unreliable. Full enum:
https://docs.cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform_v1.types.Candidate.FinishReason
Intentionally **excluded** from the default set:
- ``STOP`` — normal termination.
- ``MAX_TOKENS`` — output length truncation, not safety
(same root failure mode as
content_filter, but issue #3028
scopes it out; expose separately if
desired).
- ``LANGUAGE`` / ``NO_IMAGE`` — capability mismatches, unrelated to
safety; tool_calls would be absent
anyway.
- ``MALFORMED_FUNCTION_CALL`` /
``UNEXPECTED_TOOL_CALL`` — tool-call protocol errors. The
tool_calls are *also* unreliable
here, but the failure category is
distinct from safety filtering;
handle in a dedicated detector to
keep observability records honest.
- ``OTHER`` / ``IMAGE_OTHER`` /
``FINISH_REASON_UNSPECIFIED`` — too broad to enable by default;
opt in via ``finish_reasons=`` if
your provider abuses these.
"""
name = "gemini_safety"
_DEFAULT_FINISH_REASONS = (
# Text safety
"SAFETY",
"BLOCKLIST",
"PROHIBITED_CONTENT",
"SPII",
"RECITATION",
# Image safety (multimodal generation)
"IMAGE_SAFETY",
"IMAGE_PROHIBITED_CONTENT",
"IMAGE_RECITATION",
)
def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None:
configured = finish_reasons if finish_reasons is not None else self._DEFAULT_FINISH_REASONS
self._finish_reasons: frozenset[str] = frozenset(r.upper() for r in configured)
def detect(self, message: AIMessage) -> SafetyTermination | None:
value = _get_metadata_value(message, "finish_reason")
if value is None or value.upper() not in self._finish_reasons:
return None
extras: dict[str, Any] = {}
response_metadata = getattr(message, "response_metadata", None) or {}
if isinstance(response_metadata, dict):
# Gemini surfaces per-category scoring under safety_ratings.
ratings = response_metadata.get("safety_ratings")
if ratings:
extras["safety_ratings"] = ratings
return SafetyTermination(
detector=self.name,
reason_field="finish_reason",
reason_value=value,
extras=extras,
)
def default_detectors() -> list[SafetyTerminationDetector]:
"""Built-in detector set used when no custom detectors are configured."""
return [
OpenAICompatibleContentFilterDetector(),
AnthropicRefusalDetector(),
GeminiSafetyDetector(),
]
__all__ = [
"AnthropicRefusalDetector",
"GeminiSafetyDetector",
"OpenAICompatibleContentFilterDetector",
"SafetyTermination",
"SafetyTerminationDetector",
"default_detectors",
]
@@ -164,4 +164,14 @@ def build_subagent_runtime_middlewares(
middlewares.append(ViewImageMiddleware())
# Same provider safety-termination guard the lead agent uses — subagents
# are equally exposed to truncated tool_calls returned with
# finish_reason=content_filter (and friends), and the bad call would then
# propagate back to the lead agent via the task tool result.
safety_config = app_config.safety_finish_reason
if safety_config.enabled:
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config))
return middlewares
@@ -20,6 +20,7 @@ from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_
from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
from deerflow.config.runtime_paths import existing_project_file
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
from deerflow.config.skills_config import SkillsConfig
@@ -102,6 +103,7 @@ class AppConfig(BaseModel):
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
model_config = ConfigDict(extra="allow")
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
@@ -0,0 +1,47 @@
"""Configuration for SafetyFinishReasonMiddleware.
Mirrors the shape of GuardrailsConfig: detectors are loaded by class path
through ``deerflow.reflection.resolve_variable`` (same loader the
``guardrails.provider`` config uses) so users can drop in custom provider
detectors without modifying core code.
"""
from __future__ import annotations
from pydantic import BaseModel, Field
class SafetyDetectorConfig(BaseModel):
"""One detector entry under ``safety_finish_reason.detectors``."""
use: str = Field(
description=("Class path of a SafetyTerminationDetector implementation (e.g. 'deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector')."),
)
config: dict = Field(
default_factory=dict,
description="Constructor kwargs passed to the detector class.",
)
class SafetyFinishReasonConfig(BaseModel):
"""Configuration for the SafetyFinishReasonMiddleware.
The middleware intercepts AIMessages where the provider signaled a
safety-related termination (e.g. OpenAI ``finish_reason='content_filter'``)
while still returning tool calls, and suppresses those tool calls so the
half-truncated arguments never execute.
"""
enabled: bool = Field(
default=True,
description="Master switch for the SafetyFinishReasonMiddleware.",
)
detectors: list[SafetyDetectorConfig] | None = Field(
default=None,
description=(
"Custom detector list. Leave unset (None) to use the built-in "
"set covering OpenAI-compatible content_filter, Anthropic "
"refusal, and Gemini SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/"
"RECITATION. Provide a non-null list to fully override."
),
)
@@ -219,6 +219,12 @@ async def run_agent(
# manually here because we drive the graph through ``agent.astream(config=...)``
# without passing the official ``context=`` parameter.
runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"), ctx.app_config)
# Expose the run-scoped journal under a sentinel key so middleware can
# write audit events (e.g. SafetyFinishReasonMiddleware recording
# suppressed tool calls). Double-underscore prefix marks it as a
# runtime-internal channel; user code must not depend on the key name.
if journal is not None:
runtime_ctx["__run_journal"] = journal
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=cast(Any, runtime_ctx), store=store)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
@@ -0,0 +1,206 @@
"""End-to-end demo: SafetyFinishReasonMiddleware on the real DeerFlow lead-agent.
What it proves
--------------
- The real ``make_lead_agent`` / ``DeerFlowClient`` pipeline is built (full
18-middleware chain, sandbox, tools, etc.).
- A model that returns ``finish_reason='content_filter'`` + ``tool_calls``
triggers SafetyFinishReasonMiddleware.
- LangChain's tool router never invokes ``write_file`` — the truncated
arguments do **not** reach the sandbox.
- A ``safety_termination`` custom event is emitted on the stream and the
final AIMessage carries the observability stamp.
Run from backend/ directory:
PYTHONPATH=. uv run python scripts/e2e_safety_termination_demo.py
"""
from __future__ import annotations
import sys
from typing import Any
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, ChatResult
# ---------------------------------------------------------------------------
# Fake provider that mimics Moonshot's content_filter behaviour
# ---------------------------------------------------------------------------
class _ContentFilteredFakeModel(BaseChatModel):
"""First call returns finish_reason=content_filter + truncated write_file
tool_call. Subsequent calls return a normal stop response so the agent
can terminate (the middleware should make a second call unnecessary by
clearing tool_calls, but we keep this safety net in case loop-detection
or anything else triggers another model invocation)."""
call_count: int = 0
@property
def _llm_type(self) -> str:
return "fake-content-filtered"
def bind_tools(self, tools, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self.call_count += 1
if self.call_count == 1:
msg = AIMessage(
content="# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与",
tool_calls=[
{
"id": "call_truncated_write",
"name": "write_file",
"args": {
"path": "/mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md",
"content": "# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与",
},
}
],
response_metadata={
"finish_reason": "content_filter",
"model_name": "kimi-k2.6",
"model_provider": "openai",
},
)
else:
msg = AIMessage(
content="(secondary call, should not be needed)",
response_metadata={"finish_reason": "stop", "model_name": "kimi-k2.6"},
)
return ChatResult(generations=[ChatGeneration(message=msg)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
# ---------------------------------------------------------------------------
# Driver
# ---------------------------------------------------------------------------
def main() -> int:
# Inject the fake model BEFORE constructing the client. Both the
# client module and the lead-agent module bind ``create_chat_model``
# at import time via ``from deerflow.models import create_chat_model``,
# so we patch both attribute slots — the source-of-truth patch on
# ``factory.create_chat_model`` doesn't propagate back into already-
# imported names.
import deerflow.agents.lead_agent.agent as lead_agent_module
import deerflow.client as client_module
fake = _ContentFilteredFakeModel()
originals = {
"lead": lead_agent_module.create_chat_model,
"client": client_module.create_chat_model,
}
def fake_create_chat_model(*args, **kwargs):
return fake
lead_agent_module.create_chat_model = fake_create_chat_model
client_module.create_chat_model = fake_create_chat_model
from deerflow.client import DeerFlowClient
try:
client = DeerFlowClient()
print("\n=== Streaming a turn through the real lead-agent ===")
events: list[dict[str, Any]] = []
for event in client.stream(
"帮我整理一下最近一周政经新闻,写到 /mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md",
thread_id="e2e-safety-1",
):
events.append({"type": event.type, "data": event.data})
# ---- Assertions ----
safety_event = next(
(e for e in events if e["type"] == "custom" and isinstance(e["data"], dict) and e["data"].get("type") == "safety_termination"),
None,
)
final_values = next(
(e for e in reversed(events) if e["type"] == "values"),
None,
)
tool_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "tool"]
ai_tool_call_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "ai" and e["data"].get("tool_calls")]
print(f"\n[stats] total stream events: {len(events)}")
print(f"[stats] model call count: {fake.call_count}")
print(f"[stats] tool messages on stream: {len(tool_messages)}")
print(f"[stats] AI messages carrying tool_calls: {len(ai_tool_call_messages)}")
print("\n[event] safety_termination custom event:")
if safety_event is None:
print(" *** NOT FOUND ***")
return 1
for k, v in safety_event["data"].items():
print(f" {k}: {v}")
print("\n[state] final AIMessage from last values snapshot:")
if final_values is None:
print(" *** no values snapshot ***")
return 1
# `values` event carries `_serialize_message` dicts, not Message objects.
final_messages = final_values["data"].get("messages") or []
last_ai = next((m for m in reversed(final_messages) if isinstance(m, dict) and m.get("type") == "ai"), None)
if last_ai is None:
print(" *** no AIMessage in final state ***")
print(f" message types seen: {[m.get('type') if isinstance(m, dict) else type(m).__name__ for m in final_messages]}")
return 1
tool_calls = last_ai.get("tool_calls") or []
additional_kwargs = last_ai.get("additional_kwargs") or {}
response_metadata = last_ai.get("response_metadata") or {}
content = last_ai.get("content")
print(f" tool_calls (must be empty): {tool_calls}")
print(f" additional_kwargs.safety_termination: {additional_kwargs.get('safety_termination')}")
content_preview = (content if isinstance(content, str) else str(content))[:200]
print(f" content[:200]: {content_preview!r}")
print(f" response_metadata.finish_reason: {response_metadata.get('finish_reason')}")
# NOTE: `client._serialize_message` does not include `response_metadata`
# in the values-event payload (client-layer behaviour, unrelated to the
# middleware). The middleware *does* preserve finish_reason on the
# AIMessage object — see test_safety_finish_reason_middleware.py::
# TestMessageRewrite::test_preserves_response_metadata_finish_reason.
# Here we assert on the observability stamp, which carries the same
# evidence and is in the serialized payload.
stamp = additional_kwargs.get("safety_termination") or {}
failures = []
if tool_calls:
failures.append("final AIMessage still has tool_calls — middleware did NOT clear them")
if not stamp:
failures.append("final AIMessage missing safety_termination observability stamp")
if tool_messages:
failures.append(f"tool node was invoked: {len(tool_messages)} ToolMessage(s) on stream")
if stamp.get("reason_value") != "content_filter":
failures.append(f"safety_termination.reason_value was {stamp.get('reason_value')!r}, expected 'content_filter'")
if safety_event is None:
failures.append("safety_termination custom event was not emitted on the stream")
if failures:
print("\n=== FAIL ===")
for f in failures:
print(f" - {f}")
return 1
print("\n=== PASS ===")
print(" - tool_calls cleared on final AIMessage")
print(" - tool node never invoked (no ToolMessage on stream)")
print(" - safety_termination custom event emitted")
print(" - observability stamp written to additional_kwargs")
print(" - response_metadata.finish_reason preserved for downstream SSE")
return 0
finally:
lead_agent_module.create_chat_model = originals["lead"]
client_module.create_chat_model = originals["client"]
if __name__ == "__main__":
sys.exit(main())
@@ -336,8 +336,11 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
)
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
# verify the custom middleware is injected correctly
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
# verify the custom middleware is injected correctly.
# Chain tail order after the custom middleware is:
# ..., custom, SafetyFinishReasonMiddleware, ClarificationMiddleware
# so the custom mock sits at index [-3].
assert len(middlewares) > 0 and isinstance(middlewares[-3], MagicMock)
def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch):
@@ -0,0 +1,225 @@
"""End-to-end graph integration test for SafetyFinishReasonMiddleware.
Unit tests prove ``_apply`` does the right thing on a synthetic state.
This test does one level up: builds a real ``langchain.agents.create_agent``
graph with the SafetyFinishReasonMiddleware in place, feeds it a fake model
that returns ``finish_reason='content_filter'`` + tool_calls, and asserts:
1. The tool node is **not** invoked (the dangerous truncated tool call
is suppressed).
2. The final AIMessage in graph state has ``tool_calls == []``.
3. The observability ``safety_termination`` record is attached.
4. The user-facing explanation is appended to the message content.
This is the closest we can get to the issue's failure mode without a live
Moonshot key, and it proves the middleware actually gates LangChain's
tool router — not just rewrites state in isolation.
"""
from __future__ import annotations
from typing import Any
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import tool
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
_TOOL_INVOCATIONS: list[dict[str, Any]] = []
@tool
def write_file(path: str, content: str) -> str:
"""Pretend to write *content* to *path*. Records the call for assertion."""
_TOOL_INVOCATIONS.append({"path": path, "content": content})
return f"wrote {len(content)} bytes to {path}"
class _ContentFilteredModel(BaseChatModel):
"""Fake chat model that mimics OpenAI/Moonshot's content_filter response.
First call returns finish_reason='content_filter' + a tool_call whose
arguments are visibly truncated. Second call (if reached) returns a
normal text completion so the agent can terminate cleanly.
"""
call_count: int = 0
@property
def _llm_type(self) -> str:
return "fake-content-filtered"
def bind_tools(self, tools, **kwargs):
# create_agent binds tools onto the model; we don't actually need
# to bind anything since responses are hard-coded, but the method
# must not raise.
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self.call_count += 1
if self.call_count == 1:
message = AIMessage(
content="Here is the report:\n# Weekly Politics\n- Meeting time: 2026-05-12—",
tool_calls=[
{
"id": "call_truncated_1",
"name": "write_file",
"args": {
"path": "/mnt/user-data/outputs/report.md",
"content": "# Weekly Politics\n- Meeting time: 2026-05-12—",
},
}
],
response_metadata={"finish_reason": "content_filter", "model_name": "fake-kimi"},
)
else:
message = AIMessage(content="ack", response_metadata={"finish_reason": "stop"})
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
class _InspectMiddleware(AgentMiddleware):
"""Captures the messages list at every model entry so we can assert
no synthetic tool result was injected back into the conversation."""
def __init__(self) -> None:
super().__init__()
self.observed: list[list[Any]] = []
def wrap_model_call(self, request: ModelRequest, handler) -> ModelResponse:
self.observed.append(list(request.messages))
return handler(request)
def test_content_filter_with_tool_calls_does_not_invoke_tool_node():
_TOOL_INVOCATIONS.clear()
inspector = _InspectMiddleware()
agent = create_agent(
model=_ContentFilteredModel(),
tools=[write_file],
# Inspector first so its after_model is registered; Safety last in
# the list so it executes first under LIFO (matches production wiring).
middleware=[inspector, SafetyFinishReasonMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage(content="write me a report")]})
# Critical assertion: the dangerous truncated tool call must NOT have
# been executed. This is the entire point of the middleware.
assert _TOOL_INVOCATIONS == [], f"write_file was invoked despite content_filter: {_TOOL_INVOCATIONS}"
# Final AIMessage has no tool calls left.
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
assert final_ai.tool_calls == []
# Observability stamp is present.
record = final_ai.additional_kwargs.get("safety_termination")
assert record is not None
assert record["detector"] == "openai_compatible_content_filter"
assert record["reason_field"] == "finish_reason"
assert record["reason_value"] == "content_filter"
assert record["suppressed_tool_call_count"] == 1
assert record["suppressed_tool_call_names"] == ["write_file"]
# User-facing explanation is appended.
assert "safety-related signal" in final_ai.content
# Original partial text preserved (we don't throw away what the user
# already saw in the stream — see middleware docstring).
assert "Weekly Politics" in final_ai.content
# finish_reason on response_metadata is preserved (so SSE / converters
# downstream still see the real provider reason).
assert final_ai.response_metadata.get("finish_reason") == "content_filter"
def test_content_filter_without_tool_calls_passes_through_unchanged():
"""No tool calls => issue scope says don't intervene; the partial
response should be delivered as-is so the user sees what they got."""
_TOOL_INVOCATIONS.clear()
class _NoToolModel(BaseChatModel):
@property
def _llm_type(self) -> str:
return "fake-no-tool"
def bind_tools(self, tools, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
msg = AIMessage(
content="Partial answer truncated by safety filter",
response_metadata={"finish_reason": "content_filter"},
)
return ChatResult(generations=[ChatGeneration(message=msg)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
agent = create_agent(
model=_NoToolModel(),
tools=[write_file],
middleware=[SafetyFinishReasonMiddleware()],
)
result = agent.invoke({"messages": [HumanMessage(content="hi")]})
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
# Content untouched.
assert final_ai.content == "Partial answer truncated by safety filter"
# No safety_termination stamp because we didn't intervene.
assert "safety_termination" not in final_ai.additional_kwargs
# tool node never ran (there were no tool calls in the first place).
assert _TOOL_INVOCATIONS == []
def test_normal_tool_call_round_trip_is_not_affected():
"""Regression: a healthy finish_reason='tool_calls' response must still
execute the tool. The middleware must not over-fire."""
_TOOL_INVOCATIONS.clear()
class _HealthyToolModel(BaseChatModel):
call_count: int = 0
@property
def _llm_type(self) -> str:
return "fake-healthy"
def bind_tools(self, tools, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self.call_count += 1
if self.call_count == 1:
msg = AIMessage(
content="",
tool_calls=[
{
"id": "call_ok",
"name": "write_file",
"args": {"path": "/tmp/ok", "content": "complete content"},
}
],
response_metadata={"finish_reason": "tool_calls"},
)
else:
msg = AIMessage(content="done", response_metadata={"finish_reason": "stop"})
return ChatResult(generations=[ChatGeneration(message=msg)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
agent = create_agent(
model=_HealthyToolModel(),
tools=[write_file],
middleware=[SafetyFinishReasonMiddleware()],
)
agent.invoke({"messages": [HumanMessage(content="write")]})
assert _TOOL_INVOCATIONS == [{"path": "/tmp/ok", "content": "complete content"}]
@@ -0,0 +1,651 @@
"""Unit tests for SafetyFinishReasonMiddleware."""
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
from deerflow.agents.middlewares.safety_termination_detectors import (
SafetyTermination,
)
from deerflow.config.safety_finish_reason_config import (
SafetyDetectorConfig,
SafetyFinishReasonConfig,
)
def _runtime(thread_id="t-1"):
runtime = MagicMock()
runtime.context = {"thread_id": thread_id}
return runtime
def _ai(
*,
content="",
tool_calls=None,
response_metadata=None,
additional_kwargs=None,
):
return AIMessage(
content=content,
tool_calls=tool_calls or [],
response_metadata=response_metadata or {},
additional_kwargs=additional_kwargs or {},
)
def _write_call(idx=1, content_text="半截"):
return {
"id": f"call_write_{idx}",
"name": "write_file",
"args": {"path": "/mnt/user-data/outputs/x.md", "content": content_text},
}
class AlwaysHitDetector:
"""Test fixture: always reports the given termination."""
name = "always_hit"
def __init__(self, *, reason_field="finish_reason", reason_value="content_filter", extras=None):
self.reason_field = reason_field
self.reason_value = reason_value
self.extras = extras or {}
def detect(self, message):
return SafetyTermination(
detector=self.name,
reason_field=self.reason_field,
reason_value=self.reason_value,
extras=self.extras,
)
class NeverHitDetector:
name = "never_hit"
def detect(self, message):
return None
class RaisingDetector:
name = "raising"
def detect(self, message):
raise RuntimeError("boom")
# ---------------------------------------------------------------------------
# Core trigger behaviour
# ---------------------------------------------------------------------------
class TestTriggerCriteria:
def test_content_filter_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="partial",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
patched = result["messages"][0]
assert patched.tool_calls == []
def test_content_filter_without_tool_calls_passes_through(self):
"""issue scope: when there are no tool calls the partial text is a
legitimate final response and should not be rewritten."""
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="partial response",
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_normal_tool_calls_pass_through(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "tool_calls"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_normal_stop_with_tool_calls_pass_through(self):
# Some providers report finish_reason='stop' for tool-call messages.
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "stop"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_empty_message_list_passes_through(self):
mw = SafetyFinishReasonMiddleware()
assert mw._apply({"messages": []}, _runtime()) is None
def test_non_ai_last_message_passes_through(self):
mw = SafetyFinishReasonMiddleware()
state = {"messages": [HumanMessage(content="hi"), SystemMessage(content="sys")]}
assert mw._apply(state, _runtime()) is None
def test_anthropic_refusal_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"stop_reason": "refusal"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].tool_calls == []
def test_gemini_safety_with_tool_calls_triggers(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "SAFETY"},
)
]
}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].tool_calls == []
# ---------------------------------------------------------------------------
# Message rewriting
# ---------------------------------------------------------------------------
class TestMessageRewrite:
def test_clears_structured_tool_calls(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call(1), _write_call(2)],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, _runtime())
patched = result["messages"][0]
assert patched.tool_calls == []
def test_clears_raw_additional_kwargs_tool_calls(self):
"""Critical defence-in-depth: DanglingToolCallMiddleware will recover
tool calls from additional_kwargs.tool_calls if we forget them, which
would re-emit a synthetic ToolMessage downstream and confuse the
model. We must wipe both."""
mw = SafetyFinishReasonMiddleware()
raw_tool_calls = [
{
"id": "call_write_1",
"type": "function",
"function": {"name": "write_file", "arguments": '{"path": "/x"}'},
}
]
state = {
"messages": [
_ai(
tool_calls=[_write_call(1)],
response_metadata={"finish_reason": "content_filter"},
additional_kwargs={
"tool_calls": raw_tool_calls,
"function_call": {"name": "write_file", "arguments": "{}"},
},
)
]
}
result = mw._apply(state, _runtime())
patched = result["messages"][0]
assert "tool_calls" not in patched.additional_kwargs
assert "function_call" not in patched.additional_kwargs
def test_preserves_other_additional_kwargs(self):
# vLLM puts reasoning under additional_kwargs.reasoning; Anthropic
# may carry other provider-specific keys. They must not be wiped.
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
additional_kwargs={
"reasoning": "thinking text",
"custom_provider_field": {"x": 1},
},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.additional_kwargs["reasoning"] == "thinking text"
assert patched.additional_kwargs["custom_provider_field"] == {"x": 1}
def test_writes_observability_field(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call(1), _write_call(2)],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
record = patched.additional_kwargs["safety_termination"]
assert record["detector"] == "openai_compatible_content_filter"
assert record["reason_field"] == "finish_reason"
assert record["reason_value"] == "content_filter"
assert record["suppressed_tool_call_count"] == 2
assert record["suppressed_tool_call_names"] == ["write_file", "write_file"]
def test_preserves_response_metadata_finish_reason(self):
"""Downstream SSE converters read response_metadata.finish_reason —
we want them to see the *real* provider reason, not 'stop'."""
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter", "model_name": "kimi-k2"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.response_metadata["finish_reason"] == "content_filter"
assert patched.response_metadata["model_name"] == "kimi-k2"
def test_appends_user_facing_explanation_to_str_content(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="some partial text",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, str)
assert patched.content.startswith("some partial text")
assert "safety-related signal" in patched.content
def test_handles_empty_content(self):
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
content="",
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, str)
assert "safety-related signal" in patched.content
def test_handles_list_content_thinking_blocks(self):
"""Anthropic thinking / vLLM reasoning models emit content blocks.
Naively concatenating a string would raise TypeError."""
mw = SafetyFinishReasonMiddleware()
thinking_blocks = [
{"type": "thinking", "text": "let me consider..."},
{"type": "text", "text": "partial answer"},
]
state = {
"messages": [
_ai(
content=thinking_blocks,
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
patched = mw._apply(state, _runtime())["messages"][0]
assert isinstance(patched.content, list)
assert patched.content[:2] == thinking_blocks
assert patched.content[-1]["type"] == "text"
assert "safety-related signal" in patched.content[-1]["text"]
def test_idempotent_on_already_cleared_message(self):
# Re-running the middleware on a message we already cleared must not
# re-trigger (tool_calls is now empty → fast passthrough).
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
first = mw._apply(state, _runtime())
state2 = {"messages": [first["messages"][0]]}
second = mw._apply(state2, _runtime())
assert second is None
def test_preserves_message_id_for_add_messages_replacement(self):
"""LangGraph's add_messages reducer treats same-id messages as
replacements. model_copy keeps id by default."""
mw = SafetyFinishReasonMiddleware()
original = _ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
# AIMessage auto-generates id; capture it
original_id = original.id
state = {"messages": [original]}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.id == original_id
# ---------------------------------------------------------------------------
# Detector wiring
# ---------------------------------------------------------------------------
class TestDetectorWiring:
def test_iterates_detectors_in_order(self):
first = AlwaysHitDetector(reason_value="first")
second = AlwaysHitDetector(reason_value="second")
mw = SafetyFinishReasonMiddleware(detectors=[first, second])
state = {"messages": [_ai(tool_calls=[_write_call()])]}
patched = mw._apply(state, _runtime())["messages"][0]
assert patched.additional_kwargs["safety_termination"]["reason_value"] == "first"
def test_returns_none_when_no_detector_matches(self):
mw = SafetyFinishReasonMiddleware(detectors=[NeverHitDetector(), NeverHitDetector()])
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state, _runtime()) is None
def test_buggy_detector_does_not_break_run(self):
mw = SafetyFinishReasonMiddleware(detectors=[RaisingDetector(), AlwaysHitDetector()])
state = {"messages": [_ai(tool_calls=[_write_call()])]}
result = mw._apply(state, _runtime())
assert result is not None
assert result["messages"][0].additional_kwargs["safety_termination"]["detector"] == "always_hit"
def test_constructor_copies_detectors(self):
"""Caller mutation after construction must not leak into us."""
detectors = [AlwaysHitDetector()]
mw = SafetyFinishReasonMiddleware(detectors=detectors)
detectors.clear()
state = {"messages": [_ai(tool_calls=[_write_call()])]}
assert mw._apply(state, _runtime()) is not None
# ---------------------------------------------------------------------------
# from_config
# ---------------------------------------------------------------------------
class TestFromConfig:
def test_default_config_uses_builtin_detectors(self):
mw = SafetyFinishReasonMiddleware.from_config(SafetyFinishReasonConfig())
assert len(mw._detectors) == 3
names = {d.name for d in mw._detectors}
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
def test_custom_detectors_loaded_via_reflection(self):
cfg = SafetyFinishReasonConfig(
detectors=[
SafetyDetectorConfig(
use="deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector",
config={"finish_reasons": ["custom_filter"]},
),
]
)
mw = SafetyFinishReasonMiddleware.from_config(cfg)
assert len(mw._detectors) == 1
# Confirm the kwargs propagated.
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "custom_filter"},
)
]
}
assert mw._apply(state, _runtime()) is not None
# Default token no longer matches.
state2 = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
assert mw._apply(state2, _runtime()) is None
def test_empty_detector_list_rejected(self):
cfg = SafetyFinishReasonConfig(detectors=[])
with pytest.raises(ValueError, match="enabled=false"):
SafetyFinishReasonMiddleware.from_config(cfg)
def test_non_detector_class_rejected(self):
cfg = SafetyFinishReasonConfig(
detectors=[SafetyDetectorConfig(use="builtins:dict")],
)
with pytest.raises(TypeError):
SafetyFinishReasonMiddleware.from_config(cfg)
# ---------------------------------------------------------------------------
# Stream event
# ---------------------------------------------------------------------------
class TestAuditEvent:
"""Verify SafetyFinishReasonMiddleware records a `middleware:safety_termination`
audit event via RunJournal.record_middleware when the run-scoped journal is
exposed under runtime.context["__run_journal"].
Background: review on PR #3035 — SSE custom event handles live consumers,
but post-run audit needs a row in run_events that can be queried with one
SQL statement (no JOIN against message body).
"""
def _runtime_with_journal(self, journal):
runtime = MagicMock()
runtime.context = {"thread_id": "t-audit", "__run_journal": journal}
return runtime
def test_records_audit_event_when_journal_present(self):
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
tc = _write_call(1)
state = {
"messages": [
_ai(
content="partial",
tool_calls=[tc],
response_metadata={"finish_reason": "content_filter"},
)
]
}
result = mw._apply(state, self._runtime_with_journal(journal))
assert result is not None
journal.record_middleware.assert_called_once()
call = journal.record_middleware.call_args
# tag is positional or kwarg depending on call style; we use kwargs.
assert call.kwargs["tag"] == "safety_termination"
assert call.kwargs["name"] == "SafetyFinishReasonMiddleware"
assert call.kwargs["hook"] == "after_model"
assert call.kwargs["action"] == "suppress_tool_calls"
changes = call.kwargs["changes"]
assert changes["detector"] == "openai_compatible_content_filter"
assert changes["reason_field"] == "finish_reason"
assert changes["reason_value"] == "content_filter"
assert changes["suppressed_tool_call_count"] == 1
assert changes["suppressed_tool_call_names"] == ["write_file"]
assert changes["suppressed_tool_call_ids"] == ["call_write_1"]
assert "message_id" in changes
assert isinstance(changes["extras"], dict)
def test_audit_event_never_carries_tool_arguments(self):
"""PR #3035 review IMPORTANT: tool args are the filtered content itself
and must NOT be persisted to run_events under any circumstance."""
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
sensitive_tc = {
"id": "call_x",
"name": "write_file",
"args": {"path": "/x", "content": "FILTERED_CONTENT_DO_NOT_PERSIST"},
}
state = {
"messages": [
_ai(
tool_calls=[sensitive_tc],
response_metadata={"finish_reason": "content_filter"},
)
]
}
mw._apply(state, self._runtime_with_journal(journal))
flat = repr(journal.record_middleware.call_args)
assert "FILTERED_CONTENT_DO_NOT_PERSIST" not in flat, "tool arguments must not leak into audit event"
assert "args" not in journal.record_middleware.call_args.kwargs["changes"]
def test_no_journal_in_runtime_context_is_silently_skipped(self):
"""Subagent runtime / unit tests / no-event-store paths have no journal.
Middleware must still intervene and clear tool_calls — only the audit
event is skipped."""
mw = SafetyFinishReasonMiddleware()
runtime = MagicMock()
runtime.context = {"thread_id": "t-noj"} # no __run_journal
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Should not raise; should still clear tool_calls.
result = mw._apply(state, runtime)
assert result is not None
assert result["messages"][0].tool_calls == []
def test_journal_record_exception_does_not_break_run(self):
"""Buggy journal must never propagate an exception into the agent loop."""
journal = MagicMock()
journal.record_middleware.side_effect = RuntimeError("db down")
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Must not raise.
result = mw._apply(state, self._runtime_with_journal(journal))
assert result is not None
assert result["messages"][0].tool_calls == []
def test_no_record_when_passthrough(self):
"""When the middleware does NOT intervene, no audit event is written."""
journal = MagicMock()
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "tool_calls"}, # healthy
)
]
}
assert mw._apply(state, self._runtime_with_journal(journal)) is None
journal.record_middleware.assert_not_called()
class TestStreamEvent:
def test_emits_event_when_writer_available(self, monkeypatch):
captured: list = []
def fake_writer(payload):
captured.append(payload)
# Patch get_stream_writer at the symbol-resolution site.
import langgraph.config
monkeypatch.setattr(langgraph.config, "get_stream_writer", lambda: fake_writer)
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
mw._apply(state, _runtime("t-stream"))
assert len(captured) == 1
payload = captured[0]
assert payload["type"] == "safety_termination"
assert payload["detector"] == "openai_compatible_content_filter"
assert payload["reason_field"] == "finish_reason"
assert payload["reason_value"] == "content_filter"
assert payload["suppressed_tool_call_count"] == 1
assert payload["suppressed_tool_call_names"] == ["write_file"]
assert payload["thread_id"] == "t-stream"
def test_writer_unavailable_does_not_break(self, monkeypatch):
import langgraph.config
def boom():
raise LookupError("not in a stream context")
monkeypatch.setattr(langgraph.config, "get_stream_writer", boom)
mw = SafetyFinishReasonMiddleware()
state = {
"messages": [
_ai(
tool_calls=[_write_call()],
response_metadata={"finish_reason": "content_filter"},
)
]
}
# Should not raise.
result = mw._apply(state, _runtime())
assert result is not None
@@ -0,0 +1,176 @@
"""Unit tests for SafetyTerminationDetector built-ins."""
from langchain_core.messages import AIMessage
from deerflow.agents.middlewares.safety_termination_detectors import (
AnthropicRefusalDetector,
GeminiSafetyDetector,
OpenAICompatibleContentFilterDetector,
SafetyTermination,
SafetyTerminationDetector,
default_detectors,
)
def _ai(*, content="", tool_calls=None, response_metadata=None, additional_kwargs=None) -> AIMessage:
return AIMessage(
content=content,
tool_calls=tool_calls or [],
response_metadata=response_metadata or {},
additional_kwargs=additional_kwargs or {},
)
class TestOpenAICompatibleContentFilterDetector:
def test_default_matches_content_filter(self):
d = OpenAICompatibleContentFilterDetector()
hit = d.detect(_ai(response_metadata={"finish_reason": "content_filter"}))
assert hit is not None
assert hit.detector == "openai_compatible_content_filter"
assert hit.reason_field == "finish_reason"
assert hit.reason_value == "content_filter"
def test_case_insensitive_match(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "CONTENT_FILTER"})) is not None
def test_other_finish_reasons_pass_through(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "stop"})) is None
assert d.detect(_ai(response_metadata={"finish_reason": "tool_calls"})) is None
assert d.detect(_ai(response_metadata={"finish_reason": "length"})) is None
def test_missing_metadata_passes_through(self):
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai()) is None
def test_non_string_finish_reason_passes_through(self):
# Some adapters may stash an enum or dict — must not raise.
d = OpenAICompatibleContentFilterDetector()
assert d.detect(_ai(response_metadata={"finish_reason": 42})) is None
assert d.detect(_ai(response_metadata={"finish_reason": {"value": "content_filter"}})) is None
def test_falls_back_to_additional_kwargs(self):
# Legacy adapters surface finish_reason via additional_kwargs.
d = OpenAICompatibleContentFilterDetector()
hit = d.detect(_ai(additional_kwargs={"finish_reason": "content_filter"}))
assert hit is not None
def test_configurable_extra_values(self):
# Chinese providers sometimes use bespoke tokens.
d = OpenAICompatibleContentFilterDetector(finish_reasons=["content_filter", "sensitive", "violation"])
assert d.detect(_ai(response_metadata={"finish_reason": "sensitive"})) is not None
assert d.detect(_ai(response_metadata={"finish_reason": "violation"})) is not None
# Original token still matches.
assert d.detect(_ai(response_metadata={"finish_reason": "content_filter"})) is not None
def test_carries_azure_content_filter_results(self):
d = OpenAICompatibleContentFilterDetector()
filter_results = {"hate": {"filtered": True, "severity": "high"}}
hit = d.detect(
_ai(
response_metadata={
"finish_reason": "content_filter",
"content_filter_results": filter_results,
},
)
)
assert hit is not None
assert hit.extras["content_filter_results"] == filter_results
class TestAnthropicRefusalDetector:
def test_default_matches_refusal(self):
hit = AnthropicRefusalDetector().detect(_ai(response_metadata={"stop_reason": "refusal"}))
assert hit is not None
assert hit.reason_field == "stop_reason"
assert hit.reason_value == "refusal"
def test_other_stop_reasons_pass_through(self):
d = AnthropicRefusalDetector()
assert d.detect(_ai(response_metadata={"stop_reason": "end_turn"})) is None
assert d.detect(_ai(response_metadata={"stop_reason": "tool_use"})) is None
assert d.detect(_ai(response_metadata={"stop_reason": "max_tokens"})) is None
def test_anthropic_does_not_steal_finish_reason(self):
# An OpenAI message must not accidentally trip the Anthropic detector.
assert AnthropicRefusalDetector().detect(_ai(response_metadata={"finish_reason": "content_filter"})) is None
class TestGeminiSafetyDetector:
def test_default_set_covers_documented_reasons(self):
d = GeminiSafetyDetector()
for reason in (
# text safety
"SAFETY",
"BLOCKLIST",
"PROHIBITED_CONTENT",
"SPII",
"RECITATION",
# image safety
"IMAGE_SAFETY",
"IMAGE_PROHIBITED_CONTENT",
"IMAGE_RECITATION",
):
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is not None, reason
def test_normal_termination_passes_through(self):
d = GeminiSafetyDetector()
assert d.detect(_ai(response_metadata={"finish_reason": "STOP"})) is None
# MAX_TOKENS / LANGUAGE / NO_IMAGE / OTHER / IMAGE_OTHER /
# MALFORMED_FUNCTION_CALL / UNEXPECTED_TOOL_CALL are intentionally
# excluded from the default set — they are either normal termination,
# capability mismatches, too broad (OTHER), or tool-call protocol
# errors. See GeminiSafetyDetector docstring.
for reason in (
"MAX_TOKENS",
"LANGUAGE",
"NO_IMAGE",
"OTHER",
"IMAGE_OTHER",
"MALFORMED_FUNCTION_CALL",
"UNEXPECTED_TOOL_CALL",
"FINISH_REASON_UNSPECIFIED",
):
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is None, reason
def test_carries_safety_ratings(self):
ratings = [{"category": "HARM_CATEGORY_HARASSMENT", "probability": "HIGH"}]
hit = GeminiSafetyDetector().detect(
_ai(
response_metadata={
"finish_reason": "SAFETY",
"safety_ratings": ratings,
},
)
)
assert hit is not None
assert hit.extras["safety_ratings"] == ratings
class TestDefaultDetectorSet:
def test_default_set_returns_three_detectors(self):
dets = default_detectors()
names = {d.name for d in dets}
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
def test_default_set_returns_fresh_list(self):
# Caller mutation must not affect later calls.
first = default_detectors()
first.clear()
second = default_detectors()
assert len(second) == 3
class TestProtocolConformance:
def test_builtins_satisfy_protocol(self):
for d in default_detectors():
assert isinstance(d, SafetyTerminationDetector)
def test_safety_termination_is_frozen(self):
t = SafetyTermination(detector="x", reason_field="finish_reason", reason_value="content_filter")
try:
t.detector = "y" # type: ignore[misc]
except Exception:
return
raise AssertionError("SafetyTermination should be frozen")
@@ -134,8 +134,14 @@ def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
assert captured["app_config"] is app_config
assert len(middlewares) == 6
assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware)
# 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling,
# SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware
# (enabled by default — see SafetyFinishReasonConfig).
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
assert len(middlewares) == 7
assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares)
assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware)
def test_wrap_tool_call_passthrough_on_success():
+36 -1
View File
@@ -15,7 +15,7 @@
# ============================================================================
# Bump this number when the config schema changes.
# Run `make config-upgrade` to merge new fields into your local config.yaml.
config_version: 9
config_version: 10
# ============================================================================
# Logging
@@ -535,6 +535,41 @@ loop_detection:
# warn: 150
# hard_limit: 300
# ============================================================================
# Provider Safety Termination Configuration
# ============================================================================
# Intercept AIMessages where the provider stopped generation for safety reasons
# (e.g. OpenAI finish_reason='content_filter', Anthropic stop_reason='refusal',
# Gemini finish_reason='SAFETY') while still returning tool_calls. The
# tool_calls in such responses are typically truncated/unreliable and must
# not be executed. See issue #3028 for the full failure mode.
#
# Detectors are loaded by class path via reflection (same pattern as
# guardrails / models / tools). The built-in set covers OpenAI-compatible
# content_filter, Anthropic refusal, and Gemini SAFETY/BLOCKLIST/
# PROHIBITED_CONTENT/SPII/RECITATION.
safety_finish_reason:
enabled: true
# Leave `detectors` unset to use the built-in detector set. Set to a
# non-empty list to fully override (use `enabled: false` to disable instead
# of providing an empty list).
#
# Example — extend the OpenAI-compatible detector for a Chinese provider
# whose gateway uses a non-standard finish_reason token:
# detectors:
# - use: deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector
# config:
# finish_reasons: ["content_filter", "sensitive", "risk_control"]
# - use: deerflow.agents.middlewares.safety_termination_detectors:AnthropicRefusalDetector
# - use: deerflow.agents.middlewares.safety_termination_detectors:GeminiSafetyDetector
#
# Example — add a custom detector for an in-house provider:
# detectors:
# - use: my_company.deerflow_ext:WenxinSafetyDetector
# config:
# error_codes: [336003, 17, 18]
# ============================================================================
# Sandbox Configuration
# ============================================================================