mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
fix(runtime): suppress tool execution when provider safety-terminates with tool_calls (#3035)
* fix(runtime): suppress tool execution when provider safety-terminates with tool_calls When a provider stops generation for safety reasons (OpenAI/Moonshot finish_reason=content_filter, Anthropic stop_reason=refusal, Gemini finish_reason=SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/RECITATION/ IMAGE_SAFETY/...), the response may still carry truncated tool_calls. LangChain's tool router treats any non-empty tool_calls as executable, so partial arguments (e.g. write_file with a half-finished markdown) get dispatched and the agent loops on retry. Add SafetyFinishReasonMiddleware at after_model: detect safety termination via a pluggable detector registry, clear both structured tool_calls and raw additional_kwargs.tool_calls / function_call, preserve response_metadata.finish_reason for downstream observers, stamp additional_kwargs.safety_termination for traces, append a user-facing explanation to message content (list-aware for thinking blocks), and emit a safety_termination custom stream event so SSE consumers can reconcile any "tool starting..." UI. Default detectors cover OpenAI-compatible content_filter, Anthropic refusal, and Gemini safety enums (text + image). Custom providers are added via reflection (same pattern as guardrails). Wired into both lead-agent and subagent runtimes. Closes #3028 * fix(runtime): persist safety_termination as a middleware audit event Address review on #3035: the SSE custom event is great for live consumers but invisible to post-run audit. RunEventStore should carry its own row so operators can answer "which runs were safety-suppressed today?" from a single SQL query without joining the message body. Worker now exposes the run-scoped RunJournal via runtime.context["__run_journal"] (sentinel key, internal channel). SafetyFinishReasonMiddleware calls the previously-unused RunJournal.record_middleware, which emits event_type = "middleware:safety_termination" category = "middleware" content = {name, hook, action, changes={ detector, reason_field, reason_value, suppressed_tool_call_count, suppressed_tool_call_names, suppressed_tool_call_ids, message_id, extras}} Tool *arguments* are deliberately excluded — those are the very content the provider filtered and persisting them would defeat the purpose of the safety filter (per review note in #3035). Graceful skips when journal is absent (subagent runtime, unit tests, no-event-store local dev). Journal exceptions never propagate into the agent loop. Refs #3028 * fix(runtime): satisfy ruff format + address Copilot review - ruff format on safety_finish_reason_config.py and e2e demo (CI lint failed on ruff format --check; backend Makefile lint target runs ruff check AND ruff format --check). - Docstring on SafetyFinishReasonConfig now says resolve_variable to match the actual loader used in from_config (the wording was resolve_class previously; behavior is unchanged — resolve_variable mirrors how guardrails.provider is loaded). - Switch the AIMessage type check in SafetyFinishReasonMiddleware._apply from getattr(last, "type") == "ai" to isinstance(last, AIMessage), matching TokenUsageMiddleware / TodoMiddleware / ViewImageMiddleware / SummarizationMiddleware which are the dominant pattern. Refs #3028
This commit is contained in:
@@ -29,6 +29,7 @@ from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||
from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
@@ -338,6 +339,15 @@ def _build_middlewares(
|
||||
if custom_middlewares:
|
||||
middlewares.extend(custom_middlewares)
|
||||
|
||||
# SafetyFinishReasonMiddleware — suppress tool execution when the provider
|
||||
# safety-terminated the response. Registered after custom middlewares so
|
||||
# that LangChain's reverse-order after_model dispatch runs Safety first;
|
||||
# cleared tool_calls then flow through Loop/Subagent accounting without
|
||||
# firing extra alarms. See safety_finish_reason_middleware.py docstring.
|
||||
safety_config = resolved_app_config.safety_finish_reason
|
||||
if safety_config.enabled:
|
||||
middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config))
|
||||
|
||||
# ClarificationMiddleware should always be last
|
||||
middlewares.append(ClarificationMiddleware())
|
||||
return middlewares
|
||||
|
||||
+317
@@ -0,0 +1,317 @@
|
||||
"""Suppress tool execution when the provider safety-terminated the response.
|
||||
|
||||
Background — see issue bytedance/deer-flow#3028.
|
||||
|
||||
Some providers (OpenAI ``finish_reason='content_filter'``, Anthropic
|
||||
``stop_reason='refusal'``, Gemini ``finish_reason='SAFETY'`` ...) can stop
|
||||
generation mid-stream while still returning partially-formed ``tool_calls``.
|
||||
LangChain's tool router treats any AIMessage with a non-empty ``tool_calls``
|
||||
field as "go execute these", so half-truncated arguments — e.g. a markdown
|
||||
``write_file`` that stops in the middle of a sentence — get dispatched as if
|
||||
they were complete. The agent then sees the truncated file, tries to fix it,
|
||||
gets filtered again, and loops.
|
||||
|
||||
This middleware sits at ``after_model`` and gates that behaviour: when a
|
||||
configured ``SafetyTerminationDetector`` fires *and* the AIMessage carries
|
||||
tool calls, we strip the tool calls (both structured and raw provider
|
||||
payloads), append a user-facing explanation, and stash observability fields
|
||||
in ``additional_kwargs.safety_termination`` so logs, traces, and SSE
|
||||
consumers can see what happened.
|
||||
|
||||
Hook choice: ``after_model`` (not ``wrap_model_call``) because the response
|
||||
is a *normal* return — not an exception — and we want to participate in the
|
||||
same after-model chain as ``LoopDetectionMiddleware``, with which we share
|
||||
the same tool-call-suppression mechanic but a different trigger.
|
||||
|
||||
Placement: register *after* ``LoopDetectionMiddleware`` in the middleware
|
||||
list. LangChain factory wires ``after_model`` edges in reverse list order
|
||||
(``langchain/agents/factory.py:add_edge("model", middleware_w_after_model[-1])``,
|
||||
then walks ``range(len-1, 0, -1)``), so the *last* registered middleware is
|
||||
the *first* to observe the model output. Registering Safety after Loop
|
||||
means Safety sees the raw response first, clears tool calls if it fires,
|
||||
and Loop then accounts against the cleaned message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.middlewares.safety_termination_detectors import (
|
||||
SafetyTermination,
|
||||
SafetyTerminationDetector,
|
||||
default_detectors,
|
||||
)
|
||||
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_USER_FACING_MESSAGE = (
|
||||
"The model provider stopped this response with a safety-related signal "
|
||||
"({reason_field}={reason_value!r}, detector={detector!r}). Any tool "
|
||||
"calls produced in this turn were suppressed because their arguments "
|
||||
"may be truncated and unsafe to execute. Please rephrase the request "
|
||||
"or ask for a narrower output."
|
||||
)
|
||||
|
||||
|
||||
class SafetyFinishReasonMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Strip tool_calls from AIMessages flagged by a SafetyTerminationDetector."""
|
||||
|
||||
def __init__(self, detectors: list[SafetyTerminationDetector] | None = None) -> None:
|
||||
super().__init__()
|
||||
# Copy so caller mutations after construction don't leak into us.
|
||||
self._detectors: list[SafetyTerminationDetector] = list(detectors) if detectors else default_detectors()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: SafetyFinishReasonConfig) -> SafetyFinishReasonMiddleware:
|
||||
"""Construct from validated Pydantic config, honouring the
|
||||
reflection-loaded detector list when provided.
|
||||
|
||||
An explicit empty list is intentionally rejected — it would silently
|
||||
disable detection while leaving the middleware in the chain, which
|
||||
is the worst of both worlds. Use ``enabled: false`` instead.
|
||||
"""
|
||||
if config.detectors is None:
|
||||
return cls()
|
||||
|
||||
if not config.detectors:
|
||||
raise ValueError("safety_finish_reason.detectors must be omitted (use built-ins) or contain at least one entry; use enabled=false to disable the middleware entirely.")
|
||||
|
||||
from deerflow.reflection import resolve_variable
|
||||
|
||||
detectors: list[SafetyTerminationDetector] = []
|
||||
for entry in config.detectors:
|
||||
detector_cls = resolve_variable(entry.use)
|
||||
kwargs = dict(entry.config) if entry.config else {}
|
||||
detector = detector_cls(**kwargs)
|
||||
if not isinstance(detector, SafetyTerminationDetector):
|
||||
raise TypeError(f"{entry.use} did not produce a SafetyTerminationDetector (got {type(detector).__name__}); ensure it has a `name` attribute and a `detect(message)` method")
|
||||
detectors.append(detector)
|
||||
return cls(detectors=detectors)
|
||||
|
||||
# ----- detection -------------------------------------------------------
|
||||
|
||||
def _detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
for detector in self._detectors:
|
||||
try:
|
||||
hit = detector.detect(message)
|
||||
except Exception: # noqa: BLE001 - never let a buggy detector break the agent run
|
||||
logger.exception("SafetyTerminationDetector %r raised; treating as no-match", getattr(detector, "name", type(detector).__name__))
|
||||
continue
|
||||
if hit is not None:
|
||||
return hit
|
||||
return None
|
||||
|
||||
# ----- message rewriting ----------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _append_user_message(content: object, text: str) -> str | list:
|
||||
"""Append a plain-text explanation to AIMessage content.
|
||||
|
||||
Mirrors ``LoopDetectionMiddleware._append_text`` so list-content
|
||||
responses (Anthropic thinking blocks, vLLM reasoning splits) keep
|
||||
their structure instead of being string-coerced into a TypeError.
|
||||
"""
|
||||
if content is None or content == "":
|
||||
return text
|
||||
if isinstance(content, list):
|
||||
return [*content, {"type": "text", "text": f"\n\n{text}"}]
|
||||
if isinstance(content, str):
|
||||
return content + f"\n\n{text}"
|
||||
return str(content) + f"\n\n{text}"
|
||||
|
||||
def _build_suppressed_message(
|
||||
self,
|
||||
message: AIMessage,
|
||||
termination: SafetyTermination,
|
||||
) -> AIMessage:
|
||||
suppressed_names = [tc.get("name") or "unknown" for tc in (message.tool_calls or [])]
|
||||
explanation = _USER_FACING_MESSAGE.format(
|
||||
reason_field=termination.reason_field,
|
||||
reason_value=termination.reason_value,
|
||||
detector=termination.detector,
|
||||
)
|
||||
new_content = self._append_user_message(message.content, explanation)
|
||||
|
||||
# clone_ai_message_with_tool_calls handles structured tool_calls,
|
||||
# raw additional_kwargs.tool_calls, and function_call in one shot.
|
||||
# It only rewrites finish_reason when the old value was "tool_calls",
|
||||
# which is not our case — content_filter / refusal / SAFETY stay put
|
||||
# so downstream SSE / converters keep seeing the real provider reason.
|
||||
cleared = clone_ai_message_with_tool_calls(message, [], content=new_content)
|
||||
|
||||
# Re-clone additional_kwargs so we don't accidentally mutate the
|
||||
# dict returned by clone_ai_message_with_tool_calls (which already
|
||||
# made a shallow copy, but downstream model_copy still references
|
||||
# it). Then stamp the observability record.
|
||||
kwargs = dict(getattr(cleared, "additional_kwargs", None) or {})
|
||||
kwargs["safety_termination"] = {
|
||||
"detector": termination.detector,
|
||||
"reason_field": termination.reason_field,
|
||||
"reason_value": termination.reason_value,
|
||||
"suppressed_tool_call_count": len(suppressed_names),
|
||||
"suppressed_tool_call_names": suppressed_names,
|
||||
"extras": dict(termination.extras) if termination.extras else {},
|
||||
}
|
||||
return cleared.model_copy(update={"additional_kwargs": kwargs})
|
||||
|
||||
# ----- observability ---------------------------------------------------
|
||||
|
||||
def _emit_event(
|
||||
self,
|
||||
termination: SafetyTermination,
|
||||
suppressed_names: list[str],
|
||||
runtime: Runtime,
|
||||
) -> None:
|
||||
"""Notify SSE consumers (e.g. the web UI) that a tool turn was
|
||||
suppressed so they can reconcile any "tool starting..." placeholders
|
||||
already streamed to the user. Failures are logged at debug and
|
||||
ignored — this is a best-effort signal."""
|
||||
try:
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
except Exception: # noqa: BLE001
|
||||
logger.debug("get_stream_writer unavailable; skipping safety_termination event", exc_info=True)
|
||||
return
|
||||
|
||||
thread_id = None
|
||||
if runtime is not None and getattr(runtime, "context", None):
|
||||
thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None
|
||||
|
||||
try:
|
||||
writer(
|
||||
{
|
||||
"type": "safety_termination",
|
||||
"detector": termination.detector,
|
||||
"reason_field": termination.reason_field,
|
||||
"reason_value": termination.reason_value,
|
||||
"suppressed_tool_call_count": len(suppressed_names),
|
||||
"suppressed_tool_call_names": suppressed_names,
|
||||
"thread_id": thread_id,
|
||||
}
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.debug("Failed to emit safety_termination stream event", exc_info=True)
|
||||
|
||||
def _record_audit_event(
|
||||
self,
|
||||
termination: SafetyTermination,
|
||||
message,
|
||||
tool_calls: list[dict],
|
||||
runtime: Runtime,
|
||||
) -> None:
|
||||
"""Write a ``middleware:safety_termination`` record to RunEventStore
|
||||
for post-run auditability.
|
||||
|
||||
The custom stream event in ``_emit_event`` is consumed by live SSE
|
||||
clients and disappears after the run; this event is persisted so an
|
||||
operator can answer "which runs were safety-suppressed today?" from
|
||||
a single SQL query without joining the message body. Worker exposes
|
||||
the run-scoped ``RunJournal`` via ``runtime.context["__run_journal"]``;
|
||||
absent in unit-test / subagent / no-event-store paths, in which case
|
||||
we silently skip.
|
||||
|
||||
Tool **arguments** are deliberately **not** recorded — those are the
|
||||
very content the provider filtered; persisting them would defeat the
|
||||
purpose of the safety filter. Names / count / ids are sufficient for
|
||||
audit and debugging (issue #3028 review).
|
||||
"""
|
||||
journal = None
|
||||
if runtime is not None and getattr(runtime, "context", None):
|
||||
context = runtime.context
|
||||
if isinstance(context, dict):
|
||||
journal = context.get("__run_journal")
|
||||
if journal is None:
|
||||
return
|
||||
|
||||
suppressed_names = [tc.get("name") or "unknown" for tc in tool_calls]
|
||||
suppressed_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
|
||||
|
||||
changes = {
|
||||
"detector": termination.detector,
|
||||
"reason_field": termination.reason_field,
|
||||
"reason_value": termination.reason_value,
|
||||
"suppressed_tool_call_count": len(tool_calls),
|
||||
"suppressed_tool_call_names": suppressed_names,
|
||||
"suppressed_tool_call_ids": suppressed_ids,
|
||||
"message_id": getattr(message, "id", None),
|
||||
"extras": dict(termination.extras) if termination.extras else {},
|
||||
}
|
||||
|
||||
try:
|
||||
journal.record_middleware(
|
||||
tag="safety_termination",
|
||||
name=type(self).__name__,
|
||||
hook="after_model",
|
||||
action="suppress_tool_calls",
|
||||
changes=changes,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
# Audit-event persistence must never break agent execution.
|
||||
logger.debug("Failed to record middleware:safety_termination event", exc_info=True)
|
||||
|
||||
# ----- main apply ------------------------------------------------------
|
||||
|
||||
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
return None
|
||||
|
||||
# Issue scope: only intervene when there's something to suppress.
|
||||
# ``content_filter`` without tool_calls is allowed through unchanged
|
||||
# so the partial text response (if any) reaches the user naturally.
|
||||
tool_calls = last.tool_calls
|
||||
if not tool_calls:
|
||||
return None
|
||||
|
||||
termination = self._detect(last)
|
||||
if termination is None:
|
||||
return None
|
||||
|
||||
patched = self._build_suppressed_message(last, termination)
|
||||
|
||||
thread_id = None
|
||||
if runtime is not None and getattr(runtime, "context", None):
|
||||
thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None
|
||||
|
||||
logger.warning(
|
||||
"Provider safety termination detected — suppressed %d tool call(s)",
|
||||
len(tool_calls),
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"detector": termination.detector,
|
||||
"reason_field": termination.reason_field,
|
||||
"reason_value": termination.reason_value,
|
||||
"suppressed_tool_call_names": [tc.get("name") for tc in tool_calls],
|
||||
},
|
||||
)
|
||||
|
||||
self._emit_event(termination, [tc.get("name") or "unknown" for tc in tool_calls], runtime)
|
||||
self._record_audit_event(termination, last, list(tool_calls), runtime)
|
||||
|
||||
return {"messages": [patched]}
|
||||
|
||||
# ----- hooks -----------------------------------------------------------
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
@@ -0,0 +1,237 @@
|
||||
"""Detectors for provider-side safety termination signals.
|
||||
|
||||
Different LLM providers signal "I stopped this response for safety reasons"
|
||||
through different fields with different values. This module defines a small
|
||||
strategy interface and three built-in detectors that cover the major
|
||||
providers DeerFlow supports today. New providers (Wenxin, Hunyuan, Bedrock
|
||||
adapters, in-house gateways, ...) can be added by implementing
|
||||
``SafetyTerminationDetector`` and wiring it through
|
||||
``config.yaml: safety_finish_reason.detectors``.
|
||||
|
||||
The middleware that consumes these detectors lives in
|
||||
``safety_finish_reason_middleware.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SafetyTermination:
|
||||
"""A detected safety-related termination signal.
|
||||
|
||||
Attributes:
|
||||
detector: Name of the detector that produced this result. Used for
|
||||
observability so operators can see which provider rule fired.
|
||||
reason_field: The message metadata field that carried the signal
|
||||
(e.g. ``finish_reason``, ``stop_reason``).
|
||||
reason_value: The actual value of that field
|
||||
(e.g. ``content_filter``, ``refusal``, ``SAFETY``).
|
||||
extras: Provider-specific metadata that may help downstream
|
||||
consumers (e.g. Azure OpenAI content_filter_results, Gemini
|
||||
safety_ratings). Detectors are free to populate or skip this.
|
||||
"""
|
||||
|
||||
detector: str
|
||||
reason_field: str
|
||||
reason_value: str
|
||||
extras: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SafetyTerminationDetector(Protocol):
|
||||
"""Strategy interface for provider safety termination detection."""
|
||||
|
||||
name: str
|
||||
|
||||
def detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
"""Return a SafetyTermination if *message* indicates provider safety
|
||||
termination, otherwise return ``None``.
|
||||
|
||||
Implementations must be side-effect free and tolerant of missing or
|
||||
oddly-typed metadata — detectors run on every model response.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def _get_metadata_value(message: AIMessage, field_name: str) -> str | None:
|
||||
"""Read a string-typed value from either ``response_metadata`` or
|
||||
``additional_kwargs``.
|
||||
|
||||
LangChain provider adapters are inconsistent about where they stash
|
||||
provider stop signals. Most modern adapters use ``response_metadata``,
|
||||
but some legacy / passthrough paths still surface them via
|
||||
``additional_kwargs``. We check both, in that order, and only accept
|
||||
string values — Pydantic enums or dicts are ignored so we never raise
|
||||
on malformed inputs.
|
||||
"""
|
||||
for container_name in ("response_metadata", "additional_kwargs"):
|
||||
container = getattr(message, container_name, None) or {}
|
||||
if not isinstance(container, dict):
|
||||
continue
|
||||
value = container.get(field_name)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
class OpenAICompatibleContentFilterDetector:
|
||||
"""OpenAI-compatible content_filter signal.
|
||||
|
||||
Covers OpenAI, Azure OpenAI, Moonshot/Kimi, DeepSeek, Mistral, vLLM,
|
||||
Qwen (OpenAI-compatible mode), and any other adapter that follows the
|
||||
OpenAI ``finish_reason`` convention.
|
||||
|
||||
Some Chinese providers ship custom OpenAI-compatible gateways that use
|
||||
alternative tokens like ``sensitive`` or ``violation``. Extend the set
|
||||
via the ``finish_reasons`` kwarg in config.
|
||||
"""
|
||||
|
||||
name = "openai_compatible_content_filter"
|
||||
|
||||
def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None:
|
||||
configured = finish_reasons if finish_reasons is not None else ("content_filter",)
|
||||
self._finish_reasons: frozenset[str] = frozenset(r.lower() for r in configured)
|
||||
|
||||
def detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
value = _get_metadata_value(message, "finish_reason")
|
||||
if value is None or value.lower() not in self._finish_reasons:
|
||||
return None
|
||||
|
||||
extras: dict[str, Any] = {}
|
||||
# Azure OpenAI ships a structured content_filter_results block; carry it
|
||||
# through so operators can see *what* was filtered without re-tracing.
|
||||
response_metadata = getattr(message, "response_metadata", None) or {}
|
||||
if isinstance(response_metadata, dict):
|
||||
filter_results = response_metadata.get("content_filter_results")
|
||||
if filter_results:
|
||||
extras["content_filter_results"] = filter_results
|
||||
|
||||
return SafetyTermination(
|
||||
detector=self.name,
|
||||
reason_field="finish_reason",
|
||||
reason_value=value,
|
||||
extras=extras,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicRefusalDetector:
|
||||
"""Anthropic ``stop_reason == "refusal"`` signal.
|
||||
|
||||
Anthropic models surface safety refusals via a dedicated ``stop_reason``
|
||||
rather than ``finish_reason``. See:
|
||||
https://platform.claude.com/docs/en/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals
|
||||
"""
|
||||
|
||||
name = "anthropic_refusal"
|
||||
|
||||
def __init__(self, stop_reasons: list[str] | tuple[str, ...] | None = None) -> None:
|
||||
configured = stop_reasons if stop_reasons is not None else ("refusal",)
|
||||
self._stop_reasons: frozenset[str] = frozenset(r.lower() for r in configured)
|
||||
|
||||
def detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
value = _get_metadata_value(message, "stop_reason")
|
||||
if value is None or value.lower() not in self._stop_reasons:
|
||||
return None
|
||||
return SafetyTermination(
|
||||
detector=self.name,
|
||||
reason_field="stop_reason",
|
||||
reason_value=value,
|
||||
)
|
||||
|
||||
|
||||
class GeminiSafetyDetector:
|
||||
"""Gemini / Vertex AI safety-related finish reasons.
|
||||
|
||||
Gemini uses the same ``finish_reason`` field as OpenAI but with an
|
||||
enumerated upper-case taxonomy. The default set covers every Gemini
|
||||
finish_reason that means "the model stopped because the content/image
|
||||
tripped a safety, blocklist, recitation, or PII filter" — i.e. cases
|
||||
where any tool_calls returned alongside are likely truncated/
|
||||
unreliable. Full enum:
|
||||
https://docs.cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform_v1.types.Candidate.FinishReason
|
||||
|
||||
Intentionally **excluded** from the default set:
|
||||
- ``STOP`` — normal termination.
|
||||
- ``MAX_TOKENS`` — output length truncation, not safety
|
||||
(same root failure mode as
|
||||
content_filter, but issue #3028
|
||||
scopes it out; expose separately if
|
||||
desired).
|
||||
- ``LANGUAGE`` / ``NO_IMAGE`` — capability mismatches, unrelated to
|
||||
safety; tool_calls would be absent
|
||||
anyway.
|
||||
- ``MALFORMED_FUNCTION_CALL`` /
|
||||
``UNEXPECTED_TOOL_CALL`` — tool-call protocol errors. The
|
||||
tool_calls are *also* unreliable
|
||||
here, but the failure category is
|
||||
distinct from safety filtering;
|
||||
handle in a dedicated detector to
|
||||
keep observability records honest.
|
||||
- ``OTHER`` / ``IMAGE_OTHER`` /
|
||||
``FINISH_REASON_UNSPECIFIED`` — too broad to enable by default;
|
||||
opt in via ``finish_reasons=`` if
|
||||
your provider abuses these.
|
||||
"""
|
||||
|
||||
name = "gemini_safety"
|
||||
|
||||
_DEFAULT_FINISH_REASONS = (
|
||||
# Text safety
|
||||
"SAFETY",
|
||||
"BLOCKLIST",
|
||||
"PROHIBITED_CONTENT",
|
||||
"SPII",
|
||||
"RECITATION",
|
||||
# Image safety (multimodal generation)
|
||||
"IMAGE_SAFETY",
|
||||
"IMAGE_PROHIBITED_CONTENT",
|
||||
"IMAGE_RECITATION",
|
||||
)
|
||||
|
||||
def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None:
|
||||
configured = finish_reasons if finish_reasons is not None else self._DEFAULT_FINISH_REASONS
|
||||
self._finish_reasons: frozenset[str] = frozenset(r.upper() for r in configured)
|
||||
|
||||
def detect(self, message: AIMessage) -> SafetyTermination | None:
|
||||
value = _get_metadata_value(message, "finish_reason")
|
||||
if value is None or value.upper() not in self._finish_reasons:
|
||||
return None
|
||||
|
||||
extras: dict[str, Any] = {}
|
||||
response_metadata = getattr(message, "response_metadata", None) or {}
|
||||
if isinstance(response_metadata, dict):
|
||||
# Gemini surfaces per-category scoring under safety_ratings.
|
||||
ratings = response_metadata.get("safety_ratings")
|
||||
if ratings:
|
||||
extras["safety_ratings"] = ratings
|
||||
|
||||
return SafetyTermination(
|
||||
detector=self.name,
|
||||
reason_field="finish_reason",
|
||||
reason_value=value,
|
||||
extras=extras,
|
||||
)
|
||||
|
||||
|
||||
def default_detectors() -> list[SafetyTerminationDetector]:
|
||||
"""Built-in detector set used when no custom detectors are configured."""
|
||||
return [
|
||||
OpenAICompatibleContentFilterDetector(),
|
||||
AnthropicRefusalDetector(),
|
||||
GeminiSafetyDetector(),
|
||||
]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AnthropicRefusalDetector",
|
||||
"GeminiSafetyDetector",
|
||||
"OpenAICompatibleContentFilterDetector",
|
||||
"SafetyTermination",
|
||||
"SafetyTerminationDetector",
|
||||
"default_detectors",
|
||||
]
|
||||
+10
@@ -164,4 +164,14 @@ def build_subagent_runtime_middlewares(
|
||||
|
||||
middlewares.append(ViewImageMiddleware())
|
||||
|
||||
# Same provider safety-termination guard the lead agent uses — subagents
|
||||
# are equally exposed to truncated tool_calls returned with
|
||||
# finish_reason=content_filter (and friends), and the bad call would then
|
||||
# propagate back to the lead agent via the task tool result.
|
||||
safety_config = app_config.safety_finish_reason
|
||||
if safety_config.enabled:
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
|
||||
middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config))
|
||||
|
||||
return middlewares
|
||||
|
||||
@@ -20,6 +20,7 @@ from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.run_events_config import RunEventsConfig
|
||||
from deerflow.config.runtime_paths import existing_project_file
|
||||
from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
|
||||
from deerflow.config.skills_config import SkillsConfig
|
||||
@@ -102,6 +103,7 @@ class AppConfig(BaseModel):
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
||||
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
||||
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
|
||||
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Configuration for SafetyFinishReasonMiddleware.
|
||||
|
||||
Mirrors the shape of GuardrailsConfig: detectors are loaded by class path
|
||||
through ``deerflow.reflection.resolve_variable`` (same loader the
|
||||
``guardrails.provider`` config uses) so users can drop in custom provider
|
||||
detectors without modifying core code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SafetyDetectorConfig(BaseModel):
|
||||
"""One detector entry under ``safety_finish_reason.detectors``."""
|
||||
|
||||
use: str = Field(
|
||||
description=("Class path of a SafetyTerminationDetector implementation (e.g. 'deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector')."),
|
||||
)
|
||||
config: dict = Field(
|
||||
default_factory=dict,
|
||||
description="Constructor kwargs passed to the detector class.",
|
||||
)
|
||||
|
||||
|
||||
class SafetyFinishReasonConfig(BaseModel):
|
||||
"""Configuration for the SafetyFinishReasonMiddleware.
|
||||
|
||||
The middleware intercepts AIMessages where the provider signaled a
|
||||
safety-related termination (e.g. OpenAI ``finish_reason='content_filter'``)
|
||||
while still returning tool calls, and suppresses those tool calls so the
|
||||
half-truncated arguments never execute.
|
||||
"""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Master switch for the SafetyFinishReasonMiddleware.",
|
||||
)
|
||||
detectors: list[SafetyDetectorConfig] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Custom detector list. Leave unset (None) to use the built-in "
|
||||
"set covering OpenAI-compatible content_filter, Anthropic "
|
||||
"refusal, and Gemini SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/"
|
||||
"RECITATION. Provide a non-null list to fully override."
|
||||
),
|
||||
)
|
||||
@@ -219,6 +219,12 @@ async def run_agent(
|
||||
# manually here because we drive the graph through ``agent.astream(config=...)``
|
||||
# without passing the official ``context=`` parameter.
|
||||
runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"), ctx.app_config)
|
||||
# Expose the run-scoped journal under a sentinel key so middleware can
|
||||
# write audit events (e.g. SafetyFinishReasonMiddleware recording
|
||||
# suppressed tool calls). Double-underscore prefix marks it as a
|
||||
# runtime-internal channel; user code must not depend on the key name.
|
||||
if journal is not None:
|
||||
runtime_ctx["__run_journal"] = journal
|
||||
_install_runtime_context(config, runtime_ctx)
|
||||
runtime = Runtime(context=cast(Any, runtime_ctx), store=store)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
"""End-to-end demo: SafetyFinishReasonMiddleware on the real DeerFlow lead-agent.
|
||||
|
||||
What it proves
|
||||
--------------
|
||||
- The real ``make_lead_agent`` / ``DeerFlowClient`` pipeline is built (full
|
||||
18-middleware chain, sandbox, tools, etc.).
|
||||
- A model that returns ``finish_reason='content_filter'`` + ``tool_calls``
|
||||
triggers SafetyFinishReasonMiddleware.
|
||||
- LangChain's tool router never invokes ``write_file`` — the truncated
|
||||
arguments do **not** reach the sandbox.
|
||||
- A ``safety_termination`` custom event is emitted on the stream and the
|
||||
final AIMessage carries the observability stamp.
|
||||
|
||||
Run from backend/ directory:
|
||||
PYTHONPATH=. uv run python scripts/e2e_safety_termination_demo.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake provider that mimics Moonshot's content_filter behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ContentFilteredFakeModel(BaseChatModel):
|
||||
"""First call returns finish_reason=content_filter + truncated write_file
|
||||
tool_call. Subsequent calls return a normal stop response so the agent
|
||||
can terminate (the middleware should make a second call unnecessary by
|
||||
clearing tool_calls, but we keep this safety net in case loop-detection
|
||||
or anything else triggers another model invocation)."""
|
||||
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-content-filtered"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self.call_count += 1
|
||||
if self.call_count == 1:
|
||||
msg = AIMessage(
|
||||
content="# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_truncated_write",
|
||||
"name": "write_file",
|
||||
"args": {
|
||||
"path": "/mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md",
|
||||
"content": "# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与",
|
||||
},
|
||||
}
|
||||
],
|
||||
response_metadata={
|
||||
"finish_reason": "content_filter",
|
||||
"model_name": "kimi-k2.6",
|
||||
"model_provider": "openai",
|
||||
},
|
||||
)
|
||||
else:
|
||||
msg = AIMessage(
|
||||
content="(secondary call, should not be needed)",
|
||||
response_metadata={"finish_reason": "stop", "model_name": "kimi-k2.6"},
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=msg)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Driver
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> int:
|
||||
# Inject the fake model BEFORE constructing the client. Both the
|
||||
# client module and the lead-agent module bind ``create_chat_model``
|
||||
# at import time via ``from deerflow.models import create_chat_model``,
|
||||
# so we patch both attribute slots — the source-of-truth patch on
|
||||
# ``factory.create_chat_model`` doesn't propagate back into already-
|
||||
# imported names.
|
||||
import deerflow.agents.lead_agent.agent as lead_agent_module
|
||||
import deerflow.client as client_module
|
||||
|
||||
fake = _ContentFilteredFakeModel()
|
||||
originals = {
|
||||
"lead": lead_agent_module.create_chat_model,
|
||||
"client": client_module.create_chat_model,
|
||||
}
|
||||
|
||||
def fake_create_chat_model(*args, **kwargs):
|
||||
return fake
|
||||
|
||||
lead_agent_module.create_chat_model = fake_create_chat_model
|
||||
client_module.create_chat_model = fake_create_chat_model
|
||||
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
try:
|
||||
client = DeerFlowClient()
|
||||
|
||||
print("\n=== Streaming a turn through the real lead-agent ===")
|
||||
events: list[dict[str, Any]] = []
|
||||
for event in client.stream(
|
||||
"帮我整理一下最近一周政经新闻,写到 /mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md",
|
||||
thread_id="e2e-safety-1",
|
||||
):
|
||||
events.append({"type": event.type, "data": event.data})
|
||||
|
||||
# ---- Assertions ----
|
||||
safety_event = next(
|
||||
(e for e in events if e["type"] == "custom" and isinstance(e["data"], dict) and e["data"].get("type") == "safety_termination"),
|
||||
None,
|
||||
)
|
||||
final_values = next(
|
||||
(e for e in reversed(events) if e["type"] == "values"),
|
||||
None,
|
||||
)
|
||||
tool_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "tool"]
|
||||
ai_tool_call_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "ai" and e["data"].get("tool_calls")]
|
||||
|
||||
print(f"\n[stats] total stream events: {len(events)}")
|
||||
print(f"[stats] model call count: {fake.call_count}")
|
||||
print(f"[stats] tool messages on stream: {len(tool_messages)}")
|
||||
print(f"[stats] AI messages carrying tool_calls: {len(ai_tool_call_messages)}")
|
||||
|
||||
print("\n[event] safety_termination custom event:")
|
||||
if safety_event is None:
|
||||
print(" *** NOT FOUND ***")
|
||||
return 1
|
||||
for k, v in safety_event["data"].items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
print("\n[state] final AIMessage from last values snapshot:")
|
||||
if final_values is None:
|
||||
print(" *** no values snapshot ***")
|
||||
return 1
|
||||
# `values` event carries `_serialize_message` dicts, not Message objects.
|
||||
final_messages = final_values["data"].get("messages") or []
|
||||
last_ai = next((m for m in reversed(final_messages) if isinstance(m, dict) and m.get("type") == "ai"), None)
|
||||
if last_ai is None:
|
||||
print(" *** no AIMessage in final state ***")
|
||||
print(f" message types seen: {[m.get('type') if isinstance(m, dict) else type(m).__name__ for m in final_messages]}")
|
||||
return 1
|
||||
|
||||
tool_calls = last_ai.get("tool_calls") or []
|
||||
additional_kwargs = last_ai.get("additional_kwargs") or {}
|
||||
response_metadata = last_ai.get("response_metadata") or {}
|
||||
content = last_ai.get("content")
|
||||
|
||||
print(f" tool_calls (must be empty): {tool_calls}")
|
||||
print(f" additional_kwargs.safety_termination: {additional_kwargs.get('safety_termination')}")
|
||||
content_preview = (content if isinstance(content, str) else str(content))[:200]
|
||||
print(f" content[:200]: {content_preview!r}")
|
||||
print(f" response_metadata.finish_reason: {response_metadata.get('finish_reason')}")
|
||||
|
||||
# NOTE: `client._serialize_message` does not include `response_metadata`
|
||||
# in the values-event payload (client-layer behaviour, unrelated to the
|
||||
# middleware). The middleware *does* preserve finish_reason on the
|
||||
# AIMessage object — see test_safety_finish_reason_middleware.py::
|
||||
# TestMessageRewrite::test_preserves_response_metadata_finish_reason.
|
||||
# Here we assert on the observability stamp, which carries the same
|
||||
# evidence and is in the serialized payload.
|
||||
stamp = additional_kwargs.get("safety_termination") or {}
|
||||
failures = []
|
||||
if tool_calls:
|
||||
failures.append("final AIMessage still has tool_calls — middleware did NOT clear them")
|
||||
if not stamp:
|
||||
failures.append("final AIMessage missing safety_termination observability stamp")
|
||||
if tool_messages:
|
||||
failures.append(f"tool node was invoked: {len(tool_messages)} ToolMessage(s) on stream")
|
||||
if stamp.get("reason_value") != "content_filter":
|
||||
failures.append(f"safety_termination.reason_value was {stamp.get('reason_value')!r}, expected 'content_filter'")
|
||||
if safety_event is None:
|
||||
failures.append("safety_termination custom event was not emitted on the stream")
|
||||
|
||||
if failures:
|
||||
print("\n=== FAIL ===")
|
||||
for f in failures:
|
||||
print(f" - {f}")
|
||||
return 1
|
||||
|
||||
print("\n=== PASS ===")
|
||||
print(" - tool_calls cleared on final AIMessage")
|
||||
print(" - tool node never invoked (no ToolMessage on stream)")
|
||||
print(" - safety_termination custom event emitted")
|
||||
print(" - observability stamp written to additional_kwargs")
|
||||
print(" - response_metadata.finish_reason preserved for downstream SSE")
|
||||
return 0
|
||||
finally:
|
||||
lead_agent_module.create_chat_model = originals["lead"]
|
||||
client_module.create_chat_model = originals["client"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -336,8 +336,11 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
)
|
||||
|
||||
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
|
||||
# verify the custom middleware is injected correctly
|
||||
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
|
||||
# verify the custom middleware is injected correctly.
|
||||
# Chain tail order after the custom middleware is:
|
||||
# ..., custom, SafetyFinishReasonMiddleware, ClarificationMiddleware
|
||||
# so the custom mock sits at index [-3].
|
||||
assert len(middlewares) > 0 and isinstance(middlewares[-3], MagicMock)
|
||||
|
||||
|
||||
def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch):
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
"""End-to-end graph integration test for SafetyFinishReasonMiddleware.
|
||||
|
||||
Unit tests prove ``_apply`` does the right thing on a synthetic state.
|
||||
This test does one level up: builds a real ``langchain.agents.create_agent``
|
||||
graph with the SafetyFinishReasonMiddleware in place, feeds it a fake model
|
||||
that returns ``finish_reason='content_filter'`` + tool_calls, and asserts:
|
||||
|
||||
1. The tool node is **not** invoked (the dangerous truncated tool call
|
||||
is suppressed).
|
||||
2. The final AIMessage in graph state has ``tool_calls == []``.
|
||||
3. The observability ``safety_termination`` record is attached.
|
||||
4. The user-facing explanation is appended to the message content.
|
||||
|
||||
This is the closest we can get to the issue's failure mode without a live
|
||||
Moonshot key, and it proves the middleware actually gates LangChain's
|
||||
tool router — not just rewrites state in isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
|
||||
_TOOL_INVOCATIONS: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@tool
|
||||
def write_file(path: str, content: str) -> str:
|
||||
"""Pretend to write *content* to *path*. Records the call for assertion."""
|
||||
_TOOL_INVOCATIONS.append({"path": path, "content": content})
|
||||
return f"wrote {len(content)} bytes to {path}"
|
||||
|
||||
|
||||
class _ContentFilteredModel(BaseChatModel):
|
||||
"""Fake chat model that mimics OpenAI/Moonshot's content_filter response.
|
||||
|
||||
First call returns finish_reason='content_filter' + a tool_call whose
|
||||
arguments are visibly truncated. Second call (if reached) returns a
|
||||
normal text completion so the agent can terminate cleanly.
|
||||
"""
|
||||
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-content-filtered"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
# create_agent binds tools onto the model; we don't actually need
|
||||
# to bind anything since responses are hard-coded, but the method
|
||||
# must not raise.
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self.call_count += 1
|
||||
if self.call_count == 1:
|
||||
message = AIMessage(
|
||||
content="Here is the report:\n# Weekly Politics\n- Meeting time: 2026-05-12—",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_truncated_1",
|
||||
"name": "write_file",
|
||||
"args": {
|
||||
"path": "/mnt/user-data/outputs/report.md",
|
||||
"content": "# Weekly Politics\n- Meeting time: 2026-05-12—",
|
||||
},
|
||||
}
|
||||
],
|
||||
response_metadata={"finish_reason": "content_filter", "model_name": "fake-kimi"},
|
||||
)
|
||||
else:
|
||||
message = AIMessage(content="ack", response_metadata={"finish_reason": "stop"})
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
|
||||
class _InspectMiddleware(AgentMiddleware):
|
||||
"""Captures the messages list at every model entry so we can assert
|
||||
no synthetic tool result was injected back into the conversation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.observed: list[list[Any]] = []
|
||||
|
||||
def wrap_model_call(self, request: ModelRequest, handler) -> ModelResponse:
|
||||
self.observed.append(list(request.messages))
|
||||
return handler(request)
|
||||
|
||||
|
||||
def test_content_filter_with_tool_calls_does_not_invoke_tool_node():
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
inspector = _InspectMiddleware()
|
||||
|
||||
agent = create_agent(
|
||||
model=_ContentFilteredModel(),
|
||||
tools=[write_file],
|
||||
# Inspector first so its after_model is registered; Safety last in
|
||||
# the list so it executes first under LIFO (matches production wiring).
|
||||
middleware=[inspector, SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage(content="write me a report")]})
|
||||
|
||||
# Critical assertion: the dangerous truncated tool call must NOT have
|
||||
# been executed. This is the entire point of the middleware.
|
||||
assert _TOOL_INVOCATIONS == [], f"write_file was invoked despite content_filter: {_TOOL_INVOCATIONS}"
|
||||
|
||||
# Final AIMessage has no tool calls left.
|
||||
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
|
||||
assert final_ai.tool_calls == []
|
||||
|
||||
# Observability stamp is present.
|
||||
record = final_ai.additional_kwargs.get("safety_termination")
|
||||
assert record is not None
|
||||
assert record["detector"] == "openai_compatible_content_filter"
|
||||
assert record["reason_field"] == "finish_reason"
|
||||
assert record["reason_value"] == "content_filter"
|
||||
assert record["suppressed_tool_call_count"] == 1
|
||||
assert record["suppressed_tool_call_names"] == ["write_file"]
|
||||
|
||||
# User-facing explanation is appended.
|
||||
assert "safety-related signal" in final_ai.content
|
||||
# Original partial text preserved (we don't throw away what the user
|
||||
# already saw in the stream — see middleware docstring).
|
||||
assert "Weekly Politics" in final_ai.content
|
||||
|
||||
# finish_reason on response_metadata is preserved (so SSE / converters
|
||||
# downstream still see the real provider reason).
|
||||
assert final_ai.response_metadata.get("finish_reason") == "content_filter"
|
||||
|
||||
|
||||
def test_content_filter_without_tool_calls_passes_through_unchanged():
|
||||
"""No tool calls => issue scope says don't intervene; the partial
|
||||
response should be delivered as-is so the user sees what they got."""
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
|
||||
class _NoToolModel(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-no-tool"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
msg = AIMessage(
|
||||
content="Partial answer truncated by safety filter",
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=msg)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
agent = create_agent(
|
||||
model=_NoToolModel(),
|
||||
tools=[write_file],
|
||||
middleware=[SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
result = agent.invoke({"messages": [HumanMessage(content="hi")]})
|
||||
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
|
||||
|
||||
# Content untouched.
|
||||
assert final_ai.content == "Partial answer truncated by safety filter"
|
||||
# No safety_termination stamp because we didn't intervene.
|
||||
assert "safety_termination" not in final_ai.additional_kwargs
|
||||
# tool node never ran (there were no tool calls in the first place).
|
||||
assert _TOOL_INVOCATIONS == []
|
||||
|
||||
|
||||
def test_normal_tool_call_round_trip_is_not_affected():
|
||||
"""Regression: a healthy finish_reason='tool_calls' response must still
|
||||
execute the tool. The middleware must not over-fire."""
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
|
||||
class _HealthyToolModel(BaseChatModel):
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-healthy"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self.call_count += 1
|
||||
if self.call_count == 1:
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_ok",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/tmp/ok", "content": "complete content"},
|
||||
}
|
||||
],
|
||||
response_metadata={"finish_reason": "tool_calls"},
|
||||
)
|
||||
else:
|
||||
msg = AIMessage(content="done", response_metadata={"finish_reason": "stop"})
|
||||
return ChatResult(generations=[ChatGeneration(message=msg)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
agent = create_agent(
|
||||
model=_HealthyToolModel(),
|
||||
tools=[write_file],
|
||||
middleware=[SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
agent.invoke({"messages": [HumanMessage(content="write")]})
|
||||
|
||||
assert _TOOL_INVOCATIONS == [{"path": "/tmp/ok", "content": "complete content"}]
|
||||
@@ -0,0 +1,651 @@
|
||||
"""Unit tests for SafetyFinishReasonMiddleware."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
from deerflow.agents.middlewares.safety_termination_detectors import (
|
||||
SafetyTermination,
|
||||
)
|
||||
from deerflow.config.safety_finish_reason_config import (
|
||||
SafetyDetectorConfig,
|
||||
SafetyFinishReasonConfig,
|
||||
)
|
||||
|
||||
|
||||
def _runtime(thread_id="t-1"):
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": thread_id}
|
||||
return runtime
|
||||
|
||||
|
||||
def _ai(
|
||||
*,
|
||||
content="",
|
||||
tool_calls=None,
|
||||
response_metadata=None,
|
||||
additional_kwargs=None,
|
||||
):
|
||||
return AIMessage(
|
||||
content=content,
|
||||
tool_calls=tool_calls or [],
|
||||
response_metadata=response_metadata or {},
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
def _write_call(idx=1, content_text="半截"):
|
||||
return {
|
||||
"id": f"call_write_{idx}",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/mnt/user-data/outputs/x.md", "content": content_text},
|
||||
}
|
||||
|
||||
|
||||
class AlwaysHitDetector:
|
||||
"""Test fixture: always reports the given termination."""
|
||||
|
||||
name = "always_hit"
|
||||
|
||||
def __init__(self, *, reason_field="finish_reason", reason_value="content_filter", extras=None):
|
||||
self.reason_field = reason_field
|
||||
self.reason_value = reason_value
|
||||
self.extras = extras or {}
|
||||
|
||||
def detect(self, message):
|
||||
return SafetyTermination(
|
||||
detector=self.name,
|
||||
reason_field=self.reason_field,
|
||||
reason_value=self.reason_value,
|
||||
extras=self.extras,
|
||||
)
|
||||
|
||||
|
||||
class NeverHitDetector:
|
||||
name = "never_hit"
|
||||
|
||||
def detect(self, message):
|
||||
return None
|
||||
|
||||
|
||||
class RaisingDetector:
|
||||
name = "raising"
|
||||
|
||||
def detect(self, message):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core trigger behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTriggerCriteria:
|
||||
def test_content_filter_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
patched = result["messages"][0]
|
||||
assert patched.tool_calls == []
|
||||
|
||||
def test_content_filter_without_tool_calls_passes_through(self):
|
||||
"""issue scope: when there are no tool calls the partial text is a
|
||||
legitimate final response and should not be rewritten."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial response",
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_normal_tool_calls_pass_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "tool_calls"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_normal_stop_with_tool_calls_pass_through(self):
|
||||
# Some providers report finish_reason='stop' for tool-call messages.
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "stop"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_empty_message_list_passes_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
assert mw._apply({"messages": []}, _runtime()) is None
|
||||
|
||||
def test_non_ai_last_message_passes_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {"messages": [HumanMessage(content="hi"), SystemMessage(content="sys")]}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_anthropic_refusal_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"stop_reason": "refusal"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_gemini_safety_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "SAFETY"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message rewriting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageRewrite:
|
||||
def test_clears_structured_tool_calls(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1), _write_call(2)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
patched = result["messages"][0]
|
||||
assert patched.tool_calls == []
|
||||
|
||||
def test_clears_raw_additional_kwargs_tool_calls(self):
|
||||
"""Critical defence-in-depth: DanglingToolCallMiddleware will recover
|
||||
tool calls from additional_kwargs.tool_calls if we forget them, which
|
||||
would re-emit a synthetic ToolMessage downstream and confuse the
|
||||
model. We must wipe both."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
raw_tool_calls = [
|
||||
{
|
||||
"id": "call_write_1",
|
||||
"type": "function",
|
||||
"function": {"name": "write_file", "arguments": '{"path": "/x"}'},
|
||||
}
|
||||
]
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
additional_kwargs={
|
||||
"tool_calls": raw_tool_calls,
|
||||
"function_call": {"name": "write_file", "arguments": "{}"},
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
patched = result["messages"][0]
|
||||
assert "tool_calls" not in patched.additional_kwargs
|
||||
assert "function_call" not in patched.additional_kwargs
|
||||
|
||||
def test_preserves_other_additional_kwargs(self):
|
||||
# vLLM puts reasoning under additional_kwargs.reasoning; Anthropic
|
||||
# may carry other provider-specific keys. They must not be wiped.
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
additional_kwargs={
|
||||
"reasoning": "thinking text",
|
||||
"custom_provider_field": {"x": 1},
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.additional_kwargs["reasoning"] == "thinking text"
|
||||
assert patched.additional_kwargs["custom_provider_field"] == {"x": 1}
|
||||
|
||||
def test_writes_observability_field(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1), _write_call(2)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
record = patched.additional_kwargs["safety_termination"]
|
||||
assert record["detector"] == "openai_compatible_content_filter"
|
||||
assert record["reason_field"] == "finish_reason"
|
||||
assert record["reason_value"] == "content_filter"
|
||||
assert record["suppressed_tool_call_count"] == 2
|
||||
assert record["suppressed_tool_call_names"] == ["write_file", "write_file"]
|
||||
|
||||
def test_preserves_response_metadata_finish_reason(self):
|
||||
"""Downstream SSE converters read response_metadata.finish_reason —
|
||||
we want them to see the *real* provider reason, not 'stop'."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter", "model_name": "kimi-k2"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.response_metadata["finish_reason"] == "content_filter"
|
||||
assert patched.response_metadata["model_name"] == "kimi-k2"
|
||||
|
||||
def test_appends_user_facing_explanation_to_str_content(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="some partial text",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, str)
|
||||
assert patched.content.startswith("some partial text")
|
||||
assert "safety-related signal" in patched.content
|
||||
|
||||
def test_handles_empty_content(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, str)
|
||||
assert "safety-related signal" in patched.content
|
||||
|
||||
def test_handles_list_content_thinking_blocks(self):
|
||||
"""Anthropic thinking / vLLM reasoning models emit content blocks.
|
||||
Naively concatenating a string would raise TypeError."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
thinking_blocks = [
|
||||
{"type": "thinking", "text": "let me consider..."},
|
||||
{"type": "text", "text": "partial answer"},
|
||||
]
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content=thinking_blocks,
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, list)
|
||||
assert patched.content[:2] == thinking_blocks
|
||||
assert patched.content[-1]["type"] == "text"
|
||||
assert "safety-related signal" in patched.content[-1]["text"]
|
||||
|
||||
def test_idempotent_on_already_cleared_message(self):
|
||||
# Re-running the middleware on a message we already cleared must not
|
||||
# re-trigger (tool_calls is now empty → fast passthrough).
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
first = mw._apply(state, _runtime())
|
||||
state2 = {"messages": [first["messages"][0]]}
|
||||
second = mw._apply(state2, _runtime())
|
||||
assert second is None
|
||||
|
||||
def test_preserves_message_id_for_add_messages_replacement(self):
|
||||
"""LangGraph's add_messages reducer treats same-id messages as
|
||||
replacements. model_copy keeps id by default."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
original = _ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
# AIMessage auto-generates id; capture it
|
||||
original_id = original.id
|
||||
state = {"messages": [original]}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.id == original_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detector wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDetectorWiring:
|
||||
def test_iterates_detectors_in_order(self):
|
||||
first = AlwaysHitDetector(reason_value="first")
|
||||
second = AlwaysHitDetector(reason_value="second")
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[first, second])
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.additional_kwargs["safety_termination"]["reason_value"] == "first"
|
||||
|
||||
def test_returns_none_when_no_detector_matches(self):
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[NeverHitDetector(), NeverHitDetector()])
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_buggy_detector_does_not_break_run(self):
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[RaisingDetector(), AlwaysHitDetector()])
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].additional_kwargs["safety_termination"]["detector"] == "always_hit"
|
||||
|
||||
def test_constructor_copies_detectors(self):
|
||||
"""Caller mutation after construction must not leak into us."""
|
||||
detectors = [AlwaysHitDetector()]
|
||||
mw = SafetyFinishReasonMiddleware(detectors=detectors)
|
||||
detectors.clear()
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
assert mw._apply(state, _runtime()) is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# from_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFromConfig:
|
||||
def test_default_config_uses_builtin_detectors(self):
|
||||
mw = SafetyFinishReasonMiddleware.from_config(SafetyFinishReasonConfig())
|
||||
assert len(mw._detectors) == 3
|
||||
names = {d.name for d in mw._detectors}
|
||||
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
|
||||
|
||||
def test_custom_detectors_loaded_via_reflection(self):
|
||||
cfg = SafetyFinishReasonConfig(
|
||||
detectors=[
|
||||
SafetyDetectorConfig(
|
||||
use="deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector",
|
||||
config={"finish_reasons": ["custom_filter"]},
|
||||
),
|
||||
]
|
||||
)
|
||||
mw = SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
assert len(mw._detectors) == 1
|
||||
# Confirm the kwargs propagated.
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "custom_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is not None
|
||||
# Default token no longer matches.
|
||||
state2 = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state2, _runtime()) is None
|
||||
|
||||
def test_empty_detector_list_rejected(self):
|
||||
cfg = SafetyFinishReasonConfig(detectors=[])
|
||||
with pytest.raises(ValueError, match="enabled=false"):
|
||||
SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
|
||||
def test_non_detector_class_rejected(self):
|
||||
cfg = SafetyFinishReasonConfig(
|
||||
detectors=[SafetyDetectorConfig(use="builtins:dict")],
|
||||
)
|
||||
with pytest.raises(TypeError):
|
||||
SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stream event
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuditEvent:
|
||||
"""Verify SafetyFinishReasonMiddleware records a `middleware:safety_termination`
|
||||
audit event via RunJournal.record_middleware when the run-scoped journal is
|
||||
exposed under runtime.context["__run_journal"].
|
||||
|
||||
Background: review on PR #3035 — SSE custom event handles live consumers,
|
||||
but post-run audit needs a row in run_events that can be queried with one
|
||||
SQL statement (no JOIN against message body).
|
||||
"""
|
||||
|
||||
def _runtime_with_journal(self, journal):
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "t-audit", "__run_journal": journal}
|
||||
return runtime
|
||||
|
||||
def test_records_audit_event_when_journal_present(self):
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
tc = _write_call(1)
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial",
|
||||
tool_calls=[tc],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, self._runtime_with_journal(journal))
|
||||
assert result is not None
|
||||
|
||||
journal.record_middleware.assert_called_once()
|
||||
call = journal.record_middleware.call_args
|
||||
# tag is positional or kwarg depending on call style; we use kwargs.
|
||||
assert call.kwargs["tag"] == "safety_termination"
|
||||
assert call.kwargs["name"] == "SafetyFinishReasonMiddleware"
|
||||
assert call.kwargs["hook"] == "after_model"
|
||||
assert call.kwargs["action"] == "suppress_tool_calls"
|
||||
|
||||
changes = call.kwargs["changes"]
|
||||
assert changes["detector"] == "openai_compatible_content_filter"
|
||||
assert changes["reason_field"] == "finish_reason"
|
||||
assert changes["reason_value"] == "content_filter"
|
||||
assert changes["suppressed_tool_call_count"] == 1
|
||||
assert changes["suppressed_tool_call_names"] == ["write_file"]
|
||||
assert changes["suppressed_tool_call_ids"] == ["call_write_1"]
|
||||
assert "message_id" in changes
|
||||
assert isinstance(changes["extras"], dict)
|
||||
|
||||
def test_audit_event_never_carries_tool_arguments(self):
|
||||
"""PR #3035 review IMPORTANT: tool args are the filtered content itself
|
||||
and must NOT be persisted to run_events under any circumstance."""
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
sensitive_tc = {
|
||||
"id": "call_x",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/x", "content": "FILTERED_CONTENT_DO_NOT_PERSIST"},
|
||||
}
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[sensitive_tc],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
mw._apply(state, self._runtime_with_journal(journal))
|
||||
flat = repr(journal.record_middleware.call_args)
|
||||
assert "FILTERED_CONTENT_DO_NOT_PERSIST" not in flat, "tool arguments must not leak into audit event"
|
||||
assert "args" not in journal.record_middleware.call_args.kwargs["changes"]
|
||||
|
||||
def test_no_journal_in_runtime_context_is_silently_skipped(self):
|
||||
"""Subagent runtime / unit tests / no-event-store paths have no journal.
|
||||
Middleware must still intervene and clear tool_calls — only the audit
|
||||
event is skipped."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "t-noj"} # no __run_journal
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Should not raise; should still clear tool_calls.
|
||||
result = mw._apply(state, runtime)
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_journal_record_exception_does_not_break_run(self):
|
||||
"""Buggy journal must never propagate an exception into the agent loop."""
|
||||
journal = MagicMock()
|
||||
journal.record_middleware.side_effect = RuntimeError("db down")
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Must not raise.
|
||||
result = mw._apply(state, self._runtime_with_journal(journal))
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_no_record_when_passthrough(self):
|
||||
"""When the middleware does NOT intervene, no audit event is written."""
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "tool_calls"}, # healthy
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, self._runtime_with_journal(journal)) is None
|
||||
journal.record_middleware.assert_not_called()
|
||||
|
||||
|
||||
class TestStreamEvent:
|
||||
def test_emits_event_when_writer_available(self, monkeypatch):
|
||||
captured: list = []
|
||||
|
||||
def fake_writer(payload):
|
||||
captured.append(payload)
|
||||
|
||||
# Patch get_stream_writer at the symbol-resolution site.
|
||||
import langgraph.config
|
||||
|
||||
monkeypatch.setattr(langgraph.config, "get_stream_writer", lambda: fake_writer)
|
||||
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
mw._apply(state, _runtime("t-stream"))
|
||||
|
||||
assert len(captured) == 1
|
||||
payload = captured[0]
|
||||
assert payload["type"] == "safety_termination"
|
||||
assert payload["detector"] == "openai_compatible_content_filter"
|
||||
assert payload["reason_field"] == "finish_reason"
|
||||
assert payload["reason_value"] == "content_filter"
|
||||
assert payload["suppressed_tool_call_count"] == 1
|
||||
assert payload["suppressed_tool_call_names"] == ["write_file"]
|
||||
assert payload["thread_id"] == "t-stream"
|
||||
|
||||
def test_writer_unavailable_does_not_break(self, monkeypatch):
|
||||
import langgraph.config
|
||||
|
||||
def boom():
|
||||
raise LookupError("not in a stream context")
|
||||
|
||||
monkeypatch.setattr(langgraph.config, "get_stream_writer", boom)
|
||||
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Should not raise.
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
@@ -0,0 +1,176 @@
|
||||
"""Unit tests for SafetyTerminationDetector built-ins."""
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from deerflow.agents.middlewares.safety_termination_detectors import (
|
||||
AnthropicRefusalDetector,
|
||||
GeminiSafetyDetector,
|
||||
OpenAICompatibleContentFilterDetector,
|
||||
SafetyTermination,
|
||||
SafetyTerminationDetector,
|
||||
default_detectors,
|
||||
)
|
||||
|
||||
|
||||
def _ai(*, content="", tool_calls=None, response_metadata=None, additional_kwargs=None) -> AIMessage:
|
||||
return AIMessage(
|
||||
content=content,
|
||||
tool_calls=tool_calls or [],
|
||||
response_metadata=response_metadata or {},
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAICompatibleContentFilterDetector:
|
||||
def test_default_matches_content_filter(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
hit = d.detect(_ai(response_metadata={"finish_reason": "content_filter"}))
|
||||
assert hit is not None
|
||||
assert hit.detector == "openai_compatible_content_filter"
|
||||
assert hit.reason_field == "finish_reason"
|
||||
assert hit.reason_value == "content_filter"
|
||||
|
||||
def test_case_insensitive_match(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "CONTENT_FILTER"})) is not None
|
||||
|
||||
def test_other_finish_reasons_pass_through(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "stop"})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "tool_calls"})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "length"})) is None
|
||||
|
||||
def test_missing_metadata_passes_through(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai()) is None
|
||||
|
||||
def test_non_string_finish_reason_passes_through(self):
|
||||
# Some adapters may stash an enum or dict — must not raise.
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": 42})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": {"value": "content_filter"}})) is None
|
||||
|
||||
def test_falls_back_to_additional_kwargs(self):
|
||||
# Legacy adapters surface finish_reason via additional_kwargs.
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
hit = d.detect(_ai(additional_kwargs={"finish_reason": "content_filter"}))
|
||||
assert hit is not None
|
||||
|
||||
def test_configurable_extra_values(self):
|
||||
# Chinese providers sometimes use bespoke tokens.
|
||||
d = OpenAICompatibleContentFilterDetector(finish_reasons=["content_filter", "sensitive", "violation"])
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "sensitive"})) is not None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "violation"})) is not None
|
||||
# Original token still matches.
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "content_filter"})) is not None
|
||||
|
||||
def test_carries_azure_content_filter_results(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
filter_results = {"hate": {"filtered": True, "severity": "high"}}
|
||||
hit = d.detect(
|
||||
_ai(
|
||||
response_metadata={
|
||||
"finish_reason": "content_filter",
|
||||
"content_filter_results": filter_results,
|
||||
},
|
||||
)
|
||||
)
|
||||
assert hit is not None
|
||||
assert hit.extras["content_filter_results"] == filter_results
|
||||
|
||||
|
||||
class TestAnthropicRefusalDetector:
|
||||
def test_default_matches_refusal(self):
|
||||
hit = AnthropicRefusalDetector().detect(_ai(response_metadata={"stop_reason": "refusal"}))
|
||||
assert hit is not None
|
||||
assert hit.reason_field == "stop_reason"
|
||||
assert hit.reason_value == "refusal"
|
||||
|
||||
def test_other_stop_reasons_pass_through(self):
|
||||
d = AnthropicRefusalDetector()
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "end_turn"})) is None
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "tool_use"})) is None
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "max_tokens"})) is None
|
||||
|
||||
def test_anthropic_does_not_steal_finish_reason(self):
|
||||
# An OpenAI message must not accidentally trip the Anthropic detector.
|
||||
assert AnthropicRefusalDetector().detect(_ai(response_metadata={"finish_reason": "content_filter"})) is None
|
||||
|
||||
|
||||
class TestGeminiSafetyDetector:
|
||||
def test_default_set_covers_documented_reasons(self):
|
||||
d = GeminiSafetyDetector()
|
||||
for reason in (
|
||||
# text safety
|
||||
"SAFETY",
|
||||
"BLOCKLIST",
|
||||
"PROHIBITED_CONTENT",
|
||||
"SPII",
|
||||
"RECITATION",
|
||||
# image safety
|
||||
"IMAGE_SAFETY",
|
||||
"IMAGE_PROHIBITED_CONTENT",
|
||||
"IMAGE_RECITATION",
|
||||
):
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is not None, reason
|
||||
|
||||
def test_normal_termination_passes_through(self):
|
||||
d = GeminiSafetyDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "STOP"})) is None
|
||||
# MAX_TOKENS / LANGUAGE / NO_IMAGE / OTHER / IMAGE_OTHER /
|
||||
# MALFORMED_FUNCTION_CALL / UNEXPECTED_TOOL_CALL are intentionally
|
||||
# excluded from the default set — they are either normal termination,
|
||||
# capability mismatches, too broad (OTHER), or tool-call protocol
|
||||
# errors. See GeminiSafetyDetector docstring.
|
||||
for reason in (
|
||||
"MAX_TOKENS",
|
||||
"LANGUAGE",
|
||||
"NO_IMAGE",
|
||||
"OTHER",
|
||||
"IMAGE_OTHER",
|
||||
"MALFORMED_FUNCTION_CALL",
|
||||
"UNEXPECTED_TOOL_CALL",
|
||||
"FINISH_REASON_UNSPECIFIED",
|
||||
):
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is None, reason
|
||||
|
||||
def test_carries_safety_ratings(self):
|
||||
ratings = [{"category": "HARM_CATEGORY_HARASSMENT", "probability": "HIGH"}]
|
||||
hit = GeminiSafetyDetector().detect(
|
||||
_ai(
|
||||
response_metadata={
|
||||
"finish_reason": "SAFETY",
|
||||
"safety_ratings": ratings,
|
||||
},
|
||||
)
|
||||
)
|
||||
assert hit is not None
|
||||
assert hit.extras["safety_ratings"] == ratings
|
||||
|
||||
|
||||
class TestDefaultDetectorSet:
|
||||
def test_default_set_returns_three_detectors(self):
|
||||
dets = default_detectors()
|
||||
names = {d.name for d in dets}
|
||||
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
|
||||
|
||||
def test_default_set_returns_fresh_list(self):
|
||||
# Caller mutation must not affect later calls.
|
||||
first = default_detectors()
|
||||
first.clear()
|
||||
second = default_detectors()
|
||||
assert len(second) == 3
|
||||
|
||||
|
||||
class TestProtocolConformance:
|
||||
def test_builtins_satisfy_protocol(self):
|
||||
for d in default_detectors():
|
||||
assert isinstance(d, SafetyTerminationDetector)
|
||||
|
||||
def test_safety_termination_is_frozen(self):
|
||||
t = SafetyTermination(detector="x", reason_field="finish_reason", reason_value="content_filter")
|
||||
try:
|
||||
t.detector = "y" # type: ignore[misc]
|
||||
except Exception:
|
||||
return
|
||||
raise AssertionError("SafetyTermination should be frozen")
|
||||
@@ -134,8 +134,14 @@ def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
|
||||
|
||||
assert captured["app_config"] is app_config
|
||||
assert len(middlewares) == 6
|
||||
assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware)
|
||||
# 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling,
|
||||
# SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware
|
||||
# (enabled by default — see SafetyFinishReasonConfig).
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
|
||||
assert len(middlewares) == 7
|
||||
assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares)
|
||||
assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware)
|
||||
|
||||
|
||||
def test_wrap_tool_call_passthrough_on_success():
|
||||
|
||||
+36
-1
@@ -15,7 +15,7 @@
|
||||
# ============================================================================
|
||||
# Bump this number when the config schema changes.
|
||||
# Run `make config-upgrade` to merge new fields into your local config.yaml.
|
||||
config_version: 9
|
||||
config_version: 10
|
||||
|
||||
# ============================================================================
|
||||
# Logging
|
||||
@@ -535,6 +535,41 @@ loop_detection:
|
||||
# warn: 150
|
||||
# hard_limit: 300
|
||||
|
||||
# ============================================================================
|
||||
# Provider Safety Termination Configuration
|
||||
# ============================================================================
|
||||
# Intercept AIMessages where the provider stopped generation for safety reasons
|
||||
# (e.g. OpenAI finish_reason='content_filter', Anthropic stop_reason='refusal',
|
||||
# Gemini finish_reason='SAFETY') while still returning tool_calls. The
|
||||
# tool_calls in such responses are typically truncated/unreliable and must
|
||||
# not be executed. See issue #3028 for the full failure mode.
|
||||
#
|
||||
# Detectors are loaded by class path via reflection (same pattern as
|
||||
# guardrails / models / tools). The built-in set covers OpenAI-compatible
|
||||
# content_filter, Anthropic refusal, and Gemini SAFETY/BLOCKLIST/
|
||||
# PROHIBITED_CONTENT/SPII/RECITATION.
|
||||
|
||||
safety_finish_reason:
|
||||
enabled: true
|
||||
# Leave `detectors` unset to use the built-in detector set. Set to a
|
||||
# non-empty list to fully override (use `enabled: false` to disable instead
|
||||
# of providing an empty list).
|
||||
#
|
||||
# Example — extend the OpenAI-compatible detector for a Chinese provider
|
||||
# whose gateway uses a non-standard finish_reason token:
|
||||
# detectors:
|
||||
# - use: deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector
|
||||
# config:
|
||||
# finish_reasons: ["content_filter", "sensitive", "risk_control"]
|
||||
# - use: deerflow.agents.middlewares.safety_termination_detectors:AnthropicRefusalDetector
|
||||
# - use: deerflow.agents.middlewares.safety_termination_detectors:GeminiSafetyDetector
|
||||
#
|
||||
# Example — add a custom detector for an in-house provider:
|
||||
# detectors:
|
||||
# - use: my_company.deerflow_ext:WenxinSafetyDetector
|
||||
# config:
|
||||
# error_codes: [336003, 17, 18]
|
||||
|
||||
# ============================================================================
|
||||
# Sandbox Configuration
|
||||
# ============================================================================
|
||||
|
||||
Reference in New Issue
Block a user