"""Tests for LoopDetectionMiddleware.""" import copy from collections import OrderedDict from typing import Any from unittest.mock import MagicMock import pytest from langchain.agents import create_agent from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.runnables import Runnable from langchain_core.tools import tool as as_tool from pydantic import PrivateAttr from deerflow.agents.middlewares.loop_detection_middleware import ( _HARD_STOP_MSG, _MAX_PENDING_WARNINGS_PER_RUN, LoopDetectionMiddleware, _hash_tool_calls, ) def _make_runtime(thread_id="test-thread", run_id="test-run"): """Build a minimal Runtime mock with context.""" runtime = MagicMock() runtime.context = {"thread_id": thread_id, "run_id": run_id} return runtime def _pending_key(thread_id="test-thread", run_id="test-run"): return (thread_id, run_id) def _make_request(messages, runtime): """Build a minimal ModelRequest stand-in for wrap_model_call tests.""" request = MagicMock() request.messages = list(messages) request.runtime = runtime request.override = lambda **updates: _override_request(request, updates) return request def _override_request(request, updates): """Mimic ModelRequest.override(): return a copy with fields replaced.""" new = MagicMock() new.messages = updates.get("messages", request.messages) new.runtime = updates.get("runtime", request.runtime) new.override = lambda **u: _override_request(new, u) return new def _capture_handler(): """Build a sync handler that records the request it was called with.""" captured: list = [] def handler(req): captured.append(req) return MagicMock() return captured, handler class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel): """Fake chat model that records each model request's messages.""" _seen_messages: list[list[Any]] = PrivateAttr(default_factory=list) @property def seen_messages(self) -> list[list[Any]]: return self._seen_messages def bind_tools( self, tools: Any, *, tool_choice: Any = None, **kwargs: Any, ) -> Runnable: return self def _generate(self, messages, stop=None, run_manager=None, **kwargs): self._seen_messages.append(list(messages)) return super()._generate( messages, stop=stop, run_manager=run_manager, **kwargs, ) def _make_state(tool_calls=None, content=""): """Build a minimal AgentState dict with an AIMessage. Deep-copies *content* when it is mutable (e.g. list) so that successive calls never share the same object reference. """ safe_content = copy.deepcopy(content) if isinstance(content, list) else content msg = AIMessage(content=safe_content, tool_calls=tool_calls or []) return {"messages": [msg]} def _bash_call(cmd="ls"): return {"name": "bash", "id": f"call_{cmd}", "args": {"command": cmd}} class TestHashToolCalls: def test_same_calls_same_hash(self): a = _hash_tool_calls([_bash_call("ls")]) b = _hash_tool_calls([_bash_call("ls")]) assert a == b def test_different_calls_different_hash(self): a = _hash_tool_calls([_bash_call("ls")]) b = _hash_tool_calls([_bash_call("pwd")]) assert a != b def test_order_independent(self): a = _hash_tool_calls([_bash_call("ls"), {"name": "read_file", "args": {"path": "/tmp"}}]) b = _hash_tool_calls([{"name": "read_file", "args": {"path": "/tmp"}}, _bash_call("ls")]) assert a == b def test_empty_calls(self): h = _hash_tool_calls([]) assert isinstance(h, str) assert len(h) > 0 def test_stringified_dict_args_match_dict_args(self): dict_call = { "name": "read_file", "args": {"path": "/tmp/demo.py", "start_line": "1", "end_line": "150"}, } string_call = { "name": "read_file", "args": '{"path":"/tmp/demo.py","start_line":"1","end_line":"150"}', } assert _hash_tool_calls([dict_call]) == _hash_tool_calls([string_call]) def test_reversed_read_file_range_matches_forward_range(self): forward_call = { "name": "read_file", "args": {"path": "/tmp/demo.py", "start_line": 10, "end_line": 300}, } reversed_call = { "name": "read_file", "args": {"path": "/tmp/demo.py", "start_line": 300, "end_line": 10}, } assert _hash_tool_calls([forward_call]) == _hash_tool_calls([reversed_call]) def test_stringified_non_dict_args_do_not_crash(self): non_dict_json_call = {"name": "bash", "args": '"echo hello"'} plain_string_call = {"name": "bash", "args": "echo hello"} json_hash = _hash_tool_calls([non_dict_json_call]) plain_hash = _hash_tool_calls([plain_string_call]) assert isinstance(json_hash, str) assert isinstance(plain_hash, str) assert json_hash assert plain_hash def test_grep_pattern_affects_hash(self): grep_foo = {"name": "grep", "args": {"path": "/tmp", "pattern": "foo"}} grep_bar = {"name": "grep", "args": {"path": "/tmp", "pattern": "bar"}} assert _hash_tool_calls([grep_foo]) != _hash_tool_calls([grep_bar]) def test_glob_pattern_affects_hash(self): glob_py = {"name": "glob", "args": {"path": "/tmp", "pattern": "*.py"}} glob_ts = {"name": "glob", "args": {"path": "/tmp", "pattern": "*.ts"}} assert _hash_tool_calls([glob_py]) != _hash_tool_calls([glob_ts]) def test_write_file_content_affects_hash(self): v1 = {"name": "write_file", "args": {"path": "/tmp/a.py", "content": "v1"}} v2 = {"name": "write_file", "args": {"path": "/tmp/a.py", "content": "v2"}} assert _hash_tool_calls([v1]) != _hash_tool_calls([v2]) def test_str_replace_content_affects_hash(self): a = { "name": "str_replace", "args": {"path": "/tmp/a.py", "old_str": "foo", "new_str": "bar"}, } b = { "name": "str_replace", "args": {"path": "/tmp/a.py", "old_str": "foo", "new_str": "baz"}, } assert _hash_tool_calls([a]) != _hash_tool_calls([b]) class TestLoopDetection: def test_no_tool_calls_returns_none(self): mw = LoopDetectionMiddleware() runtime = _make_runtime() state = {"messages": [AIMessage(content="hello")]} result = mw._apply(state, runtime) assert result is None def test_below_threshold_returns_none(self): mw = LoopDetectionMiddleware(warn_threshold=3) runtime = _make_runtime() call = [_bash_call("ls")] # First two identical calls — no warning for _ in range(2): result = mw._apply(_make_state(tool_calls=call), runtime) assert result is None def test_warn_at_threshold_queues_but_does_not_mutate_state(self): """At warn threshold, ``after_model`` enqueues but returns None. Detection observes the just-emitted AIMessage(tool_calls=...). The tools node hasn't run yet, so injecting any non-tool message here would split the assistant's tool_calls from their ToolMessage responses and break OpenAI/Moonshot pairing. The warning is delivered later from ``wrap_model_call``. """ mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=5) runtime = _make_runtime() call = [_bash_call("ls")] for _ in range(2): mw._apply(_make_state(tool_calls=call), runtime) # Third identical call triggers warning detection. result = mw._apply(_make_state(tool_calls=call), runtime) # Detection must not mutate state — the AIMessage with tool_calls is # left untouched so the tools node runs normally. assert result is None # ...but a warning is queued for the next model call. assert mw._pending_warnings[_pending_key()] assert "LOOP DETECTED" in mw._pending_warnings[_pending_key()][0] def test_warn_injected_at_next_model_call(self): """``wrap_model_call`` appends a HumanMessage(loop_warning) to the outgoing messages — *after* every existing message — so that the AIMessage(tool_calls=...) -> ToolMessage(...) pairing stays intact. """ mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) runtime = _make_runtime() call = [_bash_call("ls")] for _ in range(3): mw._apply(_make_state(tool_calls=call), runtime) # Build the messages the agent runtime would assemble for the next # turn: prior AIMessage(tool_calls), its ToolMessage responses, ... ai_msg = AIMessage(content="", tool_calls=call) tool_msg = ToolMessage(content="ok", tool_call_id=call[0]["id"], name="bash") request = _make_request([ai_msg, tool_msg], runtime) captured, handler = _capture_handler() mw.wrap_model_call(request, handler) sent = captured[0].messages # AIMessage and ToolMessage stay in order, untouched. assert sent[0] is ai_msg assert sent[1] is tool_msg # HumanMessage(warning) appears AFTER the ToolMessage — pairing intact. assert isinstance(sent[2], HumanMessage) assert sent[2].name == "loop_warning" assert "LOOP DETECTED" in sent[2].content def test_warn_queue_drained_after_injection(self): """A queued warning must be emitted exactly once per detection event.""" mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) runtime = _make_runtime() call = [_bash_call("ls")] for _ in range(3): mw._apply(_make_state(tool_calls=call), runtime) request = _make_request([AIMessage(content="hi")], runtime) captured, handler = _capture_handler() # First call: warning is appended. mw.wrap_model_call(request, handler) first = captured[0].messages assert any(isinstance(m, HumanMessage) for m in first) # Subsequent call without new detection: no warning re-emitted. request2 = _make_request([AIMessage(content="hi")], runtime) mw.wrap_model_call(request2, handler) second = captured[1].messages assert not any(isinstance(m, HumanMessage) for m in second) def test_warn_queue_scoped_by_run_id(self): """A warning queued for one run must not be injected into another run.""" mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) runtime_a = _make_runtime(run_id="run-A") runtime_b = _make_runtime(run_id="run-B") call = [_bash_call("ls")] for _ in range(3): mw._apply(_make_state(tool_calls=call), runtime_a) request_b = _make_request([AIMessage(content="hi")], runtime_b) captured, handler = _capture_handler() mw.wrap_model_call(request_b, handler) assert not any(isinstance(m, HumanMessage) for m in captured[0].messages) assert mw._pending_warnings.get(_pending_key(run_id="run-A")) request_a = _make_request([AIMessage(content="hi")], runtime_a) mw.wrap_model_call(request_a, handler) assert any(isinstance(message, HumanMessage) and message.name == "loop_warning" for message in captured[1].messages) def test_missing_run_id_uses_default_pending_scope(self): """When runtime has no run_id, warning handling falls back to the default run scope.""" mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) runtime = MagicMock() runtime.context = {"thread_id": "test-thread"} call = [_bash_call("ls")] for _ in range(3): mw._apply(_make_state(tool_calls=call), runtime) assert mw._pending_warnings.get(_pending_key(run_id="default")) request = _make_request([AIMessage(content="hi")], runtime) captured, handler = _capture_handler() mw.wrap_model_call(request, handler) loop_warnings = [message for message in captured[0].messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] assert len(loop_warnings) == 1 assert "LOOP DETECTED" in loop_warnings[0].content assert not mw._pending_warnings.get(_pending_key(run_id="default")) def test_before_agent_clears_stale_pending_warnings_for_thread(self): """Starting a new run drops stale warnings from prior runs in the same thread.""" mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) runtime_a = _make_runtime(run_id="run-A") runtime_b = _make_runtime(run_id="run-B") call = [_bash_call("ls")] for _ in range(3): mw._apply(_make_state(tool_calls=call), runtime_a) assert mw._pending_warnings.get(_pending_key(run_id="run-A")) mw.before_agent({"messages": []}, runtime_b) assert not mw._pending_warnings.get(_pending_key(run_id="run-A")) def test_after_agent_clears_current_run_pending_warnings(self): """Run cleanup should drop warnings that never reached wrap_model_call.""" mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) runtime = _make_runtime() call = [_bash_call("ls")] for _ in range(3): mw._apply(_make_state(tool_calls=call), runtime) assert mw._pending_warnings.get(_pending_key()) mw.after_agent({"messages": []}, runtime) assert not mw._pending_warnings.get(_pending_key()) def test_multiple_pending_warnings_are_merged_into_one_message(self): """Edge-case drains should produce one loop_warning prompt message.""" mw = LoopDetectionMiddleware() runtime = _make_runtime() mw._pending_warnings[_pending_key()] = ["first warning", "second warning", "first warning"] request = _make_request([AIMessage(content="hi")], runtime) captured, handler = _capture_handler() mw.wrap_model_call(request, handler) loop_warnings = [message for message in captured[0].messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] assert len(loop_warnings) == 1 assert loop_warnings[0].content == "first warning\n\nsecond warning" def test_warn_only_queued_once_per_hash(self): """Same hash repeated past the threshold should warn only once.""" mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) runtime = _make_runtime() call = [_bash_call("ls")] # First two — no warning for _ in range(2): mw._apply(_make_state(tool_calls=call), runtime) # Third — warning queued mw._apply(_make_state(tool_calls=call), runtime) assert len(mw._pending_warnings[_pending_key()]) == 1 # Fourth — already warned for this hash, no additional enqueue. mw._apply(_make_state(tool_calls=call), runtime) assert len(mw._pending_warnings[_pending_key()]) == 1 def test_hard_stop_at_limit(self): mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4) runtime = _make_runtime() call = [_bash_call("ls")] for _ in range(3): mw._apply(_make_state(tool_calls=call), runtime) # Fourth call triggers hard stop result = mw._apply(_make_state(tool_calls=call), runtime) assert result is not None msgs = result["messages"] assert len(msgs) == 1 # Hard stop strips tool_calls assert isinstance(msgs[0], AIMessage) assert msgs[0].tool_calls == [] assert _HARD_STOP_MSG in msgs[0].content def test_different_calls_dont_trigger(self): mw = LoopDetectionMiddleware(warn_threshold=2) runtime = _make_runtime() # Each call is different for i in range(10): result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime) assert result is None def test_window_sliding(self): mw = LoopDetectionMiddleware(warn_threshold=3, window_size=5) runtime = _make_runtime() call = [_bash_call("ls")] # Fill with 2 identical calls mw._apply(_make_state(tool_calls=call), runtime) mw._apply(_make_state(tool_calls=call), runtime) # Push them out of the window with different calls for i in range(5): mw._apply(_make_state(tool_calls=[_bash_call(f"other_{i}")]), runtime) # Now the original call should be fresh again — no warning result = mw._apply(_make_state(tool_calls=call), runtime) assert result is None def test_reset_clears_state(self): mw = LoopDetectionMiddleware(warn_threshold=2) runtime = _make_runtime() call = [_bash_call("ls")] mw._apply(_make_state(tool_calls=call), runtime) mw._apply(_make_state(tool_calls=call), runtime) # Would trigger warning, but reset first mw.reset() result = mw._apply(_make_state(tool_calls=call), runtime) assert result is None assert not mw._pending_warnings.get(_pending_key()) def test_non_ai_message_ignored(self): mw = LoopDetectionMiddleware() runtime = _make_runtime() state = {"messages": [SystemMessage(content="hello")]} result = mw._apply(state, runtime) assert result is None def test_empty_messages_ignored(self): mw = LoopDetectionMiddleware() runtime = _make_runtime() result = mw._apply({"messages": []}, runtime) assert result is None def test_thread_id_from_runtime_context(self): """Thread ID should come from runtime.context, not state.""" mw = LoopDetectionMiddleware(warn_threshold=2) runtime_a = _make_runtime("thread-A") runtime_b = _make_runtime("thread-B") call = [_bash_call("ls")] # One call on thread A mw._apply(_make_state(tool_calls=call), runtime_a) # One call on thread B mw._apply(_make_state(tool_calls=call), runtime_b) # Second call on thread A — queues warning under thread-A only. mw._apply(_make_state(tool_calls=call), runtime_a) assert mw._pending_warnings.get(_pending_key("thread-A")) assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-A")][0] assert not mw._pending_warnings.get(_pending_key("thread-B")) # Second call on thread B — independent queue. mw._apply(_make_state(tool_calls=call), runtime_b) assert mw._pending_warnings.get(_pending_key("thread-B")) assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-B")][0] def test_lru_eviction(self): """Old threads should be evicted when max_tracked_threads is exceeded.""" mw = LoopDetectionMiddleware(warn_threshold=2, max_tracked_threads=3) call = [_bash_call("ls")] # Fill up 3 threads for i in range(3): runtime = _make_runtime(f"thread-{i}") mw._apply(_make_state(tool_calls=call), runtime) # Add a 4th thread — should evict thread-0 runtime_new = _make_runtime("thread-new") mw._apply(_make_state(tool_calls=call), runtime_new) assert "thread-0" not in mw._history assert "thread-0" not in mw._tool_freq assert "thread-0" not in mw._tool_freq_warned assert "thread-new" in mw._history assert len(mw._history) == 3 def test_warned_hashes_are_pruned_to_sliding_window(self): """A long-lived thread should not keep every historical warned hash.""" mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=100, window_size=4) runtime = _make_runtime() for i in range(12): call = [_bash_call(f"cmd_{i}")] mw._apply(_make_state(tool_calls=call), runtime) mw._apply(_make_state(tool_calls=call), runtime) assert len(mw._history["test-thread"]) <= 4 assert set(mw._warned["test-thread"]).issubset(set(mw._history["test-thread"])) assert len(mw._warned["test-thread"]) <= 4 def test_pending_warning_keys_are_capped(self): """Abnormal same-thread runs cannot grow pending-warning keys forever.""" mw = LoopDetectionMiddleware(warn_threshold=2, max_tracked_threads=2) for i in range(10): runtime = _make_runtime(thread_id="same-thread", run_id=f"run-{i}") mw._queue_pending_warning(runtime, f"warning-{i}") assert len(mw._pending_warnings) == mw._max_pending_warning_keys assert len(mw._pending_warning_touch_order) == mw._max_pending_warning_keys assert _pending_key("same-thread", "run-9") in mw._pending_warnings def test_pending_warning_list_is_capped_and_deduped(self): """One run cannot accumulate an unbounded warning list.""" mw = LoopDetectionMiddleware() runtime = _make_runtime() for i in range(_MAX_PENDING_WARNINGS_PER_RUN + 4): mw._queue_pending_warning(runtime, f"warning-{i}") mw._queue_pending_warning(runtime, f"warning-{_MAX_PENDING_WARNINGS_PER_RUN + 3}") warnings = mw._pending_warnings[_pending_key()] assert len(warnings) == _MAX_PENDING_WARNINGS_PER_RUN assert warnings == [f"warning-{i}" for i in range(4, _MAX_PENDING_WARNINGS_PER_RUN + 4)] def test_pending_warning_touch_order_cleared_with_pending_key(self): mw = LoopDetectionMiddleware() runtime = _make_runtime() mw._queue_pending_warning(runtime, "warning") mw.after_agent({"messages": []}, runtime) assert mw._pending_warnings == {} assert mw._pending_warning_touch_order == OrderedDict() def test_thread_safe_mutations(self): """Verify lock is used for mutations (basic structural test).""" mw = LoopDetectionMiddleware() # The middleware should have a lock attribute assert hasattr(mw, "_lock") assert isinstance(mw._lock, type(mw._lock)) def test_fallback_thread_id_when_missing(self): """When runtime context has no thread_id, should use 'default'.""" mw = LoopDetectionMiddleware(warn_threshold=2) runtime = MagicMock() runtime.context = {} call = [_bash_call("ls")] mw._apply(_make_state(tool_calls=call), runtime) assert "default" in mw._history class TestLoopDetectionAgentGraphIntegration: def test_loop_warning_is_transient_in_real_agent_graph(self): """after_model queues the warning; wrap_model_call injects it request-only.""" @as_tool def bash(command: str) -> str: """Run a fake shell command.""" return f"ran: {command}" repeated_calls = [[{"name": "bash", "id": f"call_ls_{i}", "args": {"command": "ls"}}] for i in range(3)] mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) model = _CapturingFakeMessagesListChatModel( responses=[ AIMessage(content="", tool_calls=repeated_calls[0]), AIMessage(content="", tool_calls=repeated_calls[1]), AIMessage(content="", tool_calls=repeated_calls[2]), AIMessage(content="final answer"), ], ) graph = create_agent(model=model, tools=[bash], middleware=[mw]) result = graph.invoke( {"messages": [("user", "inspect the directory")]}, context={"thread_id": "integration-thread", "run_id": "integration-run"}, config={"recursion_limit": 20}, ) assert len(model.seen_messages) == 4 loop_warnings_by_call = [[message for message in messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] for messages in model.seen_messages] assert loop_warnings_by_call[0] == [] assert loop_warnings_by_call[1] == [] assert loop_warnings_by_call[2] == [] assert len(loop_warnings_by_call[3]) == 1 assert "LOOP DETECTED" in loop_warnings_by_call[3][0].content fourth_request = model.seen_messages[3] assert isinstance(fourth_request[-2], ToolMessage) assert fourth_request[-2].tool_call_id == "call_ls_2" assert fourth_request[-1] is loop_warnings_by_call[3][0] persisted_loop_warnings = [message for message in result["messages"] if isinstance(message, HumanMessage) and message.name == "loop_warning"] assert persisted_loop_warnings == [] assert result["messages"][-1].content == "final answer" assert mw._pending_warnings == {} assert mw._pending_warning_touch_order == OrderedDict() @pytest.mark.asyncio async def test_loop_warning_is_transient_in_async_agent_graph(self): """awrap_model_call injects loop_warning request-only in async graph runs.""" @as_tool async def bash(command: str) -> str: """Run a fake shell command.""" return f"ran: {command}" repeated_calls = [[{"name": "bash", "id": f"call_async_ls_{i}", "args": {"command": "ls"}}] for i in range(3)] mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) model = _CapturingFakeMessagesListChatModel( responses=[ AIMessage(content="", tool_calls=repeated_calls[0]), AIMessage(content="", tool_calls=repeated_calls[1]), AIMessage(content="", tool_calls=repeated_calls[2]), AIMessage(content="async final answer"), ], ) graph = create_agent(model=model, tools=[bash], middleware=[mw]) result = await graph.ainvoke( {"messages": [("user", "inspect the directory asynchronously")]}, context={"thread_id": "async-integration-thread", "run_id": "async-integration-run"}, config={"recursion_limit": 20}, ) assert len(model.seen_messages) == 4 loop_warnings_by_call = [[message for message in messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] for messages in model.seen_messages] assert loop_warnings_by_call[0] == [] assert loop_warnings_by_call[1] == [] assert loop_warnings_by_call[2] == [] assert len(loop_warnings_by_call[3]) == 1 assert "LOOP DETECTED" in loop_warnings_by_call[3][0].content fourth_request = model.seen_messages[3] assert isinstance(fourth_request[-2], ToolMessage) assert fourth_request[-2].tool_call_id == "call_async_ls_2" assert fourth_request[-1] is loop_warnings_by_call[3][0] persisted_loop_warnings = [message for message in result["messages"] if isinstance(message, HumanMessage) and message.name == "loop_warning"] assert persisted_loop_warnings == [] assert result["messages"][-1].content == "async final answer" assert mw._pending_warnings == {} assert mw._pending_warning_touch_order == OrderedDict() class TestAppendText: """Unit tests for LoopDetectionMiddleware._append_text.""" def test_none_content_returns_text(self): result = LoopDetectionMiddleware._append_text(None, "hello") assert result == "hello" def test_str_content_concatenates(self): result = LoopDetectionMiddleware._append_text("existing", "appended") assert result == "existing\n\nappended" def test_empty_str_content_concatenates(self): result = LoopDetectionMiddleware._append_text("", "appended") assert result == "\n\nappended" def test_list_content_appends_text_block(self): """List content (e.g. Anthropic thinking mode) should get a new text block.""" content = [ {"type": "thinking", "text": "Let me think..."}, {"type": "text", "text": "Here is my answer"}, ] result = LoopDetectionMiddleware._append_text(content, "stop msg") assert isinstance(result, list) assert len(result) == 3 assert result[0] == content[0] assert result[1] == content[1] assert result[2] == {"type": "text", "text": "\n\nstop msg"} def test_empty_list_content_appends_text_block(self): result = LoopDetectionMiddleware._append_text([], "stop msg") assert isinstance(result, list) assert len(result) == 1 assert result[0] == {"type": "text", "text": "\n\nstop msg"} def test_unexpected_type_coerced_to_str(self): """Unexpected content types should be coerced to str as a fallback.""" result = LoopDetectionMiddleware._append_text(42, "stop msg") assert isinstance(result, str) assert result == "42\n\nstop msg" def test_list_content_not_mutated_in_place(self): """_append_text must not modify the original list.""" original = [{"type": "text", "text": "hello"}] result = LoopDetectionMiddleware._append_text(original, "appended") assert len(original) == 1 # original unchanged assert len(result) == 2 # new list has the appended block class TestHardStopWithListContent: """Regression tests: hard stop must not crash when AIMessage.content is a list.""" def test_hard_stop_with_list_content(self): """Hard stop on list content should not raise TypeError (regression).""" mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4) runtime = _make_runtime() call = [_bash_call("ls")] # Build state with list content (e.g. Anthropic thinking mode) list_content = [ {"type": "thinking", "text": "Let me think..."}, {"type": "text", "text": "I'll run ls"}, ] for _ in range(3): mw._apply(_make_state(tool_calls=call, content=list_content), runtime) # Fourth call triggers hard stop — must not raise TypeError result = mw._apply(_make_state(tool_calls=call, content=list_content), runtime) assert result is not None msg = result["messages"][0] assert isinstance(msg, AIMessage) assert msg.tool_calls == [] # Content should remain a list with the stop message appended assert isinstance(msg.content, list) assert len(msg.content) == 3 assert msg.content[2]["type"] == "text" assert _HARD_STOP_MSG in msg.content[2]["text"] def test_hard_stop_with_none_content(self): """Hard stop on None content should produce a plain string.""" mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4) runtime = _make_runtime() call = [_bash_call("ls")] for _ in range(3): mw._apply(_make_state(tool_calls=call), runtime) # Fourth call with default empty-string content result = mw._apply(_make_state(tool_calls=call), runtime) assert result is not None msg = result["messages"][0] assert isinstance(msg.content, str) assert _HARD_STOP_MSG in msg.content def test_hard_stop_with_str_content(self): """Hard stop on str content should concatenate the stop message.""" mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4) runtime = _make_runtime() call = [_bash_call("ls")] for _ in range(3): mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime) result = mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime) assert result is not None msg = result["messages"][0] assert isinstance(msg.content, str) assert msg.content.startswith("thinking...") assert _HARD_STOP_MSG in msg.content def test_hard_stop_clears_raw_tool_call_metadata(self): """Forced-stop messages must not retain provider-level raw tool-call payloads.""" mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4) runtime = _make_runtime() call = [_bash_call("ls")] def _make_provider_state(): return { "messages": [ AIMessage( content="thinking...", tool_calls=call, additional_kwargs={ "tool_calls": [ { "id": "call_ls", "type": "function", "function": {"name": "bash", "arguments": '{"command":"ls"}'}, "thought_signature": "sig-1", } ], "function_call": {"name": "bash", "arguments": '{"command":"ls"}'}, }, response_metadata={"finish_reason": "tool_calls"}, ) ] } for _ in range(3): mw._apply(_make_provider_state(), runtime) result = mw._apply(_make_provider_state(), runtime) assert result is not None msg = result["messages"][0] assert msg.tool_calls == [] assert "tool_calls" not in msg.additional_kwargs assert "function_call" not in msg.additional_kwargs assert msg.response_metadata["finish_reason"] == "stop" class TestToolFrequencyDetection: """Tests for per-tool-type frequency detection (Layer 2). This catches the case where an agent calls the same tool type many times with *different* arguments (e.g. read_file on 40 different files), which bypasses hash-based detection. """ def _read_call(self, path): return {"name": "read_file", "id": f"call_read_{path}", "args": {"path": path}} def test_below_freq_warn_returns_none(self): mw = LoopDetectionMiddleware(tool_freq_warn=5, tool_freq_hard_limit=10) runtime = _make_runtime() for i in range(4): result = mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) assert result is None def test_freq_warn_at_threshold(self): mw = LoopDetectionMiddleware(tool_freq_warn=5, tool_freq_hard_limit=10) runtime = _make_runtime() for i in range(4): mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) # 5th call queues a per-tool-type frequency warning; state untouched. result = mw._apply(_make_state(tool_calls=[self._read_call("/file_4.py")]), runtime) assert result is None queued = mw._pending_warnings.get(_pending_key(), []) assert queued assert "read_file" in queued[0] assert "LOOP DETECTED" in queued[0] def test_freq_warn_only_queued_once(self): mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10) runtime = _make_runtime() for i in range(2): mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) # 3rd queues a frequency warning. mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime) assert len(mw._pending_warnings[_pending_key()]) == 1 # 4th: same tool name, no additional enqueue. result = mw._apply(_make_state(tool_calls=[self._read_call("/file_3.py")]), runtime) assert result is None assert len(mw._pending_warnings[_pending_key()]) == 1 def test_freq_hard_stop_at_limit(self): mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=6) runtime = _make_runtime() for i in range(5): mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) # 6th call triggers hard stop result = mw._apply(_make_state(tool_calls=[self._read_call("/file_5.py")]), runtime) assert result is not None msg = result["messages"][0] assert isinstance(msg, AIMessage) assert msg.tool_calls == [] assert "FORCED STOP" in msg.content assert "read_file" in msg.content def test_different_tools_tracked_independently(self): """read_file and bash should have independent frequency counters.""" mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10) runtime = _make_runtime() # 2 read_file calls for i in range(2): mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) # 2 bash calls — should not trigger (bash count = 2, read_file count = 2) for i in range(2): result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime) assert result is None # 3rd read_file triggers — warning is queued (state unchanged). result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime) assert result is None assert "read_file" in mw._pending_warnings[_pending_key()][0] def test_freq_reset_clears_state(self): mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10) runtime = _make_runtime() for i in range(2): mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) mw.reset() # After reset, count restarts — should not trigger result = mw._apply(_make_state(tool_calls=[self._read_call("/file_new.py")]), runtime) assert result is None def test_freq_reset_per_thread_clears_only_target(self): """reset(thread_id=...) should clear frequency state for that thread only.""" mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10) runtime_a = _make_runtime("thread-A") runtime_b = _make_runtime("thread-B") # 2 calls on each thread for i in range(2): mw._apply(_make_state(tool_calls=[self._read_call(f"/a_{i}.py")]), runtime_a) mw._apply(_make_state(tool_calls=[self._read_call(f"/b_{i}.py")]), runtime_b) # Reset only thread-A mw.reset(thread_id="thread-A") assert "thread-A" not in mw._tool_freq assert "thread-A" not in mw._tool_freq_warned # thread-B state should still be intact — 3rd call queues a warn. result = mw._apply(_make_state(tool_calls=[self._read_call("/b_2.py")]), runtime_b) assert result is None assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-B")][0] # thread-A restarted from 0 — should not trigger result = mw._apply(_make_state(tool_calls=[self._read_call("/a_new.py")]), runtime_a) assert result is None def test_freq_per_thread_isolation(self): """Frequency counts should be independent per thread.""" mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10) runtime_a = _make_runtime("thread-A") runtime_b = _make_runtime("thread-B") # 2 calls on thread A for i in range(2): mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime_a) # 2 calls on thread B — should NOT push thread A over threshold for i in range(2): mw._apply(_make_state(tool_calls=[self._read_call(f"/other_{i}.py")]), runtime_b) # 3rd call on thread A — queues a warning (count=3 for thread A only). result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime_a) assert result is None assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-A")][0] assert not mw._pending_warnings.get(_pending_key("thread-B")) def test_multi_tool_single_response_counted(self): """When a single response has multiple tool calls, each is counted.""" mw = LoopDetectionMiddleware(tool_freq_warn=5, tool_freq_hard_limit=10) runtime = _make_runtime() # Response 1: 2 read_file calls → count = 2 call = [self._read_call("/a.py"), self._read_call("/b.py")] result = mw._apply(_make_state(tool_calls=call), runtime) assert result is None # Response 2: 2 more → count = 4 call = [self._read_call("/c.py"), self._read_call("/d.py")] result = mw._apply(_make_state(tool_calls=call), runtime) assert result is None # Response 3: 1 more → count = 5 → queues warn. result = mw._apply(_make_state(tool_calls=[self._read_call("/e.py")]), runtime) assert result is None assert "read_file" in mw._pending_warnings[_pending_key()][0] def test_override_tool_uses_override_thresholds(self): """A tool in tool_freq_overrides uses its own thresholds, not the global ones.""" mw = LoopDetectionMiddleware( tool_freq_warn=5, tool_freq_hard_limit=10, tool_freq_overrides={"bash": (50, 100)}, ) runtime = _make_runtime() # 10 bash calls — would hit global hard_limit=10, but bash override is 100 for i in range(10): result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime) assert result is None, f"unexpected trigger on call {i + 1}" def test_non_override_tool_falls_back_to_global(self): """A tool NOT in tool_freq_overrides uses the global warn/hard_limit.""" mw = LoopDetectionMiddleware( tool_freq_warn=3, tool_freq_hard_limit=6, tool_freq_overrides={"bash": (50, 100)}, ) runtime = _make_runtime() for i in range(2): mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) # 3rd read_file call hits global warn=3 (read_file has no override). # Warning delivery is deferred to wrap_model_call so the just-emitted # AIMessage(tool_calls=...) is not mutated before ToolMessages exist. result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime) assert result is None queued = mw._pending_warnings.get(_pending_key(), []) assert queued assert "read_file" in queued[0] def test_hash_detection_takes_priority(self): """Hash-based hard stop fires before frequency check for identical calls.""" mw = LoopDetectionMiddleware( warn_threshold=2, hard_limit=3, tool_freq_warn=100, tool_freq_hard_limit=200, ) runtime = _make_runtime() call = [self._read_call("/same_file.py")] for _ in range(2): mw._apply(_make_state(tool_calls=call), runtime) # 3rd identical call → hash hard_limit=3 fires (not freq) result = mw._apply(_make_state(tool_calls=call), runtime) assert result is not None msg = result["messages"][0] assert isinstance(msg, AIMessage) assert _HARD_STOP_MSG in msg.content class TestFromConfig: """Tests for LoopDetectionMiddleware.from_config — the sole validated construction path.""" @staticmethod def _config(**kwargs): from deerflow.config.loop_detection_config import LoopDetectionConfig return LoopDetectionConfig(**kwargs) def test_scalar_fields_mapped(self): config = self._config( warn_threshold=4, hard_limit=8, window_size=15, max_tracked_threads=50, tool_freq_warn=20, tool_freq_hard_limit=40, ) mw = LoopDetectionMiddleware.from_config(config) assert mw.warn_threshold == 4 assert mw.hard_limit == 8 assert mw.window_size == 15 assert mw.max_tracked_threads == 50 assert mw.tool_freq_warn == 20 assert mw.tool_freq_hard_limit == 40 def test_overrides_converted_to_tuples(self): config = self._config(tool_freq_overrides={"bash": {"warn": 50, "hard_limit": 100}}) mw = LoopDetectionMiddleware.from_config(config) assert mw._tool_freq_overrides == {"bash": (50, 100)} def test_empty_overrides(self): mw = LoopDetectionMiddleware.from_config(self._config()) assert mw._tool_freq_overrides == {} def test_constructed_middleware_queues_loop_warning(self): mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4)) runtime = _make_runtime() call = [_bash_call("ls")] mw._apply(_make_state(tool_calls=call), runtime) result = mw._apply(_make_state(tool_calls=call), runtime) assert result is None queued = mw._pending_warnings.get(_pending_key(), []) assert queued assert "LOOP DETECTED" in queued[0]