Files
deer-flow/backend/scripts/e2e_safety_termination_demo.py
T
Xinmin Zeng be0eae9825 fix(runtime): suppress tool execution when provider safety-terminates with tool_calls (#3035)
* fix(runtime): suppress tool execution when provider safety-terminates with tool_calls

When a provider stops generation for safety reasons (OpenAI/Moonshot
finish_reason=content_filter, Anthropic stop_reason=refusal, Gemini
finish_reason=SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/RECITATION/
IMAGE_SAFETY/...), the response may still carry truncated tool_calls.
LangChain's tool router treats any non-empty tool_calls as executable,
so partial arguments (e.g. write_file with a half-finished markdown)
get dispatched and the agent loops on retry.

Add SafetyFinishReasonMiddleware at after_model: detect safety
termination via a pluggable detector registry, clear both structured
tool_calls and raw additional_kwargs.tool_calls / function_call,
preserve response_metadata.finish_reason for downstream observers,
stamp additional_kwargs.safety_termination for traces, append a
user-facing explanation to message content (list-aware for thinking
blocks), and emit a safety_termination custom stream event so SSE
consumers can reconcile any "tool starting..." UI.

Default detectors cover OpenAI-compatible content_filter, Anthropic
refusal, and Gemini safety enums (text + image). Custom providers are
added via reflection (same pattern as guardrails). Wired into both
lead-agent and subagent runtimes.

Closes #3028

* fix(runtime): persist safety_termination as a middleware audit event

Address review on #3035: the SSE custom event is great for live
consumers but invisible to post-run audit. RunEventStore should carry
its own row so operators can answer "which runs were safety-suppressed
today?" from a single SQL query without joining the message body.

Worker now exposes the run-scoped RunJournal via
runtime.context["__run_journal"] (sentinel key, internal channel).
SafetyFinishReasonMiddleware calls the previously-unused
RunJournal.record_middleware, which emits

  event_type = "middleware:safety_termination"
  category   = "middleware"
  content    = {name, hook, action, changes={
                  detector, reason_field, reason_value,
                  suppressed_tool_call_count,
                  suppressed_tool_call_names,
                  suppressed_tool_call_ids,
                  message_id, extras}}

Tool *arguments* are deliberately excluded — those are the very content
the provider filtered and persisting them would defeat the purpose of
the safety filter (per review note in #3035).

Graceful skips when journal is absent (subagent runtime, unit tests,
no-event-store local dev). Journal exceptions never propagate into the
agent loop.

Refs #3028

* fix(runtime): satisfy ruff format + address Copilot review

- ruff format on safety_finish_reason_config.py and e2e demo (CI lint
  failed on ruff format --check; backend Makefile lint target runs
  ruff check AND ruff format --check).
- Docstring on SafetyFinishReasonConfig now says resolve_variable to
  match the actual loader used in from_config (the wording was
  resolve_class previously; behavior is unchanged — resolve_variable
  mirrors how guardrails.provider is loaded).
- Switch the AIMessage type check in SafetyFinishReasonMiddleware._apply
  from getattr(last, "type") == "ai" to isinstance(last, AIMessage),
  matching TokenUsageMiddleware / TodoMiddleware / ViewImageMiddleware
  / SummarizationMiddleware which are the dominant pattern.

Refs #3028
2026-05-22 21:20:28 +08:00

207 lines
9.1 KiB
Python

"""End-to-end demo: SafetyFinishReasonMiddleware on the real DeerFlow lead-agent.
What it proves
--------------
- The real ``make_lead_agent`` / ``DeerFlowClient`` pipeline is built (full
18-middleware chain, sandbox, tools, etc.).
- A model that returns ``finish_reason='content_filter'`` + ``tool_calls``
triggers SafetyFinishReasonMiddleware.
- LangChain's tool router never invokes ``write_file`` — the truncated
arguments do **not** reach the sandbox.
- A ``safety_termination`` custom event is emitted on the stream and the
final AIMessage carries the observability stamp.
Run from backend/ directory:
PYTHONPATH=. uv run python scripts/e2e_safety_termination_demo.py
"""
from __future__ import annotations
import sys
from typing import Any
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, ChatResult
# ---------------------------------------------------------------------------
# Fake provider that mimics Moonshot's content_filter behaviour
# ---------------------------------------------------------------------------
class _ContentFilteredFakeModel(BaseChatModel):
"""First call returns finish_reason=content_filter + truncated write_file
tool_call. Subsequent calls return a normal stop response so the agent
can terminate (the middleware should make a second call unnecessary by
clearing tool_calls, but we keep this safety net in case loop-detection
or anything else triggers another model invocation)."""
call_count: int = 0
@property
def _llm_type(self) -> str:
return "fake-content-filtered"
def bind_tools(self, tools, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self.call_count += 1
if self.call_count == 1:
msg = AIMessage(
content="# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与",
tool_calls=[
{
"id": "call_truncated_write",
"name": "write_file",
"args": {
"path": "/mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md",
"content": "# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与",
},
}
],
response_metadata={
"finish_reason": "content_filter",
"model_name": "kimi-k2.6",
"model_provider": "openai",
},
)
else:
msg = AIMessage(
content="(secondary call, should not be needed)",
response_metadata={"finish_reason": "stop", "model_name": "kimi-k2.6"},
)
return ChatResult(generations=[ChatGeneration(message=msg)])
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
# ---------------------------------------------------------------------------
# Driver
# ---------------------------------------------------------------------------
def main() -> int:
# Inject the fake model BEFORE constructing the client. Both the
# client module and the lead-agent module bind ``create_chat_model``
# at import time via ``from deerflow.models import create_chat_model``,
# so we patch both attribute slots — the source-of-truth patch on
# ``factory.create_chat_model`` doesn't propagate back into already-
# imported names.
import deerflow.agents.lead_agent.agent as lead_agent_module
import deerflow.client as client_module
fake = _ContentFilteredFakeModel()
originals = {
"lead": lead_agent_module.create_chat_model,
"client": client_module.create_chat_model,
}
def fake_create_chat_model(*args, **kwargs):
return fake
lead_agent_module.create_chat_model = fake_create_chat_model
client_module.create_chat_model = fake_create_chat_model
from deerflow.client import DeerFlowClient
try:
client = DeerFlowClient()
print("\n=== Streaming a turn through the real lead-agent ===")
events: list[dict[str, Any]] = []
for event in client.stream(
"帮我整理一下最近一周政经新闻,写到 /mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md",
thread_id="e2e-safety-1",
):
events.append({"type": event.type, "data": event.data})
# ---- Assertions ----
safety_event = next(
(e for e in events if e["type"] == "custom" and isinstance(e["data"], dict) and e["data"].get("type") == "safety_termination"),
None,
)
final_values = next(
(e for e in reversed(events) if e["type"] == "values"),
None,
)
tool_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "tool"]
ai_tool_call_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "ai" and e["data"].get("tool_calls")]
print(f"\n[stats] total stream events: {len(events)}")
print(f"[stats] model call count: {fake.call_count}")
print(f"[stats] tool messages on stream: {len(tool_messages)}")
print(f"[stats] AI messages carrying tool_calls: {len(ai_tool_call_messages)}")
print("\n[event] safety_termination custom event:")
if safety_event is None:
print(" *** NOT FOUND ***")
return 1
for k, v in safety_event["data"].items():
print(f" {k}: {v}")
print("\n[state] final AIMessage from last values snapshot:")
if final_values is None:
print(" *** no values snapshot ***")
return 1
# `values` event carries `_serialize_message` dicts, not Message objects.
final_messages = final_values["data"].get("messages") or []
last_ai = next((m for m in reversed(final_messages) if isinstance(m, dict) and m.get("type") == "ai"), None)
if last_ai is None:
print(" *** no AIMessage in final state ***")
print(f" message types seen: {[m.get('type') if isinstance(m, dict) else type(m).__name__ for m in final_messages]}")
return 1
tool_calls = last_ai.get("tool_calls") or []
additional_kwargs = last_ai.get("additional_kwargs") or {}
response_metadata = last_ai.get("response_metadata") or {}
content = last_ai.get("content")
print(f" tool_calls (must be empty): {tool_calls}")
print(f" additional_kwargs.safety_termination: {additional_kwargs.get('safety_termination')}")
content_preview = (content if isinstance(content, str) else str(content))[:200]
print(f" content[:200]: {content_preview!r}")
print(f" response_metadata.finish_reason: {response_metadata.get('finish_reason')}")
# NOTE: `client._serialize_message` does not include `response_metadata`
# in the values-event payload (client-layer behaviour, unrelated to the
# middleware). The middleware *does* preserve finish_reason on the
# AIMessage object — see test_safety_finish_reason_middleware.py::
# TestMessageRewrite::test_preserves_response_metadata_finish_reason.
# Here we assert on the observability stamp, which carries the same
# evidence and is in the serialized payload.
stamp = additional_kwargs.get("safety_termination") or {}
failures = []
if tool_calls:
failures.append("final AIMessage still has tool_calls — middleware did NOT clear them")
if not stamp:
failures.append("final AIMessage missing safety_termination observability stamp")
if tool_messages:
failures.append(f"tool node was invoked: {len(tool_messages)} ToolMessage(s) on stream")
if stamp.get("reason_value") != "content_filter":
failures.append(f"safety_termination.reason_value was {stamp.get('reason_value')!r}, expected 'content_filter'")
if safety_event is None:
failures.append("safety_termination custom event was not emitted on the stream")
if failures:
print("\n=== FAIL ===")
for f in failures:
print(f" - {f}")
return 1
print("\n=== PASS ===")
print(" - tool_calls cleared on final AIMessage")
print(" - tool node never invoked (no ToolMessage on stream)")
print(" - safety_termination custom event emitted")
print(" - observability stamp written to additional_kwargs")
print(" - response_metadata.finish_reason preserved for downstream SSE")
return 0
finally:
lead_agent_module.create_chat_model = originals["lead"]
client_module.create_chat_model = originals["client"]
if __name__ == "__main__":
sys.exit(main())