From be0eae9825619b63ca0c67b253d40b5eee76a2d6 Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Fri, 22 May 2026 21:20:28 +0800 Subject: [PATCH] fix(runtime): suppress tool execution when provider safety-terminates with tool_calls (#3035) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(runtime): suppress tool execution when provider safety-terminates with tool_calls When a provider stops generation for safety reasons (OpenAI/Moonshot finish_reason=content_filter, Anthropic stop_reason=refusal, Gemini finish_reason=SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/RECITATION/ IMAGE_SAFETY/...), the response may still carry truncated tool_calls. LangChain's tool router treats any non-empty tool_calls as executable, so partial arguments (e.g. write_file with a half-finished markdown) get dispatched and the agent loops on retry. Add SafetyFinishReasonMiddleware at after_model: detect safety termination via a pluggable detector registry, clear both structured tool_calls and raw additional_kwargs.tool_calls / function_call, preserve response_metadata.finish_reason for downstream observers, stamp additional_kwargs.safety_termination for traces, append a user-facing explanation to message content (list-aware for thinking blocks), and emit a safety_termination custom stream event so SSE consumers can reconcile any "tool starting..." UI. Default detectors cover OpenAI-compatible content_filter, Anthropic refusal, and Gemini safety enums (text + image). Custom providers are added via reflection (same pattern as guardrails). Wired into both lead-agent and subagent runtimes. Closes #3028 * fix(runtime): persist safety_termination as a middleware audit event Address review on #3035: the SSE custom event is great for live consumers but invisible to post-run audit. RunEventStore should carry its own row so operators can answer "which runs were safety-suppressed today?" from a single SQL query without joining the message body. Worker now exposes the run-scoped RunJournal via runtime.context["__run_journal"] (sentinel key, internal channel). SafetyFinishReasonMiddleware calls the previously-unused RunJournal.record_middleware, which emits event_type = "middleware:safety_termination" category = "middleware" content = {name, hook, action, changes={ detector, reason_field, reason_value, suppressed_tool_call_count, suppressed_tool_call_names, suppressed_tool_call_ids, message_id, extras}} Tool *arguments* are deliberately excluded — those are the very content the provider filtered and persisting them would defeat the purpose of the safety filter (per review note in #3035). Graceful skips when journal is absent (subagent runtime, unit tests, no-event-store local dev). Journal exceptions never propagate into the agent loop. Refs #3028 * fix(runtime): satisfy ruff format + address Copilot review - ruff format on safety_finish_reason_config.py and e2e demo (CI lint failed on ruff format --check; backend Makefile lint target runs ruff check AND ruff format --check). - Docstring on SafetyFinishReasonConfig now says resolve_variable to match the actual loader used in from_config (the wording was resolve_class previously; behavior is unchanged — resolve_variable mirrors how guardrails.provider is loaded). - Switch the AIMessage type check in SafetyFinishReasonMiddleware._apply from getattr(last, "type") == "ai" to isinstance(last, AIMessage), matching TokenUsageMiddleware / TodoMiddleware / ViewImageMiddleware / SummarizationMiddleware which are the dominant pattern. Refs #3028 --- .../deerflow/agents/lead_agent/agent.py | 10 + .../safety_finish_reason_middleware.py | 317 +++++++++ .../safety_termination_detectors.py | 237 +++++++ .../tool_error_handling_middleware.py | 10 + .../harness/deerflow/config/app_config.py | 2 + .../config/safety_finish_reason_config.py | 47 ++ .../harness/deerflow/runtime/runs/worker.py | 6 + .../scripts/e2e_safety_termination_demo.py | 206 ++++++ .../tests/test_lead_agent_model_resolution.py | 7 +- ..._safety_finish_reason_graph_integration.py | 225 ++++++ .../test_safety_finish_reason_middleware.py | 651 ++++++++++++++++++ .../test_safety_termination_detectors.py | 176 +++++ .../test_tool_error_handling_middleware.py | 10 +- config.example.yaml | 37 +- 14 files changed, 1936 insertions(+), 5 deletions(-) create mode 100644 backend/packages/harness/deerflow/agents/middlewares/safety_finish_reason_middleware.py create mode 100644 backend/packages/harness/deerflow/agents/middlewares/safety_termination_detectors.py create mode 100644 backend/packages/harness/deerflow/config/safety_finish_reason_config.py create mode 100644 backend/scripts/e2e_safety_termination_demo.py create mode 100644 backend/tests/test_safety_finish_reason_graph_integration.py create mode 100644 backend/tests/test_safety_finish_reason_middleware.py create mode 100644 backend/tests/test_safety_termination_detectors.py diff --git a/backend/packages/harness/deerflow/agents/lead_agent/agent.py b/backend/packages/harness/deerflow/agents/lead_agent/agent.py index 328a8a6e1..e03ff33ad 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/agent.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/agent.py @@ -29,6 +29,7 @@ from deerflow.agents.memory.summarization_hook import memory_flush_hook from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware +from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware from deerflow.agents.middlewares.title_middleware import TitleMiddleware @@ -338,6 +339,15 @@ def _build_middlewares( if custom_middlewares: middlewares.extend(custom_middlewares) + # SafetyFinishReasonMiddleware — suppress tool execution when the provider + # safety-terminated the response. Registered after custom middlewares so + # that LangChain's reverse-order after_model dispatch runs Safety first; + # cleared tool_calls then flow through Loop/Subagent accounting without + # firing extra alarms. See safety_finish_reason_middleware.py docstring. + safety_config = resolved_app_config.safety_finish_reason + if safety_config.enabled: + middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config)) + # ClarificationMiddleware should always be last middlewares.append(ClarificationMiddleware()) return middlewares diff --git a/backend/packages/harness/deerflow/agents/middlewares/safety_finish_reason_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/safety_finish_reason_middleware.py new file mode 100644 index 000000000..8fd733c23 --- /dev/null +++ b/backend/packages/harness/deerflow/agents/middlewares/safety_finish_reason_middleware.py @@ -0,0 +1,317 @@ +"""Suppress tool execution when the provider safety-terminated the response. + +Background — see issue bytedance/deer-flow#3028. + +Some providers (OpenAI ``finish_reason='content_filter'``, Anthropic +``stop_reason='refusal'``, Gemini ``finish_reason='SAFETY'`` ...) can stop +generation mid-stream while still returning partially-formed ``tool_calls``. +LangChain's tool router treats any AIMessage with a non-empty ``tool_calls`` +field as "go execute these", so half-truncated arguments — e.g. a markdown +``write_file`` that stops in the middle of a sentence — get dispatched as if +they were complete. The agent then sees the truncated file, tries to fix it, +gets filtered again, and loops. + +This middleware sits at ``after_model`` and gates that behaviour: when a +configured ``SafetyTerminationDetector`` fires *and* the AIMessage carries +tool calls, we strip the tool calls (both structured and raw provider +payloads), append a user-facing explanation, and stash observability fields +in ``additional_kwargs.safety_termination`` so logs, traces, and SSE +consumers can see what happened. + +Hook choice: ``after_model`` (not ``wrap_model_call``) because the response +is a *normal* return — not an exception — and we want to participate in the +same after-model chain as ``LoopDetectionMiddleware``, with which we share +the same tool-call-suppression mechanic but a different trigger. + +Placement: register *after* ``LoopDetectionMiddleware`` in the middleware +list. LangChain factory wires ``after_model`` edges in reverse list order +(``langchain/agents/factory.py:add_edge("model", middleware_w_after_model[-1])``, +then walks ``range(len-1, 0, -1)``), so the *last* registered middleware is +the *first* to observe the model output. Registering Safety after Loop +means Safety sees the raw response first, clears tool calls if it fires, +and Loop then accounts against the cleaned message. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, override + +from langchain.agents import AgentState +from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import AIMessage +from langgraph.runtime import Runtime + +from deerflow.agents.middlewares.safety_termination_detectors import ( + SafetyTermination, + SafetyTerminationDetector, + default_detectors, +) +from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls + +if TYPE_CHECKING: + from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig + +logger = logging.getLogger(__name__) + + +_USER_FACING_MESSAGE = ( + "The model provider stopped this response with a safety-related signal " + "({reason_field}={reason_value!r}, detector={detector!r}). Any tool " + "calls produced in this turn were suppressed because their arguments " + "may be truncated and unsafe to execute. Please rephrase the request " + "or ask for a narrower output." +) + + +class SafetyFinishReasonMiddleware(AgentMiddleware[AgentState]): + """Strip tool_calls from AIMessages flagged by a SafetyTerminationDetector.""" + + def __init__(self, detectors: list[SafetyTerminationDetector] | None = None) -> None: + super().__init__() + # Copy so caller mutations after construction don't leak into us. + self._detectors: list[SafetyTerminationDetector] = list(detectors) if detectors else default_detectors() + + @classmethod + def from_config(cls, config: SafetyFinishReasonConfig) -> SafetyFinishReasonMiddleware: + """Construct from validated Pydantic config, honouring the + reflection-loaded detector list when provided. + + An explicit empty list is intentionally rejected — it would silently + disable detection while leaving the middleware in the chain, which + is the worst of both worlds. Use ``enabled: false`` instead. + """ + if config.detectors is None: + return cls() + + if not config.detectors: + raise ValueError("safety_finish_reason.detectors must be omitted (use built-ins) or contain at least one entry; use enabled=false to disable the middleware entirely.") + + from deerflow.reflection import resolve_variable + + detectors: list[SafetyTerminationDetector] = [] + for entry in config.detectors: + detector_cls = resolve_variable(entry.use) + kwargs = dict(entry.config) if entry.config else {} + detector = detector_cls(**kwargs) + if not isinstance(detector, SafetyTerminationDetector): + raise TypeError(f"{entry.use} did not produce a SafetyTerminationDetector (got {type(detector).__name__}); ensure it has a `name` attribute and a `detect(message)` method") + detectors.append(detector) + return cls(detectors=detectors) + + # ----- detection ------------------------------------------------------- + + def _detect(self, message: AIMessage) -> SafetyTermination | None: + for detector in self._detectors: + try: + hit = detector.detect(message) + except Exception: # noqa: BLE001 - never let a buggy detector break the agent run + logger.exception("SafetyTerminationDetector %r raised; treating as no-match", getattr(detector, "name", type(detector).__name__)) + continue + if hit is not None: + return hit + return None + + # ----- message rewriting ---------------------------------------------- + + @staticmethod + def _append_user_message(content: object, text: str) -> str | list: + """Append a plain-text explanation to AIMessage content. + + Mirrors ``LoopDetectionMiddleware._append_text`` so list-content + responses (Anthropic thinking blocks, vLLM reasoning splits) keep + their structure instead of being string-coerced into a TypeError. + """ + if content is None or content == "": + return text + if isinstance(content, list): + return [*content, {"type": "text", "text": f"\n\n{text}"}] + if isinstance(content, str): + return content + f"\n\n{text}" + return str(content) + f"\n\n{text}" + + def _build_suppressed_message( + self, + message: AIMessage, + termination: SafetyTermination, + ) -> AIMessage: + suppressed_names = [tc.get("name") or "unknown" for tc in (message.tool_calls or [])] + explanation = _USER_FACING_MESSAGE.format( + reason_field=termination.reason_field, + reason_value=termination.reason_value, + detector=termination.detector, + ) + new_content = self._append_user_message(message.content, explanation) + + # clone_ai_message_with_tool_calls handles structured tool_calls, + # raw additional_kwargs.tool_calls, and function_call in one shot. + # It only rewrites finish_reason when the old value was "tool_calls", + # which is not our case — content_filter / refusal / SAFETY stay put + # so downstream SSE / converters keep seeing the real provider reason. + cleared = clone_ai_message_with_tool_calls(message, [], content=new_content) + + # Re-clone additional_kwargs so we don't accidentally mutate the + # dict returned by clone_ai_message_with_tool_calls (which already + # made a shallow copy, but downstream model_copy still references + # it). Then stamp the observability record. + kwargs = dict(getattr(cleared, "additional_kwargs", None) or {}) + kwargs["safety_termination"] = { + "detector": termination.detector, + "reason_field": termination.reason_field, + "reason_value": termination.reason_value, + "suppressed_tool_call_count": len(suppressed_names), + "suppressed_tool_call_names": suppressed_names, + "extras": dict(termination.extras) if termination.extras else {}, + } + return cleared.model_copy(update={"additional_kwargs": kwargs}) + + # ----- observability --------------------------------------------------- + + def _emit_event( + self, + termination: SafetyTermination, + suppressed_names: list[str], + runtime: Runtime, + ) -> None: + """Notify SSE consumers (e.g. the web UI) that a tool turn was + suppressed so they can reconcile any "tool starting..." placeholders + already streamed to the user. Failures are logged at debug and + ignored — this is a best-effort signal.""" + try: + from langgraph.config import get_stream_writer + + writer = get_stream_writer() + except Exception: # noqa: BLE001 + logger.debug("get_stream_writer unavailable; skipping safety_termination event", exc_info=True) + return + + thread_id = None + if runtime is not None and getattr(runtime, "context", None): + thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None + + try: + writer( + { + "type": "safety_termination", + "detector": termination.detector, + "reason_field": termination.reason_field, + "reason_value": termination.reason_value, + "suppressed_tool_call_count": len(suppressed_names), + "suppressed_tool_call_names": suppressed_names, + "thread_id": thread_id, + } + ) + except Exception: # noqa: BLE001 + logger.debug("Failed to emit safety_termination stream event", exc_info=True) + + def _record_audit_event( + self, + termination: SafetyTermination, + message, + tool_calls: list[dict], + runtime: Runtime, + ) -> None: + """Write a ``middleware:safety_termination`` record to RunEventStore + for post-run auditability. + + The custom stream event in ``_emit_event`` is consumed by live SSE + clients and disappears after the run; this event is persisted so an + operator can answer "which runs were safety-suppressed today?" from + a single SQL query without joining the message body. Worker exposes + the run-scoped ``RunJournal`` via ``runtime.context["__run_journal"]``; + absent in unit-test / subagent / no-event-store paths, in which case + we silently skip. + + Tool **arguments** are deliberately **not** recorded — those are the + very content the provider filtered; persisting them would defeat the + purpose of the safety filter. Names / count / ids are sufficient for + audit and debugging (issue #3028 review). + """ + journal = None + if runtime is not None and getattr(runtime, "context", None): + context = runtime.context + if isinstance(context, dict): + journal = context.get("__run_journal") + if journal is None: + return + + suppressed_names = [tc.get("name") or "unknown" for tc in tool_calls] + suppressed_ids = [tc.get("id") for tc in tool_calls if tc.get("id")] + + changes = { + "detector": termination.detector, + "reason_field": termination.reason_field, + "reason_value": termination.reason_value, + "suppressed_tool_call_count": len(tool_calls), + "suppressed_tool_call_names": suppressed_names, + "suppressed_tool_call_ids": suppressed_ids, + "message_id": getattr(message, "id", None), + "extras": dict(termination.extras) if termination.extras else {}, + } + + try: + journal.record_middleware( + tag="safety_termination", + name=type(self).__name__, + hook="after_model", + action="suppress_tool_calls", + changes=changes, + ) + except Exception: # noqa: BLE001 + # Audit-event persistence must never break agent execution. + logger.debug("Failed to record middleware:safety_termination event", exc_info=True) + + # ----- main apply ------------------------------------------------------ + + def _apply(self, state: AgentState, runtime: Runtime) -> dict | None: + messages = state.get("messages", []) + if not messages: + return None + + last = messages[-1] + if not isinstance(last, AIMessage): + return None + + # Issue scope: only intervene when there's something to suppress. + # ``content_filter`` without tool_calls is allowed through unchanged + # so the partial text response (if any) reaches the user naturally. + tool_calls = last.tool_calls + if not tool_calls: + return None + + termination = self._detect(last) + if termination is None: + return None + + patched = self._build_suppressed_message(last, termination) + + thread_id = None + if runtime is not None and getattr(runtime, "context", None): + thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None + + logger.warning( + "Provider safety termination detected — suppressed %d tool call(s)", + len(tool_calls), + extra={ + "thread_id": thread_id, + "detector": termination.detector, + "reason_field": termination.reason_field, + "reason_value": termination.reason_value, + "suppressed_tool_call_names": [tc.get("name") for tc in tool_calls], + }, + ) + + self._emit_event(termination, [tc.get("name") or "unknown" for tc in tool_calls], runtime) + self._record_audit_event(termination, last, list(tool_calls), runtime) + + return {"messages": [patched]} + + # ----- hooks ----------------------------------------------------------- + + @override + def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: + return self._apply(state, runtime) + + @override + async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: + return self._apply(state, runtime) diff --git a/backend/packages/harness/deerflow/agents/middlewares/safety_termination_detectors.py b/backend/packages/harness/deerflow/agents/middlewares/safety_termination_detectors.py new file mode 100644 index 000000000..b98e9f4d7 --- /dev/null +++ b/backend/packages/harness/deerflow/agents/middlewares/safety_termination_detectors.py @@ -0,0 +1,237 @@ +"""Detectors for provider-side safety termination signals. + +Different LLM providers signal "I stopped this response for safety reasons" +through different fields with different values. This module defines a small +strategy interface and three built-in detectors that cover the major +providers DeerFlow supports today. New providers (Wenxin, Hunyuan, Bedrock +adapters, in-house gateways, ...) can be added by implementing +``SafetyTerminationDetector`` and wiring it through +``config.yaml: safety_finish_reason.detectors``. + +The middleware that consumes these detectors lives in +``safety_finish_reason_middleware.py``. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable + +from langchain_core.messages import AIMessage + + +@dataclass(frozen=True) +class SafetyTermination: + """A detected safety-related termination signal. + + Attributes: + detector: Name of the detector that produced this result. Used for + observability so operators can see which provider rule fired. + reason_field: The message metadata field that carried the signal + (e.g. ``finish_reason``, ``stop_reason``). + reason_value: The actual value of that field + (e.g. ``content_filter``, ``refusal``, ``SAFETY``). + extras: Provider-specific metadata that may help downstream + consumers (e.g. Azure OpenAI content_filter_results, Gemini + safety_ratings). Detectors are free to populate or skip this. + """ + + detector: str + reason_field: str + reason_value: str + extras: dict[str, Any] = field(default_factory=dict) + + +@runtime_checkable +class SafetyTerminationDetector(Protocol): + """Strategy interface for provider safety termination detection.""" + + name: str + + def detect(self, message: AIMessage) -> SafetyTermination | None: + """Return a SafetyTermination if *message* indicates provider safety + termination, otherwise return ``None``. + + Implementations must be side-effect free and tolerant of missing or + oddly-typed metadata — detectors run on every model response. + """ + ... + + +def _get_metadata_value(message: AIMessage, field_name: str) -> str | None: + """Read a string-typed value from either ``response_metadata`` or + ``additional_kwargs``. + + LangChain provider adapters are inconsistent about where they stash + provider stop signals. Most modern adapters use ``response_metadata``, + but some legacy / passthrough paths still surface them via + ``additional_kwargs``. We check both, in that order, and only accept + string values — Pydantic enums or dicts are ignored so we never raise + on malformed inputs. + """ + for container_name in ("response_metadata", "additional_kwargs"): + container = getattr(message, container_name, None) or {} + if not isinstance(container, dict): + continue + value = container.get(field_name) + if isinstance(value, str) and value: + return value + return None + + +class OpenAICompatibleContentFilterDetector: + """OpenAI-compatible content_filter signal. + + Covers OpenAI, Azure OpenAI, Moonshot/Kimi, DeepSeek, Mistral, vLLM, + Qwen (OpenAI-compatible mode), and any other adapter that follows the + OpenAI ``finish_reason`` convention. + + Some Chinese providers ship custom OpenAI-compatible gateways that use + alternative tokens like ``sensitive`` or ``violation``. Extend the set + via the ``finish_reasons`` kwarg in config. + """ + + name = "openai_compatible_content_filter" + + def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None: + configured = finish_reasons if finish_reasons is not None else ("content_filter",) + self._finish_reasons: frozenset[str] = frozenset(r.lower() for r in configured) + + def detect(self, message: AIMessage) -> SafetyTermination | None: + value = _get_metadata_value(message, "finish_reason") + if value is None or value.lower() not in self._finish_reasons: + return None + + extras: dict[str, Any] = {} + # Azure OpenAI ships a structured content_filter_results block; carry it + # through so operators can see *what* was filtered without re-tracing. + response_metadata = getattr(message, "response_metadata", None) or {} + if isinstance(response_metadata, dict): + filter_results = response_metadata.get("content_filter_results") + if filter_results: + extras["content_filter_results"] = filter_results + + return SafetyTermination( + detector=self.name, + reason_field="finish_reason", + reason_value=value, + extras=extras, + ) + + +class AnthropicRefusalDetector: + """Anthropic ``stop_reason == "refusal"`` signal. + + Anthropic models surface safety refusals via a dedicated ``stop_reason`` + rather than ``finish_reason``. See: + https://platform.claude.com/docs/en/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals + """ + + name = "anthropic_refusal" + + def __init__(self, stop_reasons: list[str] | tuple[str, ...] | None = None) -> None: + configured = stop_reasons if stop_reasons is not None else ("refusal",) + self._stop_reasons: frozenset[str] = frozenset(r.lower() for r in configured) + + def detect(self, message: AIMessage) -> SafetyTermination | None: + value = _get_metadata_value(message, "stop_reason") + if value is None or value.lower() not in self._stop_reasons: + return None + return SafetyTermination( + detector=self.name, + reason_field="stop_reason", + reason_value=value, + ) + + +class GeminiSafetyDetector: + """Gemini / Vertex AI safety-related finish reasons. + + Gemini uses the same ``finish_reason`` field as OpenAI but with an + enumerated upper-case taxonomy. The default set covers every Gemini + finish_reason that means "the model stopped because the content/image + tripped a safety, blocklist, recitation, or PII filter" — i.e. cases + where any tool_calls returned alongside are likely truncated/ + unreliable. Full enum: + https://docs.cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform_v1.types.Candidate.FinishReason + + Intentionally **excluded** from the default set: + - ``STOP`` — normal termination. + - ``MAX_TOKENS`` — output length truncation, not safety + (same root failure mode as + content_filter, but issue #3028 + scopes it out; expose separately if + desired). + - ``LANGUAGE`` / ``NO_IMAGE`` — capability mismatches, unrelated to + safety; tool_calls would be absent + anyway. + - ``MALFORMED_FUNCTION_CALL`` / + ``UNEXPECTED_TOOL_CALL`` — tool-call protocol errors. The + tool_calls are *also* unreliable + here, but the failure category is + distinct from safety filtering; + handle in a dedicated detector to + keep observability records honest. + - ``OTHER`` / ``IMAGE_OTHER`` / + ``FINISH_REASON_UNSPECIFIED`` — too broad to enable by default; + opt in via ``finish_reasons=`` if + your provider abuses these. + """ + + name = "gemini_safety" + + _DEFAULT_FINISH_REASONS = ( + # Text safety + "SAFETY", + "BLOCKLIST", + "PROHIBITED_CONTENT", + "SPII", + "RECITATION", + # Image safety (multimodal generation) + "IMAGE_SAFETY", + "IMAGE_PROHIBITED_CONTENT", + "IMAGE_RECITATION", + ) + + def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None: + configured = finish_reasons if finish_reasons is not None else self._DEFAULT_FINISH_REASONS + self._finish_reasons: frozenset[str] = frozenset(r.upper() for r in configured) + + def detect(self, message: AIMessage) -> SafetyTermination | None: + value = _get_metadata_value(message, "finish_reason") + if value is None or value.upper() not in self._finish_reasons: + return None + + extras: dict[str, Any] = {} + response_metadata = getattr(message, "response_metadata", None) or {} + if isinstance(response_metadata, dict): + # Gemini surfaces per-category scoring under safety_ratings. + ratings = response_metadata.get("safety_ratings") + if ratings: + extras["safety_ratings"] = ratings + + return SafetyTermination( + detector=self.name, + reason_field="finish_reason", + reason_value=value, + extras=extras, + ) + + +def default_detectors() -> list[SafetyTerminationDetector]: + """Built-in detector set used when no custom detectors are configured.""" + return [ + OpenAICompatibleContentFilterDetector(), + AnthropicRefusalDetector(), + GeminiSafetyDetector(), + ] + + +__all__ = [ + "AnthropicRefusalDetector", + "GeminiSafetyDetector", + "OpenAICompatibleContentFilterDetector", + "SafetyTermination", + "SafetyTerminationDetector", + "default_detectors", +] diff --git a/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py index 4393bd360..ae3522454 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py @@ -164,4 +164,14 @@ def build_subagent_runtime_middlewares( middlewares.append(ViewImageMiddleware()) + # Same provider safety-termination guard the lead agent uses — subagents + # are equally exposed to truncated tool_calls returned with + # finish_reason=content_filter (and friends), and the bad call would then + # propagate back to the lead agent via the task tool result. + safety_config = app_config.safety_finish_reason + if safety_config.enabled: + from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware + + middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config)) + return middlewares diff --git a/backend/packages/harness/deerflow/config/app_config.py b/backend/packages/harness/deerflow/config/app_config.py index d470d6558..931c95757 100644 --- a/backend/packages/harness/deerflow/config/app_config.py +++ b/backend/packages/harness/deerflow/config/app_config.py @@ -20,6 +20,7 @@ from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_ from deerflow.config.model_config import ModelConfig from deerflow.config.run_events_config import RunEventsConfig from deerflow.config.runtime_paths import existing_project_file +from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig from deerflow.config.sandbox_config import SandboxConfig from deerflow.config.skill_evolution_config import SkillEvolutionConfig from deerflow.config.skills_config import SkillsConfig @@ -102,6 +103,7 @@ class AppConfig(BaseModel): guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration") circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration") loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration") + safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration") model_config = ConfigDict(extra="allow") database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration") run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration") diff --git a/backend/packages/harness/deerflow/config/safety_finish_reason_config.py b/backend/packages/harness/deerflow/config/safety_finish_reason_config.py new file mode 100644 index 000000000..0e8adebc5 --- /dev/null +++ b/backend/packages/harness/deerflow/config/safety_finish_reason_config.py @@ -0,0 +1,47 @@ +"""Configuration for SafetyFinishReasonMiddleware. + +Mirrors the shape of GuardrailsConfig: detectors are loaded by class path +through ``deerflow.reflection.resolve_variable`` (same loader the +``guardrails.provider`` config uses) so users can drop in custom provider +detectors without modifying core code. +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class SafetyDetectorConfig(BaseModel): + """One detector entry under ``safety_finish_reason.detectors``.""" + + use: str = Field( + description=("Class path of a SafetyTerminationDetector implementation (e.g. 'deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector')."), + ) + config: dict = Field( + default_factory=dict, + description="Constructor kwargs passed to the detector class.", + ) + + +class SafetyFinishReasonConfig(BaseModel): + """Configuration for the SafetyFinishReasonMiddleware. + + The middleware intercepts AIMessages where the provider signaled a + safety-related termination (e.g. OpenAI ``finish_reason='content_filter'``) + while still returning tool calls, and suppresses those tool calls so the + half-truncated arguments never execute. + """ + + enabled: bool = Field( + default=True, + description="Master switch for the SafetyFinishReasonMiddleware.", + ) + detectors: list[SafetyDetectorConfig] | None = Field( + default=None, + description=( + "Custom detector list. Leave unset (None) to use the built-in " + "set covering OpenAI-compatible content_filter, Anthropic " + "refusal, and Gemini SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/" + "RECITATION. Provide a non-null list to fully override." + ), + ) diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index aa47cd39b..694464fe3 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -219,6 +219,12 @@ async def run_agent( # manually here because we drive the graph through ``agent.astream(config=...)`` # without passing the official ``context=`` parameter. runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"), ctx.app_config) + # Expose the run-scoped journal under a sentinel key so middleware can + # write audit events (e.g. SafetyFinishReasonMiddleware recording + # suppressed tool calls). Double-underscore prefix marks it as a + # runtime-internal channel; user code must not depend on the key name. + if journal is not None: + runtime_ctx["__run_journal"] = journal _install_runtime_context(config, runtime_ctx) runtime = Runtime(context=cast(Any, runtime_ctx), store=store) config.setdefault("configurable", {})["__pregel_runtime"] = runtime diff --git a/backend/scripts/e2e_safety_termination_demo.py b/backend/scripts/e2e_safety_termination_demo.py new file mode 100644 index 000000000..7fd27b23f --- /dev/null +++ b/backend/scripts/e2e_safety_termination_demo.py @@ -0,0 +1,206 @@ +"""End-to-end demo: SafetyFinishReasonMiddleware on the real DeerFlow lead-agent. + +What it proves +-------------- +- The real ``make_lead_agent`` / ``DeerFlowClient`` pipeline is built (full + 18-middleware chain, sandbox, tools, etc.). +- A model that returns ``finish_reason='content_filter'`` + ``tool_calls`` + triggers SafetyFinishReasonMiddleware. +- LangChain's tool router never invokes ``write_file`` — the truncated + arguments do **not** reach the sandbox. +- A ``safety_termination`` custom event is emitted on the stream and the + final AIMessage carries the observability stamp. + +Run from backend/ directory: + PYTHONPATH=. uv run python scripts/e2e_safety_termination_demo.py +""" + +from __future__ import annotations + +import sys +from typing import Any + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, ChatResult + +# --------------------------------------------------------------------------- +# Fake provider that mimics Moonshot's content_filter behaviour +# --------------------------------------------------------------------------- + + +class _ContentFilteredFakeModel(BaseChatModel): + """First call returns finish_reason=content_filter + truncated write_file + tool_call. Subsequent calls return a normal stop response so the agent + can terminate (the middleware should make a second call unnecessary by + clearing tool_calls, but we keep this safety net in case loop-detection + or anything else triggers another model invocation).""" + + call_count: int = 0 + + @property + def _llm_type(self) -> str: + return "fake-content-filtered" + + def bind_tools(self, tools, **kwargs): + return self + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + self.call_count += 1 + if self.call_count == 1: + msg = AIMessage( + content="# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与", + tool_calls=[ + { + "id": "call_truncated_write", + "name": "write_file", + "args": { + "path": "/mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md", + "content": "# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与", + }, + } + ], + response_metadata={ + "finish_reason": "content_filter", + "model_name": "kimi-k2.6", + "model_provider": "openai", + }, + ) + else: + msg = AIMessage( + content="(secondary call, should not be needed)", + response_metadata={"finish_reason": "stop", "model_name": "kimi-k2.6"}, + ) + return ChatResult(generations=[ChatGeneration(message=msg)]) + + async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs): + return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs) + + +# --------------------------------------------------------------------------- +# Driver +# --------------------------------------------------------------------------- + + +def main() -> int: + # Inject the fake model BEFORE constructing the client. Both the + # client module and the lead-agent module bind ``create_chat_model`` + # at import time via ``from deerflow.models import create_chat_model``, + # so we patch both attribute slots — the source-of-truth patch on + # ``factory.create_chat_model`` doesn't propagate back into already- + # imported names. + import deerflow.agents.lead_agent.agent as lead_agent_module + import deerflow.client as client_module + + fake = _ContentFilteredFakeModel() + originals = { + "lead": lead_agent_module.create_chat_model, + "client": client_module.create_chat_model, + } + + def fake_create_chat_model(*args, **kwargs): + return fake + + lead_agent_module.create_chat_model = fake_create_chat_model + client_module.create_chat_model = fake_create_chat_model + + from deerflow.client import DeerFlowClient + + try: + client = DeerFlowClient() + + print("\n=== Streaming a turn through the real lead-agent ===") + events: list[dict[str, Any]] = [] + for event in client.stream( + "帮我整理一下最近一周政经新闻,写到 /mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md", + thread_id="e2e-safety-1", + ): + events.append({"type": event.type, "data": event.data}) + + # ---- Assertions ---- + safety_event = next( + (e for e in events if e["type"] == "custom" and isinstance(e["data"], dict) and e["data"].get("type") == "safety_termination"), + None, + ) + final_values = next( + (e for e in reversed(events) if e["type"] == "values"), + None, + ) + tool_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "tool"] + ai_tool_call_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "ai" and e["data"].get("tool_calls")] + + print(f"\n[stats] total stream events: {len(events)}") + print(f"[stats] model call count: {fake.call_count}") + print(f"[stats] tool messages on stream: {len(tool_messages)}") + print(f"[stats] AI messages carrying tool_calls: {len(ai_tool_call_messages)}") + + print("\n[event] safety_termination custom event:") + if safety_event is None: + print(" *** NOT FOUND ***") + return 1 + for k, v in safety_event["data"].items(): + print(f" {k}: {v}") + + print("\n[state] final AIMessage from last values snapshot:") + if final_values is None: + print(" *** no values snapshot ***") + return 1 + # `values` event carries `_serialize_message` dicts, not Message objects. + final_messages = final_values["data"].get("messages") or [] + last_ai = next((m for m in reversed(final_messages) if isinstance(m, dict) and m.get("type") == "ai"), None) + if last_ai is None: + print(" *** no AIMessage in final state ***") + print(f" message types seen: {[m.get('type') if isinstance(m, dict) else type(m).__name__ for m in final_messages]}") + return 1 + + tool_calls = last_ai.get("tool_calls") or [] + additional_kwargs = last_ai.get("additional_kwargs") or {} + response_metadata = last_ai.get("response_metadata") or {} + content = last_ai.get("content") + + print(f" tool_calls (must be empty): {tool_calls}") + print(f" additional_kwargs.safety_termination: {additional_kwargs.get('safety_termination')}") + content_preview = (content if isinstance(content, str) else str(content))[:200] + print(f" content[:200]: {content_preview!r}") + print(f" response_metadata.finish_reason: {response_metadata.get('finish_reason')}") + + # NOTE: `client._serialize_message` does not include `response_metadata` + # in the values-event payload (client-layer behaviour, unrelated to the + # middleware). The middleware *does* preserve finish_reason on the + # AIMessage object — see test_safety_finish_reason_middleware.py:: + # TestMessageRewrite::test_preserves_response_metadata_finish_reason. + # Here we assert on the observability stamp, which carries the same + # evidence and is in the serialized payload. + stamp = additional_kwargs.get("safety_termination") or {} + failures = [] + if tool_calls: + failures.append("final AIMessage still has tool_calls — middleware did NOT clear them") + if not stamp: + failures.append("final AIMessage missing safety_termination observability stamp") + if tool_messages: + failures.append(f"tool node was invoked: {len(tool_messages)} ToolMessage(s) on stream") + if stamp.get("reason_value") != "content_filter": + failures.append(f"safety_termination.reason_value was {stamp.get('reason_value')!r}, expected 'content_filter'") + if safety_event is None: + failures.append("safety_termination custom event was not emitted on the stream") + + if failures: + print("\n=== FAIL ===") + for f in failures: + print(f" - {f}") + return 1 + + print("\n=== PASS ===") + print(" - tool_calls cleared on final AIMessage") + print(" - tool node never invoked (no ToolMessage on stream)") + print(" - safety_termination custom event emitted") + print(" - observability stamp written to additional_kwargs") + print(" - response_metadata.finish_reason preserved for downstream SSE") + return 0 + finally: + lead_agent_module.create_chat_model = originals["lead"] + client_module.create_chat_model = originals["client"] + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/backend/tests/test_lead_agent_model_resolution.py b/backend/tests/test_lead_agent_model_resolution.py index 7ac4b97e6..a12a754c2 100644 --- a/backend/tests/test_lead_agent_model_resolution.py +++ b/backend/tests/test_lead_agent_model_resolution.py @@ -336,8 +336,11 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch): ) assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares) - # verify the custom middleware is injected correctly - assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock) + # verify the custom middleware is injected correctly. + # Chain tail order after the custom middleware is: + # ..., custom, SafetyFinishReasonMiddleware, ClarificationMiddleware + # so the custom mock sits at index [-3]. + assert len(middlewares) > 0 and isinstance(middlewares[-3], MagicMock) def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch): diff --git a/backend/tests/test_safety_finish_reason_graph_integration.py b/backend/tests/test_safety_finish_reason_graph_integration.py new file mode 100644 index 000000000..f26a7be90 --- /dev/null +++ b/backend/tests/test_safety_finish_reason_graph_integration.py @@ -0,0 +1,225 @@ +"""End-to-end graph integration test for SafetyFinishReasonMiddleware. + +Unit tests prove ``_apply`` does the right thing on a synthetic state. +This test does one level up: builds a real ``langchain.agents.create_agent`` +graph with the SafetyFinishReasonMiddleware in place, feeds it a fake model +that returns ``finish_reason='content_filter'`` + tool_calls, and asserts: + + 1. The tool node is **not** invoked (the dangerous truncated tool call + is suppressed). + 2. The final AIMessage in graph state has ``tool_calls == []``. + 3. The observability ``safety_termination`` record is attached. + 4. The user-facing explanation is appended to the message content. + +This is the closest we can get to the issue's failure mode without a live +Moonshot key, and it proves the middleware actually gates LangChain's +tool router — not just rewrites state in isolation. +""" + +from __future__ import annotations + +from typing import Any + +from langchain.agents import create_agent +from langchain.agents.middleware import AgentMiddleware +from langchain.agents.middleware.types import ModelRequest, ModelResponse +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.tools import tool + +from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware + +_TOOL_INVOCATIONS: list[dict[str, Any]] = [] + + +@tool +def write_file(path: str, content: str) -> str: + """Pretend to write *content* to *path*. Records the call for assertion.""" + _TOOL_INVOCATIONS.append({"path": path, "content": content}) + return f"wrote {len(content)} bytes to {path}" + + +class _ContentFilteredModel(BaseChatModel): + """Fake chat model that mimics OpenAI/Moonshot's content_filter response. + + First call returns finish_reason='content_filter' + a tool_call whose + arguments are visibly truncated. Second call (if reached) returns a + normal text completion so the agent can terminate cleanly. + """ + + call_count: int = 0 + + @property + def _llm_type(self) -> str: + return "fake-content-filtered" + + def bind_tools(self, tools, **kwargs): + # create_agent binds tools onto the model; we don't actually need + # to bind anything since responses are hard-coded, but the method + # must not raise. + return self + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + self.call_count += 1 + if self.call_count == 1: + message = AIMessage( + content="Here is the report:\n# Weekly Politics\n- Meeting time: 2026-05-12—", + tool_calls=[ + { + "id": "call_truncated_1", + "name": "write_file", + "args": { + "path": "/mnt/user-data/outputs/report.md", + "content": "# Weekly Politics\n- Meeting time: 2026-05-12—", + }, + } + ], + response_metadata={"finish_reason": "content_filter", "model_name": "fake-kimi"}, + ) + else: + message = AIMessage(content="ack", response_metadata={"finish_reason": "stop"}) + return ChatResult(generations=[ChatGeneration(message=message)]) + + async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs): + return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs) + + +class _InspectMiddleware(AgentMiddleware): + """Captures the messages list at every model entry so we can assert + no synthetic tool result was injected back into the conversation.""" + + def __init__(self) -> None: + super().__init__() + self.observed: list[list[Any]] = [] + + def wrap_model_call(self, request: ModelRequest, handler) -> ModelResponse: + self.observed.append(list(request.messages)) + return handler(request) + + +def test_content_filter_with_tool_calls_does_not_invoke_tool_node(): + _TOOL_INVOCATIONS.clear() + inspector = _InspectMiddleware() + + agent = create_agent( + model=_ContentFilteredModel(), + tools=[write_file], + # Inspector first so its after_model is registered; Safety last in + # the list so it executes first under LIFO (matches production wiring). + middleware=[inspector, SafetyFinishReasonMiddleware()], + ) + + result = agent.invoke({"messages": [HumanMessage(content="write me a report")]}) + + # Critical assertion: the dangerous truncated tool call must NOT have + # been executed. This is the entire point of the middleware. + assert _TOOL_INVOCATIONS == [], f"write_file was invoked despite content_filter: {_TOOL_INVOCATIONS}" + + # Final AIMessage has no tool calls left. + final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage)) + assert final_ai.tool_calls == [] + + # Observability stamp is present. + record = final_ai.additional_kwargs.get("safety_termination") + assert record is not None + assert record["detector"] == "openai_compatible_content_filter" + assert record["reason_field"] == "finish_reason" + assert record["reason_value"] == "content_filter" + assert record["suppressed_tool_call_count"] == 1 + assert record["suppressed_tool_call_names"] == ["write_file"] + + # User-facing explanation is appended. + assert "safety-related signal" in final_ai.content + # Original partial text preserved (we don't throw away what the user + # already saw in the stream — see middleware docstring). + assert "Weekly Politics" in final_ai.content + + # finish_reason on response_metadata is preserved (so SSE / converters + # downstream still see the real provider reason). + assert final_ai.response_metadata.get("finish_reason") == "content_filter" + + +def test_content_filter_without_tool_calls_passes_through_unchanged(): + """No tool calls => issue scope says don't intervene; the partial + response should be delivered as-is so the user sees what they got.""" + _TOOL_INVOCATIONS.clear() + + class _NoToolModel(BaseChatModel): + @property + def _llm_type(self) -> str: + return "fake-no-tool" + + def bind_tools(self, tools, **kwargs): + return self + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + msg = AIMessage( + content="Partial answer truncated by safety filter", + response_metadata={"finish_reason": "content_filter"}, + ) + return ChatResult(generations=[ChatGeneration(message=msg)]) + + async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs): + return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs) + + agent = create_agent( + model=_NoToolModel(), + tools=[write_file], + middleware=[SafetyFinishReasonMiddleware()], + ) + result = agent.invoke({"messages": [HumanMessage(content="hi")]}) + final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage)) + + # Content untouched. + assert final_ai.content == "Partial answer truncated by safety filter" + # No safety_termination stamp because we didn't intervene. + assert "safety_termination" not in final_ai.additional_kwargs + # tool node never ran (there were no tool calls in the first place). + assert _TOOL_INVOCATIONS == [] + + +def test_normal_tool_call_round_trip_is_not_affected(): + """Regression: a healthy finish_reason='tool_calls' response must still + execute the tool. The middleware must not over-fire.""" + _TOOL_INVOCATIONS.clear() + + class _HealthyToolModel(BaseChatModel): + call_count: int = 0 + + @property + def _llm_type(self) -> str: + return "fake-healthy" + + def bind_tools(self, tools, **kwargs): + return self + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + self.call_count += 1 + if self.call_count == 1: + msg = AIMessage( + content="", + tool_calls=[ + { + "id": "call_ok", + "name": "write_file", + "args": {"path": "/tmp/ok", "content": "complete content"}, + } + ], + response_metadata={"finish_reason": "tool_calls"}, + ) + else: + msg = AIMessage(content="done", response_metadata={"finish_reason": "stop"}) + return ChatResult(generations=[ChatGeneration(message=msg)]) + + async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs): + return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs) + + agent = create_agent( + model=_HealthyToolModel(), + tools=[write_file], + middleware=[SafetyFinishReasonMiddleware()], + ) + agent.invoke({"messages": [HumanMessage(content="write")]}) + + assert _TOOL_INVOCATIONS == [{"path": "/tmp/ok", "content": "complete content"}] diff --git a/backend/tests/test_safety_finish_reason_middleware.py b/backend/tests/test_safety_finish_reason_middleware.py new file mode 100644 index 000000000..14c6226dd --- /dev/null +++ b/backend/tests/test_safety_finish_reason_middleware.py @@ -0,0 +1,651 @@ +"""Unit tests for SafetyFinishReasonMiddleware.""" + +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + +from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware +from deerflow.agents.middlewares.safety_termination_detectors import ( + SafetyTermination, +) +from deerflow.config.safety_finish_reason_config import ( + SafetyDetectorConfig, + SafetyFinishReasonConfig, +) + + +def _runtime(thread_id="t-1"): + runtime = MagicMock() + runtime.context = {"thread_id": thread_id} + return runtime + + +def _ai( + *, + content="", + tool_calls=None, + response_metadata=None, + additional_kwargs=None, +): + return AIMessage( + content=content, + tool_calls=tool_calls or [], + response_metadata=response_metadata or {}, + additional_kwargs=additional_kwargs or {}, + ) + + +def _write_call(idx=1, content_text="半截"): + return { + "id": f"call_write_{idx}", + "name": "write_file", + "args": {"path": "/mnt/user-data/outputs/x.md", "content": content_text}, + } + + +class AlwaysHitDetector: + """Test fixture: always reports the given termination.""" + + name = "always_hit" + + def __init__(self, *, reason_field="finish_reason", reason_value="content_filter", extras=None): + self.reason_field = reason_field + self.reason_value = reason_value + self.extras = extras or {} + + def detect(self, message): + return SafetyTermination( + detector=self.name, + reason_field=self.reason_field, + reason_value=self.reason_value, + extras=self.extras, + ) + + +class NeverHitDetector: + name = "never_hit" + + def detect(self, message): + return None + + +class RaisingDetector: + name = "raising" + + def detect(self, message): + raise RuntimeError("boom") + + +# --------------------------------------------------------------------------- +# Core trigger behaviour +# --------------------------------------------------------------------------- + + +class TestTriggerCriteria: + def test_content_filter_with_tool_calls_triggers(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + content="partial", + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + result = mw._apply(state, _runtime()) + assert result is not None + patched = result["messages"][0] + assert patched.tool_calls == [] + + def test_content_filter_without_tool_calls_passes_through(self): + """issue scope: when there are no tool calls the partial text is a + legitimate final response and should not be rewritten.""" + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + content="partial response", + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + assert mw._apply(state, _runtime()) is None + + def test_normal_tool_calls_pass_through(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "tool_calls"}, + ) + ] + } + assert mw._apply(state, _runtime()) is None + + def test_normal_stop_with_tool_calls_pass_through(self): + # Some providers report finish_reason='stop' for tool-call messages. + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "stop"}, + ) + ] + } + assert mw._apply(state, _runtime()) is None + + def test_empty_message_list_passes_through(self): + mw = SafetyFinishReasonMiddleware() + assert mw._apply({"messages": []}, _runtime()) is None + + def test_non_ai_last_message_passes_through(self): + mw = SafetyFinishReasonMiddleware() + state = {"messages": [HumanMessage(content="hi"), SystemMessage(content="sys")]} + assert mw._apply(state, _runtime()) is None + + def test_anthropic_refusal_with_tool_calls_triggers(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"stop_reason": "refusal"}, + ) + ] + } + result = mw._apply(state, _runtime()) + assert result is not None + assert result["messages"][0].tool_calls == [] + + def test_gemini_safety_with_tool_calls_triggers(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "SAFETY"}, + ) + ] + } + result = mw._apply(state, _runtime()) + assert result is not None + assert result["messages"][0].tool_calls == [] + + +# --------------------------------------------------------------------------- +# Message rewriting +# --------------------------------------------------------------------------- + + +class TestMessageRewrite: + def test_clears_structured_tool_calls(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call(1), _write_call(2)], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + result = mw._apply(state, _runtime()) + patched = result["messages"][0] + assert patched.tool_calls == [] + + def test_clears_raw_additional_kwargs_tool_calls(self): + """Critical defence-in-depth: DanglingToolCallMiddleware will recover + tool calls from additional_kwargs.tool_calls if we forget them, which + would re-emit a synthetic ToolMessage downstream and confuse the + model. We must wipe both.""" + mw = SafetyFinishReasonMiddleware() + raw_tool_calls = [ + { + "id": "call_write_1", + "type": "function", + "function": {"name": "write_file", "arguments": '{"path": "/x"}'}, + } + ] + state = { + "messages": [ + _ai( + tool_calls=[_write_call(1)], + response_metadata={"finish_reason": "content_filter"}, + additional_kwargs={ + "tool_calls": raw_tool_calls, + "function_call": {"name": "write_file", "arguments": "{}"}, + }, + ) + ] + } + result = mw._apply(state, _runtime()) + patched = result["messages"][0] + assert "tool_calls" not in patched.additional_kwargs + assert "function_call" not in patched.additional_kwargs + + def test_preserves_other_additional_kwargs(self): + # vLLM puts reasoning under additional_kwargs.reasoning; Anthropic + # may carry other provider-specific keys. They must not be wiped. + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + additional_kwargs={ + "reasoning": "thinking text", + "custom_provider_field": {"x": 1}, + }, + ) + ] + } + patched = mw._apply(state, _runtime())["messages"][0] + assert patched.additional_kwargs["reasoning"] == "thinking text" + assert patched.additional_kwargs["custom_provider_field"] == {"x": 1} + + def test_writes_observability_field(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call(1), _write_call(2)], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + patched = mw._apply(state, _runtime())["messages"][0] + record = patched.additional_kwargs["safety_termination"] + assert record["detector"] == "openai_compatible_content_filter" + assert record["reason_field"] == "finish_reason" + assert record["reason_value"] == "content_filter" + assert record["suppressed_tool_call_count"] == 2 + assert record["suppressed_tool_call_names"] == ["write_file", "write_file"] + + def test_preserves_response_metadata_finish_reason(self): + """Downstream SSE converters read response_metadata.finish_reason — + we want them to see the *real* provider reason, not 'stop'.""" + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter", "model_name": "kimi-k2"}, + ) + ] + } + patched = mw._apply(state, _runtime())["messages"][0] + assert patched.response_metadata["finish_reason"] == "content_filter" + assert patched.response_metadata["model_name"] == "kimi-k2" + + def test_appends_user_facing_explanation_to_str_content(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + content="some partial text", + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + patched = mw._apply(state, _runtime())["messages"][0] + assert isinstance(patched.content, str) + assert patched.content.startswith("some partial text") + assert "safety-related signal" in patched.content + + def test_handles_empty_content(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + content="", + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + patched = mw._apply(state, _runtime())["messages"][0] + assert isinstance(patched.content, str) + assert "safety-related signal" in patched.content + + def test_handles_list_content_thinking_blocks(self): + """Anthropic thinking / vLLM reasoning models emit content blocks. + Naively concatenating a string would raise TypeError.""" + mw = SafetyFinishReasonMiddleware() + thinking_blocks = [ + {"type": "thinking", "text": "let me consider..."}, + {"type": "text", "text": "partial answer"}, + ] + state = { + "messages": [ + _ai( + content=thinking_blocks, + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + patched = mw._apply(state, _runtime())["messages"][0] + assert isinstance(patched.content, list) + assert patched.content[:2] == thinking_blocks + assert patched.content[-1]["type"] == "text" + assert "safety-related signal" in patched.content[-1]["text"] + + def test_idempotent_on_already_cleared_message(self): + # Re-running the middleware on a message we already cleared must not + # re-trigger (tool_calls is now empty → fast passthrough). + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + first = mw._apply(state, _runtime()) + state2 = {"messages": [first["messages"][0]]} + second = mw._apply(state2, _runtime()) + assert second is None + + def test_preserves_message_id_for_add_messages_replacement(self): + """LangGraph's add_messages reducer treats same-id messages as + replacements. model_copy keeps id by default.""" + mw = SafetyFinishReasonMiddleware() + original = _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + # AIMessage auto-generates id; capture it + original_id = original.id + state = {"messages": [original]} + patched = mw._apply(state, _runtime())["messages"][0] + assert patched.id == original_id + + +# --------------------------------------------------------------------------- +# Detector wiring +# --------------------------------------------------------------------------- + + +class TestDetectorWiring: + def test_iterates_detectors_in_order(self): + first = AlwaysHitDetector(reason_value="first") + second = AlwaysHitDetector(reason_value="second") + mw = SafetyFinishReasonMiddleware(detectors=[first, second]) + state = {"messages": [_ai(tool_calls=[_write_call()])]} + patched = mw._apply(state, _runtime())["messages"][0] + assert patched.additional_kwargs["safety_termination"]["reason_value"] == "first" + + def test_returns_none_when_no_detector_matches(self): + mw = SafetyFinishReasonMiddleware(detectors=[NeverHitDetector(), NeverHitDetector()]) + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + assert mw._apply(state, _runtime()) is None + + def test_buggy_detector_does_not_break_run(self): + mw = SafetyFinishReasonMiddleware(detectors=[RaisingDetector(), AlwaysHitDetector()]) + state = {"messages": [_ai(tool_calls=[_write_call()])]} + result = mw._apply(state, _runtime()) + assert result is not None + assert result["messages"][0].additional_kwargs["safety_termination"]["detector"] == "always_hit" + + def test_constructor_copies_detectors(self): + """Caller mutation after construction must not leak into us.""" + detectors = [AlwaysHitDetector()] + mw = SafetyFinishReasonMiddleware(detectors=detectors) + detectors.clear() + state = {"messages": [_ai(tool_calls=[_write_call()])]} + assert mw._apply(state, _runtime()) is not None + + +# --------------------------------------------------------------------------- +# from_config +# --------------------------------------------------------------------------- + + +class TestFromConfig: + def test_default_config_uses_builtin_detectors(self): + mw = SafetyFinishReasonMiddleware.from_config(SafetyFinishReasonConfig()) + assert len(mw._detectors) == 3 + names = {d.name for d in mw._detectors} + assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"} + + def test_custom_detectors_loaded_via_reflection(self): + cfg = SafetyFinishReasonConfig( + detectors=[ + SafetyDetectorConfig( + use="deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector", + config={"finish_reasons": ["custom_filter"]}, + ), + ] + ) + mw = SafetyFinishReasonMiddleware.from_config(cfg) + assert len(mw._detectors) == 1 + # Confirm the kwargs propagated. + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "custom_filter"}, + ) + ] + } + assert mw._apply(state, _runtime()) is not None + # Default token no longer matches. + state2 = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + assert mw._apply(state2, _runtime()) is None + + def test_empty_detector_list_rejected(self): + cfg = SafetyFinishReasonConfig(detectors=[]) + with pytest.raises(ValueError, match="enabled=false"): + SafetyFinishReasonMiddleware.from_config(cfg) + + def test_non_detector_class_rejected(self): + cfg = SafetyFinishReasonConfig( + detectors=[SafetyDetectorConfig(use="builtins:dict")], + ) + with pytest.raises(TypeError): + SafetyFinishReasonMiddleware.from_config(cfg) + + +# --------------------------------------------------------------------------- +# Stream event +# --------------------------------------------------------------------------- + + +class TestAuditEvent: + """Verify SafetyFinishReasonMiddleware records a `middleware:safety_termination` + audit event via RunJournal.record_middleware when the run-scoped journal is + exposed under runtime.context["__run_journal"]. + + Background: review on PR #3035 — SSE custom event handles live consumers, + but post-run audit needs a row in run_events that can be queried with one + SQL statement (no JOIN against message body). + """ + + def _runtime_with_journal(self, journal): + runtime = MagicMock() + runtime.context = {"thread_id": "t-audit", "__run_journal": journal} + return runtime + + def test_records_audit_event_when_journal_present(self): + journal = MagicMock() + mw = SafetyFinishReasonMiddleware() + tc = _write_call(1) + state = { + "messages": [ + _ai( + content="partial", + tool_calls=[tc], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + result = mw._apply(state, self._runtime_with_journal(journal)) + assert result is not None + + journal.record_middleware.assert_called_once() + call = journal.record_middleware.call_args + # tag is positional or kwarg depending on call style; we use kwargs. + assert call.kwargs["tag"] == "safety_termination" + assert call.kwargs["name"] == "SafetyFinishReasonMiddleware" + assert call.kwargs["hook"] == "after_model" + assert call.kwargs["action"] == "suppress_tool_calls" + + changes = call.kwargs["changes"] + assert changes["detector"] == "openai_compatible_content_filter" + assert changes["reason_field"] == "finish_reason" + assert changes["reason_value"] == "content_filter" + assert changes["suppressed_tool_call_count"] == 1 + assert changes["suppressed_tool_call_names"] == ["write_file"] + assert changes["suppressed_tool_call_ids"] == ["call_write_1"] + assert "message_id" in changes + assert isinstance(changes["extras"], dict) + + def test_audit_event_never_carries_tool_arguments(self): + """PR #3035 review IMPORTANT: tool args are the filtered content itself + and must NOT be persisted to run_events under any circumstance.""" + journal = MagicMock() + mw = SafetyFinishReasonMiddleware() + sensitive_tc = { + "id": "call_x", + "name": "write_file", + "args": {"path": "/x", "content": "FILTERED_CONTENT_DO_NOT_PERSIST"}, + } + state = { + "messages": [ + _ai( + tool_calls=[sensitive_tc], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + mw._apply(state, self._runtime_with_journal(journal)) + flat = repr(journal.record_middleware.call_args) + assert "FILTERED_CONTENT_DO_NOT_PERSIST" not in flat, "tool arguments must not leak into audit event" + assert "args" not in journal.record_middleware.call_args.kwargs["changes"] + + def test_no_journal_in_runtime_context_is_silently_skipped(self): + """Subagent runtime / unit tests / no-event-store paths have no journal. + Middleware must still intervene and clear tool_calls — only the audit + event is skipped.""" + mw = SafetyFinishReasonMiddleware() + runtime = MagicMock() + runtime.context = {"thread_id": "t-noj"} # no __run_journal + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + # Should not raise; should still clear tool_calls. + result = mw._apply(state, runtime) + assert result is not None + assert result["messages"][0].tool_calls == [] + + def test_journal_record_exception_does_not_break_run(self): + """Buggy journal must never propagate an exception into the agent loop.""" + journal = MagicMock() + journal.record_middleware.side_effect = RuntimeError("db down") + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + # Must not raise. + result = mw._apply(state, self._runtime_with_journal(journal)) + assert result is not None + assert result["messages"][0].tool_calls == [] + + def test_no_record_when_passthrough(self): + """When the middleware does NOT intervene, no audit event is written.""" + journal = MagicMock() + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "tool_calls"}, # healthy + ) + ] + } + assert mw._apply(state, self._runtime_with_journal(journal)) is None + journal.record_middleware.assert_not_called() + + +class TestStreamEvent: + def test_emits_event_when_writer_available(self, monkeypatch): + captured: list = [] + + def fake_writer(payload): + captured.append(payload) + + # Patch get_stream_writer at the symbol-resolution site. + import langgraph.config + + monkeypatch.setattr(langgraph.config, "get_stream_writer", lambda: fake_writer) + + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + mw._apply(state, _runtime("t-stream")) + + assert len(captured) == 1 + payload = captured[0] + assert payload["type"] == "safety_termination" + assert payload["detector"] == "openai_compatible_content_filter" + assert payload["reason_field"] == "finish_reason" + assert payload["reason_value"] == "content_filter" + assert payload["suppressed_tool_call_count"] == 1 + assert payload["suppressed_tool_call_names"] == ["write_file"] + assert payload["thread_id"] == "t-stream" + + def test_writer_unavailable_does_not_break(self, monkeypatch): + import langgraph.config + + def boom(): + raise LookupError("not in a stream context") + + monkeypatch.setattr(langgraph.config, "get_stream_writer", boom) + + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + # Should not raise. + result = mw._apply(state, _runtime()) + assert result is not None diff --git a/backend/tests/test_safety_termination_detectors.py b/backend/tests/test_safety_termination_detectors.py new file mode 100644 index 000000000..0679aed0e --- /dev/null +++ b/backend/tests/test_safety_termination_detectors.py @@ -0,0 +1,176 @@ +"""Unit tests for SafetyTerminationDetector built-ins.""" + +from langchain_core.messages import AIMessage + +from deerflow.agents.middlewares.safety_termination_detectors import ( + AnthropicRefusalDetector, + GeminiSafetyDetector, + OpenAICompatibleContentFilterDetector, + SafetyTermination, + SafetyTerminationDetector, + default_detectors, +) + + +def _ai(*, content="", tool_calls=None, response_metadata=None, additional_kwargs=None) -> AIMessage: + return AIMessage( + content=content, + tool_calls=tool_calls or [], + response_metadata=response_metadata or {}, + additional_kwargs=additional_kwargs or {}, + ) + + +class TestOpenAICompatibleContentFilterDetector: + def test_default_matches_content_filter(self): + d = OpenAICompatibleContentFilterDetector() + hit = d.detect(_ai(response_metadata={"finish_reason": "content_filter"})) + assert hit is not None + assert hit.detector == "openai_compatible_content_filter" + assert hit.reason_field == "finish_reason" + assert hit.reason_value == "content_filter" + + def test_case_insensitive_match(self): + d = OpenAICompatibleContentFilterDetector() + assert d.detect(_ai(response_metadata={"finish_reason": "CONTENT_FILTER"})) is not None + + def test_other_finish_reasons_pass_through(self): + d = OpenAICompatibleContentFilterDetector() + assert d.detect(_ai(response_metadata={"finish_reason": "stop"})) is None + assert d.detect(_ai(response_metadata={"finish_reason": "tool_calls"})) is None + assert d.detect(_ai(response_metadata={"finish_reason": "length"})) is None + + def test_missing_metadata_passes_through(self): + d = OpenAICompatibleContentFilterDetector() + assert d.detect(_ai()) is None + + def test_non_string_finish_reason_passes_through(self): + # Some adapters may stash an enum or dict — must not raise. + d = OpenAICompatibleContentFilterDetector() + assert d.detect(_ai(response_metadata={"finish_reason": 42})) is None + assert d.detect(_ai(response_metadata={"finish_reason": {"value": "content_filter"}})) is None + + def test_falls_back_to_additional_kwargs(self): + # Legacy adapters surface finish_reason via additional_kwargs. + d = OpenAICompatibleContentFilterDetector() + hit = d.detect(_ai(additional_kwargs={"finish_reason": "content_filter"})) + assert hit is not None + + def test_configurable_extra_values(self): + # Chinese providers sometimes use bespoke tokens. + d = OpenAICompatibleContentFilterDetector(finish_reasons=["content_filter", "sensitive", "violation"]) + assert d.detect(_ai(response_metadata={"finish_reason": "sensitive"})) is not None + assert d.detect(_ai(response_metadata={"finish_reason": "violation"})) is not None + # Original token still matches. + assert d.detect(_ai(response_metadata={"finish_reason": "content_filter"})) is not None + + def test_carries_azure_content_filter_results(self): + d = OpenAICompatibleContentFilterDetector() + filter_results = {"hate": {"filtered": True, "severity": "high"}} + hit = d.detect( + _ai( + response_metadata={ + "finish_reason": "content_filter", + "content_filter_results": filter_results, + }, + ) + ) + assert hit is not None + assert hit.extras["content_filter_results"] == filter_results + + +class TestAnthropicRefusalDetector: + def test_default_matches_refusal(self): + hit = AnthropicRefusalDetector().detect(_ai(response_metadata={"stop_reason": "refusal"})) + assert hit is not None + assert hit.reason_field == "stop_reason" + assert hit.reason_value == "refusal" + + def test_other_stop_reasons_pass_through(self): + d = AnthropicRefusalDetector() + assert d.detect(_ai(response_metadata={"stop_reason": "end_turn"})) is None + assert d.detect(_ai(response_metadata={"stop_reason": "tool_use"})) is None + assert d.detect(_ai(response_metadata={"stop_reason": "max_tokens"})) is None + + def test_anthropic_does_not_steal_finish_reason(self): + # An OpenAI message must not accidentally trip the Anthropic detector. + assert AnthropicRefusalDetector().detect(_ai(response_metadata={"finish_reason": "content_filter"})) is None + + +class TestGeminiSafetyDetector: + def test_default_set_covers_documented_reasons(self): + d = GeminiSafetyDetector() + for reason in ( + # text safety + "SAFETY", + "BLOCKLIST", + "PROHIBITED_CONTENT", + "SPII", + "RECITATION", + # image safety + "IMAGE_SAFETY", + "IMAGE_PROHIBITED_CONTENT", + "IMAGE_RECITATION", + ): + assert d.detect(_ai(response_metadata={"finish_reason": reason})) is not None, reason + + def test_normal_termination_passes_through(self): + d = GeminiSafetyDetector() + assert d.detect(_ai(response_metadata={"finish_reason": "STOP"})) is None + # MAX_TOKENS / LANGUAGE / NO_IMAGE / OTHER / IMAGE_OTHER / + # MALFORMED_FUNCTION_CALL / UNEXPECTED_TOOL_CALL are intentionally + # excluded from the default set — they are either normal termination, + # capability mismatches, too broad (OTHER), or tool-call protocol + # errors. See GeminiSafetyDetector docstring. + for reason in ( + "MAX_TOKENS", + "LANGUAGE", + "NO_IMAGE", + "OTHER", + "IMAGE_OTHER", + "MALFORMED_FUNCTION_CALL", + "UNEXPECTED_TOOL_CALL", + "FINISH_REASON_UNSPECIFIED", + ): + assert d.detect(_ai(response_metadata={"finish_reason": reason})) is None, reason + + def test_carries_safety_ratings(self): + ratings = [{"category": "HARM_CATEGORY_HARASSMENT", "probability": "HIGH"}] + hit = GeminiSafetyDetector().detect( + _ai( + response_metadata={ + "finish_reason": "SAFETY", + "safety_ratings": ratings, + }, + ) + ) + assert hit is not None + assert hit.extras["safety_ratings"] == ratings + + +class TestDefaultDetectorSet: + def test_default_set_returns_three_detectors(self): + dets = default_detectors() + names = {d.name for d in dets} + assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"} + + def test_default_set_returns_fresh_list(self): + # Caller mutation must not affect later calls. + first = default_detectors() + first.clear() + second = default_detectors() + assert len(second) == 3 + + +class TestProtocolConformance: + def test_builtins_satisfy_protocol(self): + for d in default_detectors(): + assert isinstance(d, SafetyTerminationDetector) + + def test_safety_termination_is_frozen(self): + t = SafetyTermination(detector="x", reason_field="finish_reason", reason_value="content_filter") + try: + t.detector = "y" # type: ignore[misc] + except Exception: + return + raise AssertionError("SafetyTermination should be frozen") diff --git a/backend/tests/test_tool_error_handling_middleware.py b/backend/tests/test_tool_error_handling_middleware.py index 2c28dac35..28c59a9ad 100644 --- a/backend/tests/test_tool_error_handling_middleware.py +++ b/backend/tests/test_tool_error_handling_middleware.py @@ -134,8 +134,14 @@ def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False) assert captured["app_config"] is app_config - assert len(middlewares) == 6 - assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware) + # 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling, + # SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware + # (enabled by default — see SafetyFinishReasonConfig). + from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware + + assert len(middlewares) == 7 + assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares) + assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware) def test_wrap_tool_call_passthrough_on_success(): diff --git a/config.example.yaml b/config.example.yaml index 9ea4e4c08..8e289fac9 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -15,7 +15,7 @@ # ============================================================================ # Bump this number when the config schema changes. # Run `make config-upgrade` to merge new fields into your local config.yaml. -config_version: 9 +config_version: 10 # ============================================================================ # Logging @@ -535,6 +535,41 @@ loop_detection: # warn: 150 # hard_limit: 300 +# ============================================================================ +# Provider Safety Termination Configuration +# ============================================================================ +# Intercept AIMessages where the provider stopped generation for safety reasons +# (e.g. OpenAI finish_reason='content_filter', Anthropic stop_reason='refusal', +# Gemini finish_reason='SAFETY') while still returning tool_calls. The +# tool_calls in such responses are typically truncated/unreliable and must +# not be executed. See issue #3028 for the full failure mode. +# +# Detectors are loaded by class path via reflection (same pattern as +# guardrails / models / tools). The built-in set covers OpenAI-compatible +# content_filter, Anthropic refusal, and Gemini SAFETY/BLOCKLIST/ +# PROHIBITED_CONTENT/SPII/RECITATION. + +safety_finish_reason: + enabled: true + # Leave `detectors` unset to use the built-in detector set. Set to a + # non-empty list to fully override (use `enabled: false` to disable instead + # of providing an empty list). + # + # Example — extend the OpenAI-compatible detector for a Chinese provider + # whose gateway uses a non-standard finish_reason token: + # detectors: + # - use: deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector + # config: + # finish_reasons: ["content_filter", "sensitive", "risk_control"] + # - use: deerflow.agents.middlewares.safety_termination_detectors:AnthropicRefusalDetector + # - use: deerflow.agents.middlewares.safety_termination_detectors:GeminiSafetyDetector + # + # Example — add a custom detector for an in-house provider: + # detectors: + # - use: my_company.deerflow_ext:WenxinSafetyDetector + # config: + # error_codes: [336003, 17, 18] + # ============================================================================ # Sandbox Configuration # ============================================================================