mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
Merge branch 'main' into rayhpeng/fix-run-manager-store-atomicity
This commit is contained in:
@@ -336,8 +336,11 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch):
|
||||
)
|
||||
|
||||
assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares)
|
||||
# verify the custom middleware is injected correctly
|
||||
assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock)
|
||||
# verify the custom middleware is injected correctly.
|
||||
# Chain tail order after the custom middleware is:
|
||||
# ..., custom, SafetyFinishReasonMiddleware, ClarificationMiddleware
|
||||
# so the custom mock sits at index [-3].
|
||||
assert len(middlewares) > 0 and isinstance(middlewares[-3], MagicMock)
|
||||
|
||||
|
||||
def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch):
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
"""End-to-end graph integration test for SafetyFinishReasonMiddleware.
|
||||
|
||||
Unit tests prove ``_apply`` does the right thing on a synthetic state.
|
||||
This test does one level up: builds a real ``langchain.agents.create_agent``
|
||||
graph with the SafetyFinishReasonMiddleware in place, feeds it a fake model
|
||||
that returns ``finish_reason='content_filter'`` + tool_calls, and asserts:
|
||||
|
||||
1. The tool node is **not** invoked (the dangerous truncated tool call
|
||||
is suppressed).
|
||||
2. The final AIMessage in graph state has ``tool_calls == []``.
|
||||
3. The observability ``safety_termination`` record is attached.
|
||||
4. The user-facing explanation is appended to the message content.
|
||||
|
||||
This is the closest we can get to the issue's failure mode without a live
|
||||
Moonshot key, and it proves the middleware actually gates LangChain's
|
||||
tool router — not just rewrites state in isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
|
||||
_TOOL_INVOCATIONS: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@tool
|
||||
def write_file(path: str, content: str) -> str:
|
||||
"""Pretend to write *content* to *path*. Records the call for assertion."""
|
||||
_TOOL_INVOCATIONS.append({"path": path, "content": content})
|
||||
return f"wrote {len(content)} bytes to {path}"
|
||||
|
||||
|
||||
class _ContentFilteredModel(BaseChatModel):
|
||||
"""Fake chat model that mimics OpenAI/Moonshot's content_filter response.
|
||||
|
||||
First call returns finish_reason='content_filter' + a tool_call whose
|
||||
arguments are visibly truncated. Second call (if reached) returns a
|
||||
normal text completion so the agent can terminate cleanly.
|
||||
"""
|
||||
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-content-filtered"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
# create_agent binds tools onto the model; we don't actually need
|
||||
# to bind anything since responses are hard-coded, but the method
|
||||
# must not raise.
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self.call_count += 1
|
||||
if self.call_count == 1:
|
||||
message = AIMessage(
|
||||
content="Here is the report:\n# Weekly Politics\n- Meeting time: 2026-05-12—",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_truncated_1",
|
||||
"name": "write_file",
|
||||
"args": {
|
||||
"path": "/mnt/user-data/outputs/report.md",
|
||||
"content": "# Weekly Politics\n- Meeting time: 2026-05-12—",
|
||||
},
|
||||
}
|
||||
],
|
||||
response_metadata={"finish_reason": "content_filter", "model_name": "fake-kimi"},
|
||||
)
|
||||
else:
|
||||
message = AIMessage(content="ack", response_metadata={"finish_reason": "stop"})
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
|
||||
class _InspectMiddleware(AgentMiddleware):
|
||||
"""Captures the messages list at every model entry so we can assert
|
||||
no synthetic tool result was injected back into the conversation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.observed: list[list[Any]] = []
|
||||
|
||||
def wrap_model_call(self, request: ModelRequest, handler) -> ModelResponse:
|
||||
self.observed.append(list(request.messages))
|
||||
return handler(request)
|
||||
|
||||
|
||||
def test_content_filter_with_tool_calls_does_not_invoke_tool_node():
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
inspector = _InspectMiddleware()
|
||||
|
||||
agent = create_agent(
|
||||
model=_ContentFilteredModel(),
|
||||
tools=[write_file],
|
||||
# Inspector first so its after_model is registered; Safety last in
|
||||
# the list so it executes first under LIFO (matches production wiring).
|
||||
middleware=[inspector, SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage(content="write me a report")]})
|
||||
|
||||
# Critical assertion: the dangerous truncated tool call must NOT have
|
||||
# been executed. This is the entire point of the middleware.
|
||||
assert _TOOL_INVOCATIONS == [], f"write_file was invoked despite content_filter: {_TOOL_INVOCATIONS}"
|
||||
|
||||
# Final AIMessage has no tool calls left.
|
||||
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
|
||||
assert final_ai.tool_calls == []
|
||||
|
||||
# Observability stamp is present.
|
||||
record = final_ai.additional_kwargs.get("safety_termination")
|
||||
assert record is not None
|
||||
assert record["detector"] == "openai_compatible_content_filter"
|
||||
assert record["reason_field"] == "finish_reason"
|
||||
assert record["reason_value"] == "content_filter"
|
||||
assert record["suppressed_tool_call_count"] == 1
|
||||
assert record["suppressed_tool_call_names"] == ["write_file"]
|
||||
|
||||
# User-facing explanation is appended.
|
||||
assert "safety-related signal" in final_ai.content
|
||||
# Original partial text preserved (we don't throw away what the user
|
||||
# already saw in the stream — see middleware docstring).
|
||||
assert "Weekly Politics" in final_ai.content
|
||||
|
||||
# finish_reason on response_metadata is preserved (so SSE / converters
|
||||
# downstream still see the real provider reason).
|
||||
assert final_ai.response_metadata.get("finish_reason") == "content_filter"
|
||||
|
||||
|
||||
def test_content_filter_without_tool_calls_passes_through_unchanged():
|
||||
"""No tool calls => issue scope says don't intervene; the partial
|
||||
response should be delivered as-is so the user sees what they got."""
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
|
||||
class _NoToolModel(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-no-tool"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
msg = AIMessage(
|
||||
content="Partial answer truncated by safety filter",
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=msg)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
agent = create_agent(
|
||||
model=_NoToolModel(),
|
||||
tools=[write_file],
|
||||
middleware=[SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
result = agent.invoke({"messages": [HumanMessage(content="hi")]})
|
||||
final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage))
|
||||
|
||||
# Content untouched.
|
||||
assert final_ai.content == "Partial answer truncated by safety filter"
|
||||
# No safety_termination stamp because we didn't intervene.
|
||||
assert "safety_termination" not in final_ai.additional_kwargs
|
||||
# tool node never ran (there were no tool calls in the first place).
|
||||
assert _TOOL_INVOCATIONS == []
|
||||
|
||||
|
||||
def test_normal_tool_call_round_trip_is_not_affected():
|
||||
"""Regression: a healthy finish_reason='tool_calls' response must still
|
||||
execute the tool. The middleware must not over-fire."""
|
||||
_TOOL_INVOCATIONS.clear()
|
||||
|
||||
class _HealthyToolModel(BaseChatModel):
|
||||
call_count: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-healthy"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self.call_count += 1
|
||||
if self.call_count == 1:
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_ok",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/tmp/ok", "content": "complete content"},
|
||||
}
|
||||
],
|
||||
response_metadata={"finish_reason": "tool_calls"},
|
||||
)
|
||||
else:
|
||||
msg = AIMessage(content="done", response_metadata={"finish_reason": "stop"})
|
||||
return ChatResult(generations=[ChatGeneration(message=msg)])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
agent = create_agent(
|
||||
model=_HealthyToolModel(),
|
||||
tools=[write_file],
|
||||
middleware=[SafetyFinishReasonMiddleware()],
|
||||
)
|
||||
agent.invoke({"messages": [HumanMessage(content="write")]})
|
||||
|
||||
assert _TOOL_INVOCATIONS == [{"path": "/tmp/ok", "content": "complete content"}]
|
||||
@@ -0,0 +1,651 @@
|
||||
"""Unit tests for SafetyFinishReasonMiddleware."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
from deerflow.agents.middlewares.safety_termination_detectors import (
|
||||
SafetyTermination,
|
||||
)
|
||||
from deerflow.config.safety_finish_reason_config import (
|
||||
SafetyDetectorConfig,
|
||||
SafetyFinishReasonConfig,
|
||||
)
|
||||
|
||||
|
||||
def _runtime(thread_id="t-1"):
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": thread_id}
|
||||
return runtime
|
||||
|
||||
|
||||
def _ai(
|
||||
*,
|
||||
content="",
|
||||
tool_calls=None,
|
||||
response_metadata=None,
|
||||
additional_kwargs=None,
|
||||
):
|
||||
return AIMessage(
|
||||
content=content,
|
||||
tool_calls=tool_calls or [],
|
||||
response_metadata=response_metadata or {},
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
def _write_call(idx=1, content_text="半截"):
|
||||
return {
|
||||
"id": f"call_write_{idx}",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/mnt/user-data/outputs/x.md", "content": content_text},
|
||||
}
|
||||
|
||||
|
||||
class AlwaysHitDetector:
|
||||
"""Test fixture: always reports the given termination."""
|
||||
|
||||
name = "always_hit"
|
||||
|
||||
def __init__(self, *, reason_field="finish_reason", reason_value="content_filter", extras=None):
|
||||
self.reason_field = reason_field
|
||||
self.reason_value = reason_value
|
||||
self.extras = extras or {}
|
||||
|
||||
def detect(self, message):
|
||||
return SafetyTermination(
|
||||
detector=self.name,
|
||||
reason_field=self.reason_field,
|
||||
reason_value=self.reason_value,
|
||||
extras=self.extras,
|
||||
)
|
||||
|
||||
|
||||
class NeverHitDetector:
|
||||
name = "never_hit"
|
||||
|
||||
def detect(self, message):
|
||||
return None
|
||||
|
||||
|
||||
class RaisingDetector:
|
||||
name = "raising"
|
||||
|
||||
def detect(self, message):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core trigger behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTriggerCriteria:
|
||||
def test_content_filter_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
patched = result["messages"][0]
|
||||
assert patched.tool_calls == []
|
||||
|
||||
def test_content_filter_without_tool_calls_passes_through(self):
|
||||
"""issue scope: when there are no tool calls the partial text is a
|
||||
legitimate final response and should not be rewritten."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial response",
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_normal_tool_calls_pass_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "tool_calls"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_normal_stop_with_tool_calls_pass_through(self):
|
||||
# Some providers report finish_reason='stop' for tool-call messages.
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "stop"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_empty_message_list_passes_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
assert mw._apply({"messages": []}, _runtime()) is None
|
||||
|
||||
def test_non_ai_last_message_passes_through(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {"messages": [HumanMessage(content="hi"), SystemMessage(content="sys")]}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_anthropic_refusal_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"stop_reason": "refusal"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_gemini_safety_with_tool_calls_triggers(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "SAFETY"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message rewriting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageRewrite:
|
||||
def test_clears_structured_tool_calls(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1), _write_call(2)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
patched = result["messages"][0]
|
||||
assert patched.tool_calls == []
|
||||
|
||||
def test_clears_raw_additional_kwargs_tool_calls(self):
|
||||
"""Critical defence-in-depth: DanglingToolCallMiddleware will recover
|
||||
tool calls from additional_kwargs.tool_calls if we forget them, which
|
||||
would re-emit a synthetic ToolMessage downstream and confuse the
|
||||
model. We must wipe both."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
raw_tool_calls = [
|
||||
{
|
||||
"id": "call_write_1",
|
||||
"type": "function",
|
||||
"function": {"name": "write_file", "arguments": '{"path": "/x"}'},
|
||||
}
|
||||
]
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
additional_kwargs={
|
||||
"tool_calls": raw_tool_calls,
|
||||
"function_call": {"name": "write_file", "arguments": "{}"},
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, _runtime())
|
||||
patched = result["messages"][0]
|
||||
assert "tool_calls" not in patched.additional_kwargs
|
||||
assert "function_call" not in patched.additional_kwargs
|
||||
|
||||
def test_preserves_other_additional_kwargs(self):
|
||||
# vLLM puts reasoning under additional_kwargs.reasoning; Anthropic
|
||||
# may carry other provider-specific keys. They must not be wiped.
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
additional_kwargs={
|
||||
"reasoning": "thinking text",
|
||||
"custom_provider_field": {"x": 1},
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.additional_kwargs["reasoning"] == "thinking text"
|
||||
assert patched.additional_kwargs["custom_provider_field"] == {"x": 1}
|
||||
|
||||
def test_writes_observability_field(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call(1), _write_call(2)],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
record = patched.additional_kwargs["safety_termination"]
|
||||
assert record["detector"] == "openai_compatible_content_filter"
|
||||
assert record["reason_field"] == "finish_reason"
|
||||
assert record["reason_value"] == "content_filter"
|
||||
assert record["suppressed_tool_call_count"] == 2
|
||||
assert record["suppressed_tool_call_names"] == ["write_file", "write_file"]
|
||||
|
||||
def test_preserves_response_metadata_finish_reason(self):
|
||||
"""Downstream SSE converters read response_metadata.finish_reason —
|
||||
we want them to see the *real* provider reason, not 'stop'."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter", "model_name": "kimi-k2"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.response_metadata["finish_reason"] == "content_filter"
|
||||
assert patched.response_metadata["model_name"] == "kimi-k2"
|
||||
|
||||
def test_appends_user_facing_explanation_to_str_content(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="some partial text",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, str)
|
||||
assert patched.content.startswith("some partial text")
|
||||
assert "safety-related signal" in patched.content
|
||||
|
||||
def test_handles_empty_content(self):
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="",
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, str)
|
||||
assert "safety-related signal" in patched.content
|
||||
|
||||
def test_handles_list_content_thinking_blocks(self):
|
||||
"""Anthropic thinking / vLLM reasoning models emit content blocks.
|
||||
Naively concatenating a string would raise TypeError."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
thinking_blocks = [
|
||||
{"type": "thinking", "text": "let me consider..."},
|
||||
{"type": "text", "text": "partial answer"},
|
||||
]
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content=thinking_blocks,
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert isinstance(patched.content, list)
|
||||
assert patched.content[:2] == thinking_blocks
|
||||
assert patched.content[-1]["type"] == "text"
|
||||
assert "safety-related signal" in patched.content[-1]["text"]
|
||||
|
||||
def test_idempotent_on_already_cleared_message(self):
|
||||
# Re-running the middleware on a message we already cleared must not
|
||||
# re-trigger (tool_calls is now empty → fast passthrough).
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
first = mw._apply(state, _runtime())
|
||||
state2 = {"messages": [first["messages"][0]]}
|
||||
second = mw._apply(state2, _runtime())
|
||||
assert second is None
|
||||
|
||||
def test_preserves_message_id_for_add_messages_replacement(self):
|
||||
"""LangGraph's add_messages reducer treats same-id messages as
|
||||
replacements. model_copy keeps id by default."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
original = _ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
# AIMessage auto-generates id; capture it
|
||||
original_id = original.id
|
||||
state = {"messages": [original]}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.id == original_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detector wiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDetectorWiring:
|
||||
def test_iterates_detectors_in_order(self):
|
||||
first = AlwaysHitDetector(reason_value="first")
|
||||
second = AlwaysHitDetector(reason_value="second")
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[first, second])
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
patched = mw._apply(state, _runtime())["messages"][0]
|
||||
assert patched.additional_kwargs["safety_termination"]["reason_value"] == "first"
|
||||
|
||||
def test_returns_none_when_no_detector_matches(self):
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[NeverHitDetector(), NeverHitDetector()])
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is None
|
||||
|
||||
def test_buggy_detector_does_not_break_run(self):
|
||||
mw = SafetyFinishReasonMiddleware(detectors=[RaisingDetector(), AlwaysHitDetector()])
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
assert result["messages"][0].additional_kwargs["safety_termination"]["detector"] == "always_hit"
|
||||
|
||||
def test_constructor_copies_detectors(self):
|
||||
"""Caller mutation after construction must not leak into us."""
|
||||
detectors = [AlwaysHitDetector()]
|
||||
mw = SafetyFinishReasonMiddleware(detectors=detectors)
|
||||
detectors.clear()
|
||||
state = {"messages": [_ai(tool_calls=[_write_call()])]}
|
||||
assert mw._apply(state, _runtime()) is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# from_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFromConfig:
|
||||
def test_default_config_uses_builtin_detectors(self):
|
||||
mw = SafetyFinishReasonMiddleware.from_config(SafetyFinishReasonConfig())
|
||||
assert len(mw._detectors) == 3
|
||||
names = {d.name for d in mw._detectors}
|
||||
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
|
||||
|
||||
def test_custom_detectors_loaded_via_reflection(self):
|
||||
cfg = SafetyFinishReasonConfig(
|
||||
detectors=[
|
||||
SafetyDetectorConfig(
|
||||
use="deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector",
|
||||
config={"finish_reasons": ["custom_filter"]},
|
||||
),
|
||||
]
|
||||
)
|
||||
mw = SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
assert len(mw._detectors) == 1
|
||||
# Confirm the kwargs propagated.
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "custom_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, _runtime()) is not None
|
||||
# Default token no longer matches.
|
||||
state2 = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state2, _runtime()) is None
|
||||
|
||||
def test_empty_detector_list_rejected(self):
|
||||
cfg = SafetyFinishReasonConfig(detectors=[])
|
||||
with pytest.raises(ValueError, match="enabled=false"):
|
||||
SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
|
||||
def test_non_detector_class_rejected(self):
|
||||
cfg = SafetyFinishReasonConfig(
|
||||
detectors=[SafetyDetectorConfig(use="builtins:dict")],
|
||||
)
|
||||
with pytest.raises(TypeError):
|
||||
SafetyFinishReasonMiddleware.from_config(cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stream event
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuditEvent:
|
||||
"""Verify SafetyFinishReasonMiddleware records a `middleware:safety_termination`
|
||||
audit event via RunJournal.record_middleware when the run-scoped journal is
|
||||
exposed under runtime.context["__run_journal"].
|
||||
|
||||
Background: review on PR #3035 — SSE custom event handles live consumers,
|
||||
but post-run audit needs a row in run_events that can be queried with one
|
||||
SQL statement (no JOIN against message body).
|
||||
"""
|
||||
|
||||
def _runtime_with_journal(self, journal):
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "t-audit", "__run_journal": journal}
|
||||
return runtime
|
||||
|
||||
def test_records_audit_event_when_journal_present(self):
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
tc = _write_call(1)
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
content="partial",
|
||||
tool_calls=[tc],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
result = mw._apply(state, self._runtime_with_journal(journal))
|
||||
assert result is not None
|
||||
|
||||
journal.record_middleware.assert_called_once()
|
||||
call = journal.record_middleware.call_args
|
||||
# tag is positional or kwarg depending on call style; we use kwargs.
|
||||
assert call.kwargs["tag"] == "safety_termination"
|
||||
assert call.kwargs["name"] == "SafetyFinishReasonMiddleware"
|
||||
assert call.kwargs["hook"] == "after_model"
|
||||
assert call.kwargs["action"] == "suppress_tool_calls"
|
||||
|
||||
changes = call.kwargs["changes"]
|
||||
assert changes["detector"] == "openai_compatible_content_filter"
|
||||
assert changes["reason_field"] == "finish_reason"
|
||||
assert changes["reason_value"] == "content_filter"
|
||||
assert changes["suppressed_tool_call_count"] == 1
|
||||
assert changes["suppressed_tool_call_names"] == ["write_file"]
|
||||
assert changes["suppressed_tool_call_ids"] == ["call_write_1"]
|
||||
assert "message_id" in changes
|
||||
assert isinstance(changes["extras"], dict)
|
||||
|
||||
def test_audit_event_never_carries_tool_arguments(self):
|
||||
"""PR #3035 review IMPORTANT: tool args are the filtered content itself
|
||||
and must NOT be persisted to run_events under any circumstance."""
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
sensitive_tc = {
|
||||
"id": "call_x",
|
||||
"name": "write_file",
|
||||
"args": {"path": "/x", "content": "FILTERED_CONTENT_DO_NOT_PERSIST"},
|
||||
}
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[sensitive_tc],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
mw._apply(state, self._runtime_with_journal(journal))
|
||||
flat = repr(journal.record_middleware.call_args)
|
||||
assert "FILTERED_CONTENT_DO_NOT_PERSIST" not in flat, "tool arguments must not leak into audit event"
|
||||
assert "args" not in journal.record_middleware.call_args.kwargs["changes"]
|
||||
|
||||
def test_no_journal_in_runtime_context_is_silently_skipped(self):
|
||||
"""Subagent runtime / unit tests / no-event-store paths have no journal.
|
||||
Middleware must still intervene and clear tool_calls — only the audit
|
||||
event is skipped."""
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "t-noj"} # no __run_journal
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Should not raise; should still clear tool_calls.
|
||||
result = mw._apply(state, runtime)
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_journal_record_exception_does_not_break_run(self):
|
||||
"""Buggy journal must never propagate an exception into the agent loop."""
|
||||
journal = MagicMock()
|
||||
journal.record_middleware.side_effect = RuntimeError("db down")
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Must not raise.
|
||||
result = mw._apply(state, self._runtime_with_journal(journal))
|
||||
assert result is not None
|
||||
assert result["messages"][0].tool_calls == []
|
||||
|
||||
def test_no_record_when_passthrough(self):
|
||||
"""When the middleware does NOT intervene, no audit event is written."""
|
||||
journal = MagicMock()
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "tool_calls"}, # healthy
|
||||
)
|
||||
]
|
||||
}
|
||||
assert mw._apply(state, self._runtime_with_journal(journal)) is None
|
||||
journal.record_middleware.assert_not_called()
|
||||
|
||||
|
||||
class TestStreamEvent:
|
||||
def test_emits_event_when_writer_available(self, monkeypatch):
|
||||
captured: list = []
|
||||
|
||||
def fake_writer(payload):
|
||||
captured.append(payload)
|
||||
|
||||
# Patch get_stream_writer at the symbol-resolution site.
|
||||
import langgraph.config
|
||||
|
||||
monkeypatch.setattr(langgraph.config, "get_stream_writer", lambda: fake_writer)
|
||||
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
mw._apply(state, _runtime("t-stream"))
|
||||
|
||||
assert len(captured) == 1
|
||||
payload = captured[0]
|
||||
assert payload["type"] == "safety_termination"
|
||||
assert payload["detector"] == "openai_compatible_content_filter"
|
||||
assert payload["reason_field"] == "finish_reason"
|
||||
assert payload["reason_value"] == "content_filter"
|
||||
assert payload["suppressed_tool_call_count"] == 1
|
||||
assert payload["suppressed_tool_call_names"] == ["write_file"]
|
||||
assert payload["thread_id"] == "t-stream"
|
||||
|
||||
def test_writer_unavailable_does_not_break(self, monkeypatch):
|
||||
import langgraph.config
|
||||
|
||||
def boom():
|
||||
raise LookupError("not in a stream context")
|
||||
|
||||
monkeypatch.setattr(langgraph.config, "get_stream_writer", boom)
|
||||
|
||||
mw = SafetyFinishReasonMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai(
|
||||
tool_calls=[_write_call()],
|
||||
response_metadata={"finish_reason": "content_filter"},
|
||||
)
|
||||
]
|
||||
}
|
||||
# Should not raise.
|
||||
result = mw._apply(state, _runtime())
|
||||
assert result is not None
|
||||
@@ -0,0 +1,176 @@
|
||||
"""Unit tests for SafetyTerminationDetector built-ins."""
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from deerflow.agents.middlewares.safety_termination_detectors import (
|
||||
AnthropicRefusalDetector,
|
||||
GeminiSafetyDetector,
|
||||
OpenAICompatibleContentFilterDetector,
|
||||
SafetyTermination,
|
||||
SafetyTerminationDetector,
|
||||
default_detectors,
|
||||
)
|
||||
|
||||
|
||||
def _ai(*, content="", tool_calls=None, response_metadata=None, additional_kwargs=None) -> AIMessage:
|
||||
return AIMessage(
|
||||
content=content,
|
||||
tool_calls=tool_calls or [],
|
||||
response_metadata=response_metadata or {},
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAICompatibleContentFilterDetector:
|
||||
def test_default_matches_content_filter(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
hit = d.detect(_ai(response_metadata={"finish_reason": "content_filter"}))
|
||||
assert hit is not None
|
||||
assert hit.detector == "openai_compatible_content_filter"
|
||||
assert hit.reason_field == "finish_reason"
|
||||
assert hit.reason_value == "content_filter"
|
||||
|
||||
def test_case_insensitive_match(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "CONTENT_FILTER"})) is not None
|
||||
|
||||
def test_other_finish_reasons_pass_through(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "stop"})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "tool_calls"})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "length"})) is None
|
||||
|
||||
def test_missing_metadata_passes_through(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai()) is None
|
||||
|
||||
def test_non_string_finish_reason_passes_through(self):
|
||||
# Some adapters may stash an enum or dict — must not raise.
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": 42})) is None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": {"value": "content_filter"}})) is None
|
||||
|
||||
def test_falls_back_to_additional_kwargs(self):
|
||||
# Legacy adapters surface finish_reason via additional_kwargs.
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
hit = d.detect(_ai(additional_kwargs={"finish_reason": "content_filter"}))
|
||||
assert hit is not None
|
||||
|
||||
def test_configurable_extra_values(self):
|
||||
# Chinese providers sometimes use bespoke tokens.
|
||||
d = OpenAICompatibleContentFilterDetector(finish_reasons=["content_filter", "sensitive", "violation"])
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "sensitive"})) is not None
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "violation"})) is not None
|
||||
# Original token still matches.
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "content_filter"})) is not None
|
||||
|
||||
def test_carries_azure_content_filter_results(self):
|
||||
d = OpenAICompatibleContentFilterDetector()
|
||||
filter_results = {"hate": {"filtered": True, "severity": "high"}}
|
||||
hit = d.detect(
|
||||
_ai(
|
||||
response_metadata={
|
||||
"finish_reason": "content_filter",
|
||||
"content_filter_results": filter_results,
|
||||
},
|
||||
)
|
||||
)
|
||||
assert hit is not None
|
||||
assert hit.extras["content_filter_results"] == filter_results
|
||||
|
||||
|
||||
class TestAnthropicRefusalDetector:
|
||||
def test_default_matches_refusal(self):
|
||||
hit = AnthropicRefusalDetector().detect(_ai(response_metadata={"stop_reason": "refusal"}))
|
||||
assert hit is not None
|
||||
assert hit.reason_field == "stop_reason"
|
||||
assert hit.reason_value == "refusal"
|
||||
|
||||
def test_other_stop_reasons_pass_through(self):
|
||||
d = AnthropicRefusalDetector()
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "end_turn"})) is None
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "tool_use"})) is None
|
||||
assert d.detect(_ai(response_metadata={"stop_reason": "max_tokens"})) is None
|
||||
|
||||
def test_anthropic_does_not_steal_finish_reason(self):
|
||||
# An OpenAI message must not accidentally trip the Anthropic detector.
|
||||
assert AnthropicRefusalDetector().detect(_ai(response_metadata={"finish_reason": "content_filter"})) is None
|
||||
|
||||
|
||||
class TestGeminiSafetyDetector:
|
||||
def test_default_set_covers_documented_reasons(self):
|
||||
d = GeminiSafetyDetector()
|
||||
for reason in (
|
||||
# text safety
|
||||
"SAFETY",
|
||||
"BLOCKLIST",
|
||||
"PROHIBITED_CONTENT",
|
||||
"SPII",
|
||||
"RECITATION",
|
||||
# image safety
|
||||
"IMAGE_SAFETY",
|
||||
"IMAGE_PROHIBITED_CONTENT",
|
||||
"IMAGE_RECITATION",
|
||||
):
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is not None, reason
|
||||
|
||||
def test_normal_termination_passes_through(self):
|
||||
d = GeminiSafetyDetector()
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": "STOP"})) is None
|
||||
# MAX_TOKENS / LANGUAGE / NO_IMAGE / OTHER / IMAGE_OTHER /
|
||||
# MALFORMED_FUNCTION_CALL / UNEXPECTED_TOOL_CALL are intentionally
|
||||
# excluded from the default set — they are either normal termination,
|
||||
# capability mismatches, too broad (OTHER), or tool-call protocol
|
||||
# errors. See GeminiSafetyDetector docstring.
|
||||
for reason in (
|
||||
"MAX_TOKENS",
|
||||
"LANGUAGE",
|
||||
"NO_IMAGE",
|
||||
"OTHER",
|
||||
"IMAGE_OTHER",
|
||||
"MALFORMED_FUNCTION_CALL",
|
||||
"UNEXPECTED_TOOL_CALL",
|
||||
"FINISH_REASON_UNSPECIFIED",
|
||||
):
|
||||
assert d.detect(_ai(response_metadata={"finish_reason": reason})) is None, reason
|
||||
|
||||
def test_carries_safety_ratings(self):
|
||||
ratings = [{"category": "HARM_CATEGORY_HARASSMENT", "probability": "HIGH"}]
|
||||
hit = GeminiSafetyDetector().detect(
|
||||
_ai(
|
||||
response_metadata={
|
||||
"finish_reason": "SAFETY",
|
||||
"safety_ratings": ratings,
|
||||
},
|
||||
)
|
||||
)
|
||||
assert hit is not None
|
||||
assert hit.extras["safety_ratings"] == ratings
|
||||
|
||||
|
||||
class TestDefaultDetectorSet:
|
||||
def test_default_set_returns_three_detectors(self):
|
||||
dets = default_detectors()
|
||||
names = {d.name for d in dets}
|
||||
assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"}
|
||||
|
||||
def test_default_set_returns_fresh_list(self):
|
||||
# Caller mutation must not affect later calls.
|
||||
first = default_detectors()
|
||||
first.clear()
|
||||
second = default_detectors()
|
||||
assert len(second) == 3
|
||||
|
||||
|
||||
class TestProtocolConformance:
|
||||
def test_builtins_satisfy_protocol(self):
|
||||
for d in default_detectors():
|
||||
assert isinstance(d, SafetyTerminationDetector)
|
||||
|
||||
def test_safety_termination_is_frozen(self):
|
||||
t = SafetyTermination(detector="x", reason_field="finish_reason", reason_value="content_filter")
|
||||
try:
|
||||
t.detector = "y" # type: ignore[misc]
|
||||
except Exception:
|
||||
return
|
||||
raise AssertionError("SafetyTermination should be frozen")
|
||||
@@ -134,8 +134,14 @@ def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False)
|
||||
|
||||
assert captured["app_config"] is app_config
|
||||
assert len(middlewares) == 6
|
||||
assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware)
|
||||
# 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling,
|
||||
# SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware
|
||||
# (enabled by default — see SafetyFinishReasonConfig).
|
||||
from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware
|
||||
|
||||
assert len(middlewares) == 7
|
||||
assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares)
|
||||
assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware)
|
||||
|
||||
|
||||
def test_wrap_tool_call_passthrough_on_success():
|
||||
|
||||
Reference in New Issue
Block a user