feat(loop-detection): defer warning injection (#2752)

* fix(loop-detection): defer warn injection to wrap_model_call

The warn branch in LoopDetectionMiddleware injected a HumanMessage
into state from after_model. The tools node had not yet produced
ToolMessage responses to the previous AIMessage(tool_calls=...), so
the new HumanMessage landed *between* the assistant's tool_calls and
their responses. OpenAI/Moonshot reject the next request with
"tool_call_ids did not have response messages" because their
validators require tool_calls to be followed immediately by tool
messages.

Detection now runs in after_model as before, but only enqueues the
warning into a per-thread list. Injection happens in wrap_model_call,
where every prior ToolMessage is already present in request.messages.
The warning is appended at the end as HumanMessage(name="loop_warning")
— pairing intact, AIMessage semantics untouched, no SystemMessage
issues for Anthropic.

Closes #2029, addresses #2255 #2293 #2304 #2511.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* fix(channels): remove loop warning display filter

* feat(loop-detection): scope pending warnings by run

* docs(loop-detection): update docs

* test(loop-detection): assert deferred warnings are queued

* fix(loop-detection): cap transient warning state

* docs: update docs

* add async awrap_model_call test coverage

* docs(loop-detection): document transient warnings

---------

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Nan Gao
2026-05-21 08:36:07 +02:00
committed by GitHub
parent 7ec8d3a6e7
commit dcc6f1e678
7 changed files with 696 additions and 221 deletions
+412 -82
View File
@@ -1,24 +1,94 @@
"""Tests for LoopDetectionMiddleware."""
import copy
from collections import OrderedDict
from typing import Any
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, SystemMessage
import pytest
from langchain.agents import create_agent
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import tool as as_tool
from pydantic import PrivateAttr
from deerflow.agents.middlewares.loop_detection_middleware import (
_HARD_STOP_MSG,
_MAX_PENDING_WARNINGS_PER_RUN,
LoopDetectionMiddleware,
_hash_tool_calls,
)
def _make_runtime(thread_id="test-thread"):
def _make_runtime(thread_id="test-thread", run_id="test-run"):
"""Build a minimal Runtime mock with context."""
runtime = MagicMock()
runtime.context = {"thread_id": thread_id}
runtime.context = {"thread_id": thread_id, "run_id": run_id}
return runtime
def _pending_key(thread_id="test-thread", run_id="test-run"):
return (thread_id, run_id)
def _make_request(messages, runtime):
"""Build a minimal ModelRequest stand-in for wrap_model_call tests."""
request = MagicMock()
request.messages = list(messages)
request.runtime = runtime
request.override = lambda **updates: _override_request(request, updates)
return request
def _override_request(request, updates):
"""Mimic ModelRequest.override(): return a copy with fields replaced."""
new = MagicMock()
new.messages = updates.get("messages", request.messages)
new.runtime = updates.get("runtime", request.runtime)
new.override = lambda **u: _override_request(new, u)
return new
def _capture_handler():
"""Build a sync handler that records the request it was called with."""
captured: list = []
def handler(req):
captured.append(req)
return MagicMock()
return captured, handler
class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel):
"""Fake chat model that records each model request's messages."""
_seen_messages: list[list[Any]] = PrivateAttr(default_factory=list)
@property
def seen_messages(self) -> list[list[Any]]:
return self._seen_messages
def bind_tools(
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self._seen_messages.append(list(messages))
return super()._generate(
messages,
stop=stop,
run_manager=run_manager,
**kwargs,
)
def _make_state(tool_calls=None, content=""):
"""Build a minimal AgentState dict with an AIMessage.
@@ -138,7 +208,15 @@ class TestLoopDetection:
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
def test_warn_at_threshold(self):
def test_warn_at_threshold_queues_but_does_not_mutate_state(self):
"""At warn threshold, ``after_model`` enqueues but returns None.
Detection observes the just-emitted AIMessage(tool_calls=...). The
tools node hasn't run yet, so injecting any non-tool message here
would split the assistant's tool_calls from their ToolMessage
responses and break OpenAI/Moonshot pairing. The warning is
delivered later from ``wrap_model_call``.
"""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=5)
runtime = _make_runtime()
call = [_bash_call("ls")]
@@ -146,44 +224,150 @@ class TestLoopDetection:
for _ in range(2):
mw._apply(_make_state(tool_calls=call), runtime)
# Third identical call triggers warning. The warning is appended to
# the AIMessage content (tool_calls preserved) — never inserted as a
# separate HumanMessage between the AIMessage(tool_calls) and its
# ToolMessage responses, which would break OpenAI/Moonshot strict
# tool-call pairing validation.
# Third identical call triggers warning detection.
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msgs = result["messages"]
assert len(msgs) == 1
assert isinstance(msgs[0], AIMessage)
assert len(msgs[0].tool_calls) == len(call)
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
assert "LOOP DETECTED" in msgs[0].content
# Detection must not mutate state — the AIMessage with tool_calls is
# left untouched so the tools node runs normally.
assert result is None
# ...but a warning is queued for the next model call.
assert mw._pending_warnings[_pending_key()]
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key()][0]
def test_warn_does_not_break_tool_call_pairing(self):
"""Regression: the warn branch must NOT inject a non-tool message
after an AIMessage(tool_calls=...). Moonshot/OpenAI reject the next
request with 'tool_call_ids did not have response messages' if any
non-tool message is wedged between the AIMessage and its ToolMessage
responses. See #2029.
def test_warn_injected_at_next_model_call(self):
"""``wrap_model_call`` appends a HumanMessage(loop_warning) to the
outgoing messages — *after* every existing message — so that the
AIMessage(tool_calls=...) -> ToolMessage(...) pairing stays intact.
"""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(2):
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msgs = result["messages"]
assert len(msgs) == 1
assert isinstance(msgs[0], AIMessage)
assert len(msgs[0].tool_calls) == len(call)
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
# Build the messages the agent runtime would assemble for the next
# turn: prior AIMessage(tool_calls), its ToolMessage responses, ...
ai_msg = AIMessage(content="", tool_calls=call)
tool_msg = ToolMessage(content="ok", tool_call_id=call[0]["id"], name="bash")
request = _make_request([ai_msg, tool_msg], runtime)
def test_warn_only_injected_once(self):
"""Warning for the same hash should only be injected once per thread."""
captured, handler = _capture_handler()
mw.wrap_model_call(request, handler)
sent = captured[0].messages
# AIMessage and ToolMessage stay in order, untouched.
assert sent[0] is ai_msg
assert sent[1] is tool_msg
# HumanMessage(warning) appears AFTER the ToolMessage — pairing intact.
assert isinstance(sent[2], HumanMessage)
assert sent[2].name == "loop_warning"
assert "LOOP DETECTED" in sent[2].content
def test_warn_queue_drained_after_injection(self):
"""A queued warning must be emitted exactly once per detection event."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
request = _make_request([AIMessage(content="hi")], runtime)
captured, handler = _capture_handler()
# First call: warning is appended.
mw.wrap_model_call(request, handler)
first = captured[0].messages
assert any(isinstance(m, HumanMessage) for m in first)
# Subsequent call without new detection: no warning re-emitted.
request2 = _make_request([AIMessage(content="hi")], runtime)
mw.wrap_model_call(request2, handler)
second = captured[1].messages
assert not any(isinstance(m, HumanMessage) for m in second)
def test_warn_queue_scoped_by_run_id(self):
"""A warning queued for one run must not be injected into another run."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime_a = _make_runtime(run_id="run-A")
runtime_b = _make_runtime(run_id="run-B")
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime_a)
request_b = _make_request([AIMessage(content="hi")], runtime_b)
captured, handler = _capture_handler()
mw.wrap_model_call(request_b, handler)
assert not any(isinstance(m, HumanMessage) for m in captured[0].messages)
assert mw._pending_warnings.get(_pending_key(run_id="run-A"))
request_a = _make_request([AIMessage(content="hi")], runtime_a)
mw.wrap_model_call(request_a, handler)
assert any(isinstance(message, HumanMessage) and message.name == "loop_warning" for message in captured[1].messages)
def test_missing_run_id_uses_default_pending_scope(self):
"""When runtime has no run_id, warning handling falls back to the default run scope."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = MagicMock()
runtime.context = {"thread_id": "test-thread"}
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
assert mw._pending_warnings.get(_pending_key(run_id="default"))
request = _make_request([AIMessage(content="hi")], runtime)
captured, handler = _capture_handler()
mw.wrap_model_call(request, handler)
loop_warnings = [message for message in captured[0].messages if isinstance(message, HumanMessage) and message.name == "loop_warning"]
assert len(loop_warnings) == 1
assert "LOOP DETECTED" in loop_warnings[0].content
assert not mw._pending_warnings.get(_pending_key(run_id="default"))
def test_before_agent_clears_stale_pending_warnings_for_thread(self):
"""Starting a new run drops stale warnings from prior runs in the same thread."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime_a = _make_runtime(run_id="run-A")
runtime_b = _make_runtime(run_id="run-B")
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime_a)
assert mw._pending_warnings.get(_pending_key(run_id="run-A"))
mw.before_agent({"messages": []}, runtime_b)
assert not mw._pending_warnings.get(_pending_key(run_id="run-A"))
def test_after_agent_clears_current_run_pending_warnings(self):
"""Run cleanup should drop warnings that never reached wrap_model_call."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
assert mw._pending_warnings.get(_pending_key())
mw.after_agent({"messages": []}, runtime)
assert not mw._pending_warnings.get(_pending_key())
def test_multiple_pending_warnings_are_merged_into_one_message(self):
"""Edge-case drains should produce one loop_warning prompt message."""
mw = LoopDetectionMiddleware()
runtime = _make_runtime()
mw._pending_warnings[_pending_key()] = ["first warning", "second warning", "first warning"]
request = _make_request([AIMessage(content="hi")], runtime)
captured, handler = _capture_handler()
mw.wrap_model_call(request, handler)
loop_warnings = [message for message in captured[0].messages if isinstance(message, HumanMessage) and message.name == "loop_warning"]
assert len(loop_warnings) == 1
assert loop_warnings[0].content == "first warning\n\nsecond warning"
def test_warn_only_queued_once_per_hash(self):
"""Same hash repeated past the threshold should warn only once."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
@@ -192,14 +376,13 @@ class TestLoopDetection:
for _ in range(2):
mw._apply(_make_state(tool_calls=call), runtime)
# Third — warning injected
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
# Third — warning queued
mw._apply(_make_state(tool_calls=call), runtime)
assert len(mw._pending_warnings[_pending_key()]) == 1
# Fourth — warning already injected, should return None
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
# Fourth — already warned for this hash, no additional enqueue.
mw._apply(_make_state(tool_calls=call), runtime)
assert len(mw._pending_warnings[_pending_key()]) == 1
def test_hard_stop_at_limit(self):
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
@@ -257,6 +440,7 @@ class TestLoopDetection:
mw.reset()
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
assert not mw._pending_warnings.get(_pending_key())
def test_non_ai_message_ignored(self):
mw = LoopDetectionMiddleware()
@@ -283,15 +467,16 @@ class TestLoopDetection:
# One call on thread B
mw._apply(_make_state(tool_calls=call), runtime_b)
# Second call on thread A — triggers warning (2 >= warn_threshold)
result = mw._apply(_make_state(tool_calls=call), runtime_a)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
# Second call on thread A — queues warning under thread-A only.
mw._apply(_make_state(tool_calls=call), runtime_a)
assert mw._pending_warnings.get(_pending_key("thread-A"))
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-A")][0]
assert not mw._pending_warnings.get(_pending_key("thread-B"))
# Second call on thread B — also triggers (independent tracking)
result = mw._apply(_make_state(tool_calls=call), runtime_b)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
# Second call on thread B — independent queue.
mw._apply(_make_state(tool_calls=call), runtime_b)
assert mw._pending_warnings.get(_pending_key("thread-B"))
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-B")][0]
def test_lru_eviction(self):
"""Old threads should be evicted when max_tracked_threads is exceeded."""
@@ -313,6 +498,55 @@ class TestLoopDetection:
assert "thread-new" in mw._history
assert len(mw._history) == 3
def test_warned_hashes_are_pruned_to_sliding_window(self):
"""A long-lived thread should not keep every historical warned hash."""
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=100, window_size=4)
runtime = _make_runtime()
for i in range(12):
call = [_bash_call(f"cmd_{i}")]
mw._apply(_make_state(tool_calls=call), runtime)
mw._apply(_make_state(tool_calls=call), runtime)
assert len(mw._history["test-thread"]) <= 4
assert set(mw._warned["test-thread"]).issubset(set(mw._history["test-thread"]))
assert len(mw._warned["test-thread"]) <= 4
def test_pending_warning_keys_are_capped(self):
"""Abnormal same-thread runs cannot grow pending-warning keys forever."""
mw = LoopDetectionMiddleware(warn_threshold=2, max_tracked_threads=2)
for i in range(10):
runtime = _make_runtime(thread_id="same-thread", run_id=f"run-{i}")
mw._queue_pending_warning(runtime, f"warning-{i}")
assert len(mw._pending_warnings) == mw._max_pending_warning_keys
assert len(mw._pending_warning_touch_order) == mw._max_pending_warning_keys
assert _pending_key("same-thread", "run-9") in mw._pending_warnings
def test_pending_warning_list_is_capped_and_deduped(self):
"""One run cannot accumulate an unbounded warning list."""
mw = LoopDetectionMiddleware()
runtime = _make_runtime()
for i in range(_MAX_PENDING_WARNINGS_PER_RUN + 4):
mw._queue_pending_warning(runtime, f"warning-{i}")
mw._queue_pending_warning(runtime, f"warning-{_MAX_PENDING_WARNINGS_PER_RUN + 3}")
warnings = mw._pending_warnings[_pending_key()]
assert len(warnings) == _MAX_PENDING_WARNINGS_PER_RUN
assert warnings == [f"warning-{i}" for i in range(4, _MAX_PENDING_WARNINGS_PER_RUN + 4)]
def test_pending_warning_touch_order_cleared_with_pending_key(self):
mw = LoopDetectionMiddleware()
runtime = _make_runtime()
mw._queue_pending_warning(runtime, "warning")
mw.after_agent({"messages": []}, runtime)
assert mw._pending_warnings == {}
assert mw._pending_warning_touch_order == OrderedDict()
def test_thread_safe_mutations(self):
"""Verify lock is used for mutations (basic structural test)."""
mw = LoopDetectionMiddleware()
@@ -331,6 +565,99 @@ class TestLoopDetection:
assert "default" in mw._history
class TestLoopDetectionAgentGraphIntegration:
def test_loop_warning_is_transient_in_real_agent_graph(self):
"""after_model queues the warning; wrap_model_call injects it request-only."""
@as_tool
def bash(command: str) -> str:
"""Run a fake shell command."""
return f"ran: {command}"
repeated_calls = [[{"name": "bash", "id": f"call_ls_{i}", "args": {"command": "ls"}}] for i in range(3)]
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
model = _CapturingFakeMessagesListChatModel(
responses=[
AIMessage(content="", tool_calls=repeated_calls[0]),
AIMessage(content="", tool_calls=repeated_calls[1]),
AIMessage(content="", tool_calls=repeated_calls[2]),
AIMessage(content="final answer"),
],
)
graph = create_agent(model=model, tools=[bash], middleware=[mw])
result = graph.invoke(
{"messages": [("user", "inspect the directory")]},
context={"thread_id": "integration-thread", "run_id": "integration-run"},
config={"recursion_limit": 20},
)
assert len(model.seen_messages) == 4
loop_warnings_by_call = [[message for message in messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] for messages in model.seen_messages]
assert loop_warnings_by_call[0] == []
assert loop_warnings_by_call[1] == []
assert loop_warnings_by_call[2] == []
assert len(loop_warnings_by_call[3]) == 1
assert "LOOP DETECTED" in loop_warnings_by_call[3][0].content
fourth_request = model.seen_messages[3]
assert isinstance(fourth_request[-2], ToolMessage)
assert fourth_request[-2].tool_call_id == "call_ls_2"
assert fourth_request[-1] is loop_warnings_by_call[3][0]
persisted_loop_warnings = [message for message in result["messages"] if isinstance(message, HumanMessage) and message.name == "loop_warning"]
assert persisted_loop_warnings == []
assert result["messages"][-1].content == "final answer"
assert mw._pending_warnings == {}
assert mw._pending_warning_touch_order == OrderedDict()
@pytest.mark.asyncio
async def test_loop_warning_is_transient_in_async_agent_graph(self):
"""awrap_model_call injects loop_warning request-only in async graph runs."""
@as_tool
async def bash(command: str) -> str:
"""Run a fake shell command."""
return f"ran: {command}"
repeated_calls = [[{"name": "bash", "id": f"call_async_ls_{i}", "args": {"command": "ls"}}] for i in range(3)]
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
model = _CapturingFakeMessagesListChatModel(
responses=[
AIMessage(content="", tool_calls=repeated_calls[0]),
AIMessage(content="", tool_calls=repeated_calls[1]),
AIMessage(content="", tool_calls=repeated_calls[2]),
AIMessage(content="async final answer"),
],
)
graph = create_agent(model=model, tools=[bash], middleware=[mw])
result = await graph.ainvoke(
{"messages": [("user", "inspect the directory asynchronously")]},
context={"thread_id": "async-integration-thread", "run_id": "async-integration-run"},
config={"recursion_limit": 20},
)
assert len(model.seen_messages) == 4
loop_warnings_by_call = [[message for message in messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] for messages in model.seen_messages]
assert loop_warnings_by_call[0] == []
assert loop_warnings_by_call[1] == []
assert loop_warnings_by_call[2] == []
assert len(loop_warnings_by_call[3]) == 1
assert "LOOP DETECTED" in loop_warnings_by_call[3][0].content
fourth_request = model.seen_messages[3]
assert isinstance(fourth_request[-2], ToolMessage)
assert fourth_request[-2].tool_call_id == "call_async_ls_2"
assert fourth_request[-1] is loop_warnings_by_call[3][0]
persisted_loop_warnings = [message for message in result["messages"] if isinstance(message, HumanMessage) and message.name == "loop_warning"]
assert persisted_loop_warnings == []
assert result["messages"][-1].content == "async final answer"
assert mw._pending_warnings == {}
assert mw._pending_warning_touch_order == OrderedDict()
class TestAppendText:
"""Unit tests for LoopDetectionMiddleware._append_text."""
@@ -507,33 +834,29 @@ class TestToolFrequencyDetection:
for i in range(4):
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
# 5th call to read_file (different file each time) triggers freq warning
# 5th call queues a per-tool-type frequency warning; state untouched.
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_4.py")]), runtime)
assert result is not None
msg = result["messages"][0]
# Warning is appended to the AIMessage content; tool_calls preserved
# so the tools node still runs and Moonshot/OpenAI tool-call pairing
# validation does not break.
assert isinstance(msg, AIMessage)
assert msg.tool_calls
assert "read_file" in msg.content
assert "LOOP DETECTED" in msg.content
assert result is None
queued = mw._pending_warnings.get(_pending_key(), [])
assert queued
assert "read_file" in queued[0]
assert "LOOP DETECTED" in queued[0]
def test_freq_warn_only_injected_once(self):
def test_freq_warn_only_queued_once(self):
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10)
runtime = _make_runtime()
for i in range(2):
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
# 3rd triggers warning
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
# 3rd queues a frequency warning.
mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert len(mw._pending_warnings[_pending_key()]) == 1
# 4th should not re-warn (already warned for read_file)
# 4th: same tool name, no additional enqueue.
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_3.py")]), runtime)
assert result is None
assert len(mw._pending_warnings[_pending_key()]) == 1
def test_freq_hard_stop_at_limit(self):
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=6)
@@ -565,10 +888,10 @@ class TestToolFrequencyDetection:
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
assert result is None
# 3rd read_file triggers (read_file count = 3)
# 3rd read_file triggers — warning is queued (state unchanged).
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert result is not None
assert "read_file" in result["messages"][0].content
assert result is None
assert "read_file" in mw._pending_warnings[_pending_key()][0]
def test_freq_reset_clears_state(self):
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10)
@@ -600,10 +923,10 @@ class TestToolFrequencyDetection:
assert "thread-A" not in mw._tool_freq
assert "thread-A" not in mw._tool_freq_warned
# thread-B state should still be intact — 3rd call triggers warn
# thread-B state should still be intact — 3rd call queues a warn.
result = mw._apply(_make_state(tool_calls=[self._read_call("/b_2.py")]), runtime_b)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
assert result is None
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-B")][0]
# thread-A restarted from 0 — should not trigger
result = mw._apply(_make_state(tool_calls=[self._read_call("/a_new.py")]), runtime_a)
@@ -623,10 +946,11 @@ class TestToolFrequencyDetection:
for i in range(2):
mw._apply(_make_state(tool_calls=[self._read_call(f"/other_{i}.py")]), runtime_b)
# 3rd call on thread A — triggers (count=3 for thread A only)
# 3rd call on thread A — queues a warning (count=3 for thread A only).
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime_a)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
assert result is None
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-A")][0]
assert not mw._pending_warnings.get(_pending_key("thread-B"))
def test_multi_tool_single_response_counted(self):
"""When a single response has multiple tool calls, each is counted."""
@@ -643,10 +967,10 @@ class TestToolFrequencyDetection:
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
# Response 3: 1 more → count = 5 → triggers warn
# Response 3: 1 more → count = 5 → queues warn.
result = mw._apply(_make_state(tool_calls=[self._read_call("/e.py")]), runtime)
assert result is not None
assert "read_file" in result["messages"][0].content
assert result is None
assert "read_file" in mw._pending_warnings[_pending_key()][0]
def test_override_tool_uses_override_thresholds(self):
"""A tool in tool_freq_overrides uses its own thresholds, not the global ones."""
@@ -674,10 +998,14 @@ class TestToolFrequencyDetection:
for i in range(2):
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
# 3rd read_file call hits global warn=3 (read_file has no override)
# 3rd read_file call hits global warn=3 (read_file has no override).
# Warning delivery is deferred to wrap_model_call so the just-emitted
# AIMessage(tool_calls=...) is not mutated before ToolMessages exist.
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert result is not None
assert "read_file" in result["messages"][0].content
assert result is None
queued = mw._pending_warnings.get(_pending_key(), [])
assert queued
assert "read_file" in queued[0]
def test_hash_detection_takes_priority(self):
"""Hash-based hard stop fires before frequency check for identical calls."""
@@ -736,11 +1064,13 @@ class TestFromConfig:
mw = LoopDetectionMiddleware.from_config(self._config())
assert mw._tool_freq_overrides == {}
def test_constructed_middleware_detects_loops(self):
def test_constructed_middleware_queues_loop_warning(self):
mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4))
runtime = _make_runtime()
call = [_bash_call("ls")]
mw._apply(_make_state(tool_calls=call), runtime)
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
assert result is None
queued = mw._pending_warnings.get(_pending_key(), [])
assert queued
assert "LOOP DETECTED" in queued[0]