feat(loop-detection): defer warning injection (#2752)

* fix(loop-detection): defer warn injection to wrap_model_call

The warn branch in LoopDetectionMiddleware injected a HumanMessage
into state from after_model. The tools node had not yet produced
ToolMessage responses to the previous AIMessage(tool_calls=...), so
the new HumanMessage landed *between* the assistant's tool_calls and
their responses. OpenAI/Moonshot reject the next request with
"tool_call_ids did not have response messages" because their
validators require tool_calls to be followed immediately by tool
messages.

Detection now runs in after_model as before, but only enqueues the
warning into a per-thread list. Injection happens in wrap_model_call,
where every prior ToolMessage is already present in request.messages.
The warning is appended at the end as HumanMessage(name="loop_warning")
— pairing intact, AIMessage semantics untouched, no SystemMessage
issues for Anthropic.

Closes #2029, addresses #2255 #2293 #2304 #2511.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* fix(channels): remove loop warning display filter

* feat(loop-detection): scope pending warnings by run

* docs(loop-detection): update docs

* test(loop-detection): assert deferred warnings are queued

* fix(loop-detection): cap transient warning state

* docs: update docs

* add async awrap_model_call test coverage

* docs(loop-detection): document transient warnings

---------

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Nan Gao
2026-05-21 08:36:07 +02:00
committed by GitHub
parent 7ec8d3a6e7
commit dcc6f1e678
7 changed files with 696 additions and 221 deletions
-15
View File
@@ -146,13 +146,6 @@ def _normalize_custom_agent_name(raw_value: str) -> str:
return normalized
def _strip_loop_warning_text(text: str) -> str:
"""Remove middleware-authored loop warning lines from display text."""
if "[LOOP DETECTED]" not in text:
return text
return "\n".join(line for line in text.splitlines() if "[LOOP DETECTED]" not in line).strip()
def _extract_response_text(result: dict | list) -> str:
"""Extract the last AI message text from a LangGraph runs.wait result.
@@ -162,7 +155,6 @@ def _extract_response_text(result: dict | list) -> str:
Handles special cases:
- Regular AI text responses
- Clarification interrupts (``ask_clarification`` tool messages)
- Strips loop-detection warnings attached to tool-call AI messages
"""
if isinstance(result, list):
messages = result
@@ -192,12 +184,7 @@ def _extract_response_text(result: dict | list) -> str:
# Regular AI message with text content
if msg_type == "ai":
content = msg.get("content", "")
has_tool_calls = bool(msg.get("tool_calls"))
if isinstance(content, str) and content:
if has_tool_calls:
content = _strip_loop_warning_text(content)
if not content:
continue
return content
# content can be a list of content blocks
if isinstance(content, list):
@@ -208,8 +195,6 @@ def _extract_response_text(result: dict | list) -> str:
elif isinstance(block, str):
parts.append(block)
text = "".join(parts)
if has_tool_calls:
text = _strip_loop_warning_text(text)
if text:
return text
return ""
+79 -65
View File
@@ -4,22 +4,22 @@
`create_deerflow_agent` 通过 `RuntimeFeatures` 组装的完整 middleware 链(默认全开时):
| # | Middleware | `before_agent` | `before_model` | `after_model` | `after_agent` | `wrap_tool_call` | 主 Agent | Subagent | 来源 |
|---|-----------|:-:|:-:|:-:|:-:|:-:|:-:|:-:|------|
| 0 | ThreadDataMiddleware | ✓ | | | | | ✓ | ✓ | `sandbox` |
| 1 | UploadsMiddleware | ✓ | | | | | ✓ | ✗ | `sandbox` |
| 2 | SandboxMiddleware | ✓ | | | ✓ | | ✓ | ✓ | `sandbox` |
| 3 | DanglingToolCallMiddleware | | | | | | ✓ | ✗ | 始终开启 |
| 4 | GuardrailMiddleware | | | | | ✓ | ✓ | ✓ | *Phase 2 纳入* |
| 5 | ToolErrorHandlingMiddleware | | | | | ✓ | ✓ | ✓ | 始终开启 |
| 6 | SummarizationMiddleware | | | | | | ✓ | ✗ | `summarization` |
| 7 | TodoMiddleware | | | ✓ | | | ✓ | ✗ | `plan_mode` 参数 |
| 8 | TitleMiddleware | | | ✓ | | | ✓ | ✗ | `auto_title` |
| 9 | MemoryMiddleware | | | | ✓ | | ✓ | ✗ | `memory` |
| 10 | ViewImageMiddleware | | ✓ | | | | ✓ | ✗ | `vision` |
| 11 | SubagentLimitMiddleware | | | ✓ | | | ✓ | ✗ | `subagent` |
| 12 | LoopDetectionMiddleware | | | ✓ | | | ✓ | ✗ | 始终开启 |
| 13 | ClarificationMiddleware | | | | | | ✓ | ✗ | 始终最后 |
| # | Middleware | `before_agent` | `before_model` | `after_model` | `after_agent` | `wrap_model_call` | `wrap_tool_call` | 主 Agent | Subagent | 来源 |
|---|-----------|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|------|
| 0 | ThreadDataMiddleware | ✓ | | | | | | ✓ | ✓ | `sandbox` |
| 1 | UploadsMiddleware | ✓ | | | | | | ✓ | ✗ | `sandbox` |
| 2 | SandboxMiddleware | ✓ | | | ✓ | | | ✓ | ✓ | `sandbox` |
| 3 | DanglingToolCallMiddleware | | | | | | | ✓ | ✗ | 始终开启 |
| 4 | GuardrailMiddleware | | | | | | ✓ | ✓ | ✓ | *Phase 2 纳入* |
| 5 | ToolErrorHandlingMiddleware | | | | | | ✓ | ✓ | ✓ | 始终开启 |
| 6 | SummarizationMiddleware | | | | | | | ✓ | ✗ | `summarization` |
| 7 | TodoMiddleware | | | ✓ | | ✓ | | ✓ | ✗ | `plan_mode` 参数 |
| 8 | TitleMiddleware | | | ✓ | | | | ✓ | ✗ | `auto_title` |
| 9 | MemoryMiddleware | | | | ✓ | | | ✓ | ✗ | `memory` |
| 10 | ViewImageMiddleware | | ✓ | | | | | ✓ | ✗ | `vision` |
| 11 | SubagentLimitMiddleware | | | ✓ | | | | ✓ | ✗ | `subagent` |
| 12 | LoopDetectionMiddleware | | | ✓ | ✓ | ✓ | | ✓ | ✗ | 始终开启 |
| 13 | ClarificationMiddleware | | | | | | | ✓ | ✗ | 始终最后 |
主 agent **14 个** middleware`make_lead_agent`),subagent **4 个**ThreadData、Sandbox、Guardrail、ToolErrorHandling)。`create_deerflow_agent` Phase 1 实现 **13 个**(Guardrail 仅支持自定义实例,无内置默认)。
@@ -35,7 +35,7 @@ graph TB
subgraph BA ["<b>before_agent</b> 正序 0→N"]
direction TB
TD["[0] ThreadData<br/>创建线程目录"] --> UL["[1] Uploads<br/>扫描上传文件"] --> SB["[2] Sandbox<br/>获取沙箱"]
TD["[0] ThreadData<br/>创建线程目录"] --> UL["[1] Uploads<br/>扫描上传文件"] --> SB["[2] Sandbox<br/>获取沙箱"] --> LD_BA["[12] LoopDetection<br/>清理 stale warning"]
end
subgraph BM ["<b>before_model</b> 正序 0→N"]
@@ -43,34 +43,42 @@ graph TB
VI["[10] ViewImage<br/>注入图片 base64"]
end
SB --> VI
VI --> M["<b>MODEL</b>"]
subgraph WM ["<b>wrap_model_call</b>"]
direction TB
DTC_WM["[3] DanglingToolCall<br/>补悬空 ToolMessage"] --> LD_WM["[12] LoopDetection<br/>注入当前 run warning"]
end
LD_BA --> VI
VI --> DTC_WM
LD_WM --> M["<b>MODEL</b>"]
subgraph AM ["<b>after_model</b> 反序 N→0"]
direction TB
CL["[13] Clarification<br/>拦截 ask_clarification"] --> LD["[12] LoopDetection<br/>检测循环"] --> SL["[11] SubagentLimit<br/>截断多余 task"] --> TI["[8] Title<br/>生成标题"] --> SM["[6] Summarization<br/>上下文压缩"] --> DTC["[3] DanglingToolCall<br/>补缺失 ToolMessage"]
LD["[12] LoopDetection<br/>检测循环/排队 warning"] --> SL["[11] SubagentLimit<br/>截断多余 task"] --> TI["[8] Title<br/>生成标题"]
end
M --> CL
M --> LD
subgraph AA ["<b>after_agent</b> 反序 N→0"]
direction TB
SBR["[2] Sandbox<br/>释放沙箱"] --> MEM["[9] Memory<br/>入队记忆"]
LD_CLEAN["[12] LoopDetection<br/>清理 pending warning"] --> MEM["[9] Memory<br/>入队记忆"] --> SBR["[2] Sandbox<br/>释放沙箱"]
end
DTC --> SBR
MEM --> END(["response"])
TI --> LD_CLEAN
SBR --> END(["response"])
classDef beforeNode fill:#a0a8b5,stroke:#636b7a,color:#2d3239
classDef modelNode fill:#b5a8a0,stroke:#7a6b63,color:#2d3239
classDef wrapModelNode fill:#a8a0b5,stroke:#6b637a,color:#2d3239
classDef afterModelNode fill:#b5a0a8,stroke:#7a636b,color:#2d3239
classDef afterAgentNode fill:#a0b5a8,stroke:#637a6b,color:#2d3239
classDef terminalNode fill:#a8b5a0,stroke:#6b7a63,color:#2d3239
class TD,UL,SB,VI beforeNode
class TD,UL,SB,LD_BA,VI beforeNode
class DTC_WM,LD_WM wrapModelNode
class M modelNode
class CL,LD,SL,TI,SM,DTC afterModelNode
class SBR,MEM afterAgentNode
class LD,SL,TI afterModelNode
class LD_CLEAN,SBR,MEM afterAgentNode
class START,END terminalNode
```
@@ -82,13 +90,12 @@ sequenceDiagram
participant TD as ThreadDataMiddleware
participant UL as UploadsMiddleware
participant SB as SandboxMiddleware
participant LD as LoopDetectionMiddleware
participant VI as ViewImageMiddleware
participant DTC as DanglingToolCallMiddleware
participant M as MODEL
participant CL as ClarificationMiddleware
participant SL as SubagentLimitMiddleware
participant TI as TitleMiddleware
participant SM as SummarizationMiddleware
participant DTC as DanglingToolCallMiddleware
participant MEM as MemoryMiddleware
U ->> TD: invoke
@@ -103,19 +110,26 @@ sequenceDiagram
activate SB
Note right of SB: before_agent 获取沙箱
SB ->> VI: before_model
SB ->> LD: before_agent
activate LD
Note right of LD: before_agent 清理同 thread 旧 run 的 pending warning
LD ->> VI: before_model
activate VI
Note right of VI: before_model 注入图片 base64
VI ->> M: messages + tools
VI ->> DTC: wrap_model_call
activate DTC
Note right of DTC: wrap_model_call 补悬空 ToolMessage
DTC ->> LD: wrap_model_call
Note right of LD: wrap_model_call drain 当前 run warning 并追加到末尾
LD ->> M: messages + tools
activate M
M -->> CL: AI response
M -->> LD: AI response
deactivate M
activate CL
Note right of CL: after_model 拦截 ask_clarification
CL -->> SL: after_model
deactivate CL
Note right of LD: after_model 检测循环;warning 入队,hard-stop 清 tool_calls
LD -->> SL: after_model
deactivate LD
activate SL
Note right of SL: after_model 截断多余 task
@@ -124,22 +138,18 @@ sequenceDiagram
activate TI
Note right of TI: after_model 生成标题
TI -->> SM: after_model
TI -->> DTC: done
deactivate TI
activate SM
Note right of SM: after_model 上下文压缩
SM -->> DTC: after_model
deactivate SM
activate DTC
Note right of DTC: after_model 补缺失 ToolMessage
DTC -->> VI: done
deactivate DTC
VI -->> SB: done
deactivate VI
Note right of LD: after_agent 清理当前 run 未消费 warning
Note right of MEM: after_agent 入队记忆
Note right of SB: after_agent 释放沙箱
SB -->> UL: done
deactivate SB
@@ -147,8 +157,6 @@ sequenceDiagram
UL -->> TD: done
deactivate UL
Note right of MEM: after_agent 入队记忆
TD -->> U: response
deactivate TD
```
@@ -224,12 +232,12 @@ sequenceDiagram
participant TD as ThreadData
participant UL as Uploads
participant SB as Sandbox
participant LD as LoopDetection
participant VI as ViewImage
participant DTC as DanglingToolCall
participant M as MODEL
participant CL as Clarification
participant SL as SubagentLimit
participant TI as Title
participant SM as Summarization
participant MEM as Memory
U ->> TD: invoke
@@ -238,34 +246,40 @@ sequenceDiagram
Note right of UL: before_agent 扫描文件
UL ->> SB: .
Note right of SB: before_agent 获取沙箱
SB ->> LD: .
Note right of LD: before_agent 清理 stale pending warning
loop 每轮对话(tool call 循环)
SB ->> VI: .
Note right of VI: before_model 注入图片
VI ->> M: messages + tools
M -->> CL: AI response
Note right of CL: after_model 拦截 ask_clarification
CL -->> SL: .
VI ->> DTC: .
Note right of DTC: wrap_model_call 补悬空工具结果
DTC ->> LD: .
Note right of LD: wrap_model_call 注入当前 run warning
LD ->> M: messages + tools
M -->> LD: AI response
Note right of LD: after_model 检测循环/排队 warning
LD -->> SL: .
Note right of SL: after_model 截断多余 task
SL -->> TI: .
Note right of TI: after_model 生成标题
TI -->> SM: .
Note right of SM: after_model 上下文压缩
end
Note right of SB: after_agent 释放沙箱
SB -->> MEM: .
Note right of LD: after_agent 清理当前 run pending warning
LD -->> MEM: .
Note right of MEM: after_agent 入队记忆
MEM -->> U: response
MEM -->> SB: .
Note right of SB: after_agent 释放沙箱
SB -->> U: response
```
> [!warning] 不是洋葱
> 14 个 middleware 中只有 SandboxMiddleware before/after 对称(获取/释放)。其余都是单向的:要么只在 `before_*` 做事,要么只在 `after_*` 做事。`before_agent` / `after_agent` 只跑一次,`before_model` / `after_model` 每轮循环都跑。
> 大部分 middleware 只用一个阶段。SandboxMiddleware 使用 `before_agent`/`after_agent` 做资源获取/释放;LoopDetectionMiddleware 也使用这两个钩子,但用途是清理 run-scoped pending warnings,不是资源生命周期对称。`before_agent` / `after_agent` 只跑一次,`before_model` / `after_model` / `wrap_model_call` 每轮循环都跑。
硬依赖只有 2 处:
1. **ThreadData 在 Sandbox 之前** — sandbox 需要线程目录
2. **Clarification 在列表最后**`after_model` 反序时最先执行,第一个拦截 `ask_clarification`
2. **Clarification 在列表最后**`wrap_tool_call` 处理 `ask_clarification` 时优先拦截,并通过 `Command(goto=END)` 中断执行
### 结论
@@ -273,19 +287,19 @@ sequenceDiagram
|---|---|---|
| 每个 middleware | before + after 对称 | 大多只用一个钩子 |
| 激活条 | 嵌套(外长内短) | 不嵌套(串行) |
| 反序的意义 | 清理与初始化配对 | 影响 after_model 的执行优先级 |
| 反序的意义 | 清理与初始化配对 | 影响 `after_model` / `after_agent` 的执行优先级 |
| 典型例子 | Auth: 校验 token / 清理上下文 | ThreadData: 只创建目录,没有清理 |
## 关键设计点
### ClarificationMiddleware 为什么在列表最后?
位置最后 = `after_model` 最先执行。它需要**第一个**看到 model 输出,检查是否有 `ask_clarification` tool call。如果有,立即中断(`Command(goto=END)`),后续 middleware 的 `after_model` 不再执行。
位置最后使它在工具调用包装链中优先拦截 `ask_clarification`。如果命中,它返回 `Command(goto=END)`,把格式化后的澄清问题写成 `ToolMessage` 并中断执行。
### SandboxMiddleware 的对称性
`before_agent`(正序第 3 个)获取沙箱,`after_agent`(反序第 1 个)释放沙箱。外层进入 → 外层退出,天然的洋葱对称。
### 大部分 middleware 只用一个钩子
### LoopDetectionMiddleware 为什么同时用多个钩子
14 个 middleware 中,只有 SandboxMiddleware 同时用了 `before_agent` + `after_agent`(获取/释放)。其余都只在一个阶段执行。洋葱模型的反序特性主要影响 `after_model` 阶段的执行顺序
`after_model` 只做检测:重复工具调用达到 warning 阈值时,把 warning 放入 `(thread_id, run_id)` 作用域的 pending 队列。真正注入发生在下一次 `wrap_model_call`:此时上一轮 `AIMessage(tool_calls)` 对应的 `ToolMessage` 已经在请求里,warning 追加在末尾,不会破坏 OpenAI/Moonshot 的 tool-call pairing。`before_agent` 清理同一 thread 下旧 run 的残留 warning`after_agent` 清理当前 run 没被消费的 warning
@@ -6,10 +6,36 @@ arguments indefinitely until the recursion limit kills the run.
Detection strategy:
1. After each model response, hash the tool calls (name + args).
2. Track recent hashes in a sliding window.
3. If the same hash appears >= warn_threshold times, inject a
"you are repeating yourself — wrap up" system message (once per hash).
3. If the same hash appears >= warn_threshold times, queue a
"you are repeating yourself — wrap up" warning for the current
thread/run. The warning is **injected at the next model call** (in
``wrap_model_call``) as a ``HumanMessage`` appended to the message
list, *after* all ToolMessage responses to the previous
AIMessage(tool_calls).
4. If it appears >= hard_limit times, strip all tool_calls from the
response so the agent is forced to produce a final text answer.
Why the warning is injected at ``wrap_model_call`` instead of
``after_model``:
``after_model`` fires immediately after the model emits an
``AIMessage`` that may carry ``tool_calls``. The tools node has not
run yet, so no matching ``ToolMessage`` exists in the history. Any
message we add here lands *between* the assistant's tool_calls and
their responses. OpenAI/Moonshot reject the next request with
``"tool_call_ids did not have response messages"`` because their
validators require the assistant's tool_calls to be followed
immediately by tool messages. Anthropic also disallows mid-stream
``SystemMessage``. By deferring the warning to ``wrap_model_call``,
every prior ToolMessage is already present in the request's message
list and the warning is appended at the end — pairing intact, no
``AIMessage`` semantics are mutated.
Queued warnings are intentionally transient. If a run ends before the
next model request drains a queued warning, ``after_agent`` drops it
instead of carrying it into a later invocation for the same thread. The
hard-stop path still forces termination when the configured safety limit
is reached.
"""
from __future__ import annotations
@@ -19,11 +45,14 @@ import json
import logging
import threading
from collections import OrderedDict, defaultdict
from collections.abc import Awaitable, Callable
from copy import deepcopy
from typing import TYPE_CHECKING, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
if TYPE_CHECKING:
@@ -38,6 +67,7 @@ _DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
_DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type
_DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type
_MAX_PENDING_WARNINGS_PER_RUN = 4
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
@@ -195,6 +225,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
self._warned: dict[str, set[str]] = defaultdict(set)
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
# Per-thread/run queue of warnings to inject at the next model call.
# Populated by ``after_model`` (detection) and drained by
# ``wrap_model_call`` (injection); see module docstring.
self._pending_warnings: dict[tuple[str, str], list[str]] = defaultdict(list)
self._pending_warning_touch_order: OrderedDict[tuple[str, str], None] = OrderedDict()
self._max_pending_warning_keys = max(1, self.max_tracked_threads * 2)
@classmethod
def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware:
@@ -213,9 +249,20 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
"""Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id:
return thread_id
return str(thread_id)
return "default"
def _get_run_id(self, runtime: Runtime) -> str:
"""Extract run_id from runtime context for per-run warning scoping."""
run_id = runtime.context.get("run_id") if runtime.context else None
if run_id:
return str(run_id)
return "default"
def _pending_key(self, runtime: Runtime) -> tuple[str, str]:
"""Return the pending-warning key for the current thread/run."""
return self._get_thread_id(runtime), self._get_run_id(runtime)
def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit.
@@ -226,8 +273,52 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
self._warned.pop(evicted_id, None)
self._tool_freq.pop(evicted_id, None)
self._tool_freq_warned.pop(evicted_id, None)
for key in list(self._pending_warnings):
if key[0] == evicted_id:
self._drop_pending_warning_key_locked(key)
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
def _drop_pending_warning_key_locked(self, key: tuple[str, str]) -> None:
"""Drop all pending-warning bookkeeping for one thread/run key.
Must be called while holding self._lock.
"""
self._pending_warnings.pop(key, None)
self._pending_warning_touch_order.pop(key, None)
def _touch_pending_warning_key_locked(self, key: tuple[str, str]) -> None:
"""Mark a pending-warning key as recently used.
Must be called while holding self._lock.
"""
self._pending_warning_touch_order[key] = None
self._pending_warning_touch_order.move_to_end(key)
def _prune_pending_warning_state_locked(self, protected_key: tuple[str, str]) -> None:
"""Cap pending-warning state across abnormal or concurrent runs.
Must be called while holding self._lock.
"""
overflow = len(self._pending_warning_touch_order) - self._max_pending_warning_keys
if overflow <= 0:
return
candidates = [key for key in self._pending_warning_touch_order if key != protected_key]
for key in candidates[:overflow]:
self._drop_pending_warning_key_locked(key)
def _queue_pending_warning(self, runtime: Runtime, warning: str) -> None:
"""Queue one transient warning for the current thread/run with caps."""
pending_key = self._pending_key(runtime)
with self._lock:
warnings = self._pending_warnings[pending_key]
if warning not in warnings:
warnings.append(warning)
if len(warnings) > _MAX_PENDING_WARNINGS_PER_RUN:
del warnings[: len(warnings) - _MAX_PENDING_WARNINGS_PER_RUN]
self._touch_pending_warning_key_locked(pending_key)
self._prune_pending_warning_state_locked(protected_key=pending_key)
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
"""Track tool calls and check for loops.
@@ -268,6 +359,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
if len(history) > self.window_size:
history[:] = history[-self.window_size :]
warned_hashes = self._warned.get(thread_id)
if warned_hashes is not None:
warned_hashes.intersection_update(history)
if not warned_hashes:
self._warned.pop(thread_id, None)
count = history.count(call_hash)
tool_names = [tc.get("name", "?") for tc in tool_calls]
@@ -381,7 +478,10 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
warning, hard_stop = self._track_and_check(state, runtime)
if hard_stop:
# Strip tool_calls from the last AIMessage to force text output
# Strip tool_calls from the last AIMessage to force text output.
# Once tool_calls are stripped, the AIMessage no longer requires
# matching ToolMessage responses, so mutating it in place here
# is safe for OpenAI/Moonshot pairing validators.
messages = state.get("messages", [])
last_msg = messages[-1]
content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG)
@@ -389,33 +489,48 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
return {"messages": [stripped_msg]}
if warning:
# WORKAROUND for v2.0-m1 — see #2724.
#
# Append the warning to the AIMessage content instead of
# injecting a separate HumanMessage. Inserting any non-tool
# message between an AIMessage(tool_calls=...) and its
# ToolMessage responses breaks OpenAI/Moonshot strict pairing
# validation ("tool_call_ids did not have response messages")
# because the tools node has not run yet at after_model time.
# tool_calls are preserved so the tools node still executes.
#
# This is a temporary mitigation: mutating an existing
# AIMessage to carry framework-authored text leaks loop-warning
# text into downstream consumers (MemoryMiddleware fact
# extraction, TitleMiddleware, telemetry, model replay) as if
# the model said it. The proper fix is to defer warning
# injection from after_model to wrap_model_call so every prior
# ToolMessage is already in the request — see RFC #2517 (which
# lists "loop intervention does not leave invalid
# tool-call/tool-message state" as acceptance criteria) and
# the prototype on `fix/loop-detection-tool-call-pairing`.
messages = state.get("messages", [])
last_msg = messages[-1]
patched_msg = last_msg.model_copy(update={"content": self._append_text(last_msg.content, warning)})
return {"messages": [patched_msg]}
# Defer injection to the next model call. We must NOT alter the
# AIMessage(tool_calls=...) here (would put framework words in
# the model's mouth, polluting downstream consumers like
# MemoryMiddleware), nor insert a separate non-tool message
# (would break OpenAI/Moonshot tool-call pairing because the
# tools node has not produced ToolMessage responses yet). The
# warning is delivered via ``wrap_model_call`` below.
self._queue_pending_warning(runtime, warning)
return None
return None
def _clear_other_run_pending_warnings(self, runtime: Runtime) -> None:
"""Drop stale pending warnings for previous runs in this thread."""
thread_id, current_run_id = self._pending_key(runtime)
with self._lock:
for key in list(self._pending_warnings):
if key[0] == thread_id and key[1] != current_run_id:
self._drop_pending_warning_key_locked(key)
def _clear_current_run_pending_warnings(self, runtime: Runtime) -> None:
"""Drop pending warnings owned by the current thread/run."""
pending_key = self._pending_key(runtime)
with self._lock:
self._drop_pending_warning_key_locked(pending_key)
@staticmethod
def _format_warning_message(warnings: list[str]) -> str:
"""Merge pending warnings into one prompt message."""
deduped = list(dict.fromkeys(warnings))
return "\n\n".join(deduped)
@override
def before_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
self._clear_other_run_pending_warnings(runtime)
return None
@override
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
self._clear_other_run_pending_warnings(runtime)
return None
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
@@ -424,6 +539,59 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
@override
def after_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
self._clear_current_run_pending_warnings(runtime)
return None
@override
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
self._clear_current_run_pending_warnings(runtime)
return None
def _drain_pending_warnings(self, runtime: Runtime) -> list[str]:
"""Pop and return all queued warnings for *runtime*'s thread/run."""
pending_key = self._pending_key(runtime)
with self._lock:
warnings = self._pending_warnings.pop(pending_key, [])
self._pending_warning_touch_order.pop(pending_key, None)
return warnings
def _augment_request(self, request: ModelRequest) -> ModelRequest:
"""Append queued loop warnings (if any) to the outgoing message list.
The warning is placed *after* every existing message, including the
ToolMessage responses to the previous AIMessage(tool_calls). This
keeps ``assistant tool_calls -> tool_messages`` pairing intact for
OpenAI/Moonshot, avoids the Anthropic mid-stream SystemMessage
restriction (we use HumanMessage), and never mutates an existing
AIMessage.
"""
warnings = self._drain_pending_warnings(request.runtime)
if not warnings:
return request
new_messages = [
*request.messages,
HumanMessage(content=self._format_warning_message(warnings), name="loop_warning"),
]
return request.override(messages=new_messages)
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(self._augment_request(request))
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
return await handler(self._augment_request(request))
def reset(self, thread_id: str | None = None) -> None:
"""Clear tracking state. If thread_id given, clear only that thread."""
with self._lock:
@@ -432,8 +600,13 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
self._warned.pop(thread_id, None)
self._tool_freq.pop(thread_id, None)
self._tool_freq_warned.pop(thread_id, None)
for key in list(self._pending_warnings):
if key[0] == thread_id:
self._drop_pending_warning_key_locked(key)
else:
self._history.clear()
self._warned.clear()
self._tool_freq.clear()
self._tool_freq_warned.clear()
self._pending_warnings.clear()
self._pending_warning_touch_order.clear()
-31
View File
@@ -372,37 +372,6 @@ class TestExtractResponseText:
# Should return "" (no text in current turn), NOT "Hi there!" from previous turn
assert _extract_response_text(result) == ""
def test_does_not_publish_loop_warning_on_tool_calling_ai_message(self):
"""Loop-detection warning text on a tool-calling AI message is middleware-authored."""
from app.channels.manager import _extract_response_text
result = {
"messages": [
{"type": "human", "content": "search the repo"},
{
"type": "ai",
"content": "[LOOP DETECTED] You are repeating the same tool calls.",
"tool_calls": [{"name": "grep", "args": {"pattern": "TODO"}, "id": "call_1"}],
},
]
}
assert _extract_response_text(result) == ""
def test_preserves_visible_text_when_stripping_loop_warning(self):
from app.channels.manager import _extract_response_text
result = {
"messages": [
{"type": "human", "content": "prepare the report"},
{
"type": "ai",
"content": "Here is the report.\n\n[LOOP DETECTED] You are repeating the same tool calls.",
"tool_calls": [{"name": "present_files", "args": {"filepaths": ["/mnt/user-data/outputs/report.md"]}, "id": "call_1"}],
},
]
}
assert _extract_response_text(result) == "Here is the report."
# ---------------------------------------------------------------------------
# ChannelManager tests
+412 -82
View File
@@ -1,24 +1,94 @@
"""Tests for LoopDetectionMiddleware."""
import copy
from collections import OrderedDict
from typing import Any
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, SystemMessage
import pytest
from langchain.agents import create_agent
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import tool as as_tool
from pydantic import PrivateAttr
from deerflow.agents.middlewares.loop_detection_middleware import (
_HARD_STOP_MSG,
_MAX_PENDING_WARNINGS_PER_RUN,
LoopDetectionMiddleware,
_hash_tool_calls,
)
def _make_runtime(thread_id="test-thread"):
def _make_runtime(thread_id="test-thread", run_id="test-run"):
"""Build a minimal Runtime mock with context."""
runtime = MagicMock()
runtime.context = {"thread_id": thread_id}
runtime.context = {"thread_id": thread_id, "run_id": run_id}
return runtime
def _pending_key(thread_id="test-thread", run_id="test-run"):
return (thread_id, run_id)
def _make_request(messages, runtime):
"""Build a minimal ModelRequest stand-in for wrap_model_call tests."""
request = MagicMock()
request.messages = list(messages)
request.runtime = runtime
request.override = lambda **updates: _override_request(request, updates)
return request
def _override_request(request, updates):
"""Mimic ModelRequest.override(): return a copy with fields replaced."""
new = MagicMock()
new.messages = updates.get("messages", request.messages)
new.runtime = updates.get("runtime", request.runtime)
new.override = lambda **u: _override_request(new, u)
return new
def _capture_handler():
"""Build a sync handler that records the request it was called with."""
captured: list = []
def handler(req):
captured.append(req)
return MagicMock()
return captured, handler
class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel):
"""Fake chat model that records each model request's messages."""
_seen_messages: list[list[Any]] = PrivateAttr(default_factory=list)
@property
def seen_messages(self) -> list[list[Any]]:
return self._seen_messages
def bind_tools(
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self._seen_messages.append(list(messages))
return super()._generate(
messages,
stop=stop,
run_manager=run_manager,
**kwargs,
)
def _make_state(tool_calls=None, content=""):
"""Build a minimal AgentState dict with an AIMessage.
@@ -138,7 +208,15 @@ class TestLoopDetection:
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
def test_warn_at_threshold(self):
def test_warn_at_threshold_queues_but_does_not_mutate_state(self):
"""At warn threshold, ``after_model`` enqueues but returns None.
Detection observes the just-emitted AIMessage(tool_calls=...). The
tools node hasn't run yet, so injecting any non-tool message here
would split the assistant's tool_calls from their ToolMessage
responses and break OpenAI/Moonshot pairing. The warning is
delivered later from ``wrap_model_call``.
"""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=5)
runtime = _make_runtime()
call = [_bash_call("ls")]
@@ -146,44 +224,150 @@ class TestLoopDetection:
for _ in range(2):
mw._apply(_make_state(tool_calls=call), runtime)
# Third identical call triggers warning. The warning is appended to
# the AIMessage content (tool_calls preserved) — never inserted as a
# separate HumanMessage between the AIMessage(tool_calls) and its
# ToolMessage responses, which would break OpenAI/Moonshot strict
# tool-call pairing validation.
# Third identical call triggers warning detection.
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msgs = result["messages"]
assert len(msgs) == 1
assert isinstance(msgs[0], AIMessage)
assert len(msgs[0].tool_calls) == len(call)
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
assert "LOOP DETECTED" in msgs[0].content
# Detection must not mutate state — the AIMessage with tool_calls is
# left untouched so the tools node runs normally.
assert result is None
# ...but a warning is queued for the next model call.
assert mw._pending_warnings[_pending_key()]
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key()][0]
def test_warn_does_not_break_tool_call_pairing(self):
"""Regression: the warn branch must NOT inject a non-tool message
after an AIMessage(tool_calls=...). Moonshot/OpenAI reject the next
request with 'tool_call_ids did not have response messages' if any
non-tool message is wedged between the AIMessage and its ToolMessage
responses. See #2029.
def test_warn_injected_at_next_model_call(self):
"""``wrap_model_call`` appends a HumanMessage(loop_warning) to the
outgoing messages — *after* every existing message — so that the
AIMessage(tool_calls=...) -> ToolMessage(...) pairing stays intact.
"""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(2):
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msgs = result["messages"]
assert len(msgs) == 1
assert isinstance(msgs[0], AIMessage)
assert len(msgs[0].tool_calls) == len(call)
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
# Build the messages the agent runtime would assemble for the next
# turn: prior AIMessage(tool_calls), its ToolMessage responses, ...
ai_msg = AIMessage(content="", tool_calls=call)
tool_msg = ToolMessage(content="ok", tool_call_id=call[0]["id"], name="bash")
request = _make_request([ai_msg, tool_msg], runtime)
def test_warn_only_injected_once(self):
"""Warning for the same hash should only be injected once per thread."""
captured, handler = _capture_handler()
mw.wrap_model_call(request, handler)
sent = captured[0].messages
# AIMessage and ToolMessage stay in order, untouched.
assert sent[0] is ai_msg
assert sent[1] is tool_msg
# HumanMessage(warning) appears AFTER the ToolMessage — pairing intact.
assert isinstance(sent[2], HumanMessage)
assert sent[2].name == "loop_warning"
assert "LOOP DETECTED" in sent[2].content
def test_warn_queue_drained_after_injection(self):
"""A queued warning must be emitted exactly once per detection event."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
request = _make_request([AIMessage(content="hi")], runtime)
captured, handler = _capture_handler()
# First call: warning is appended.
mw.wrap_model_call(request, handler)
first = captured[0].messages
assert any(isinstance(m, HumanMessage) for m in first)
# Subsequent call without new detection: no warning re-emitted.
request2 = _make_request([AIMessage(content="hi")], runtime)
mw.wrap_model_call(request2, handler)
second = captured[1].messages
assert not any(isinstance(m, HumanMessage) for m in second)
def test_warn_queue_scoped_by_run_id(self):
"""A warning queued for one run must not be injected into another run."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime_a = _make_runtime(run_id="run-A")
runtime_b = _make_runtime(run_id="run-B")
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime_a)
request_b = _make_request([AIMessage(content="hi")], runtime_b)
captured, handler = _capture_handler()
mw.wrap_model_call(request_b, handler)
assert not any(isinstance(m, HumanMessage) for m in captured[0].messages)
assert mw._pending_warnings.get(_pending_key(run_id="run-A"))
request_a = _make_request([AIMessage(content="hi")], runtime_a)
mw.wrap_model_call(request_a, handler)
assert any(isinstance(message, HumanMessage) and message.name == "loop_warning" for message in captured[1].messages)
def test_missing_run_id_uses_default_pending_scope(self):
"""When runtime has no run_id, warning handling falls back to the default run scope."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = MagicMock()
runtime.context = {"thread_id": "test-thread"}
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
assert mw._pending_warnings.get(_pending_key(run_id="default"))
request = _make_request([AIMessage(content="hi")], runtime)
captured, handler = _capture_handler()
mw.wrap_model_call(request, handler)
loop_warnings = [message for message in captured[0].messages if isinstance(message, HumanMessage) and message.name == "loop_warning"]
assert len(loop_warnings) == 1
assert "LOOP DETECTED" in loop_warnings[0].content
assert not mw._pending_warnings.get(_pending_key(run_id="default"))
def test_before_agent_clears_stale_pending_warnings_for_thread(self):
"""Starting a new run drops stale warnings from prior runs in the same thread."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime_a = _make_runtime(run_id="run-A")
runtime_b = _make_runtime(run_id="run-B")
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime_a)
assert mw._pending_warnings.get(_pending_key(run_id="run-A"))
mw.before_agent({"messages": []}, runtime_b)
assert not mw._pending_warnings.get(_pending_key(run_id="run-A"))
def test_after_agent_clears_current_run_pending_warnings(self):
"""Run cleanup should drop warnings that never reached wrap_model_call."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
assert mw._pending_warnings.get(_pending_key())
mw.after_agent({"messages": []}, runtime)
assert not mw._pending_warnings.get(_pending_key())
def test_multiple_pending_warnings_are_merged_into_one_message(self):
"""Edge-case drains should produce one loop_warning prompt message."""
mw = LoopDetectionMiddleware()
runtime = _make_runtime()
mw._pending_warnings[_pending_key()] = ["first warning", "second warning", "first warning"]
request = _make_request([AIMessage(content="hi")], runtime)
captured, handler = _capture_handler()
mw.wrap_model_call(request, handler)
loop_warnings = [message for message in captured[0].messages if isinstance(message, HumanMessage) and message.name == "loop_warning"]
assert len(loop_warnings) == 1
assert loop_warnings[0].content == "first warning\n\nsecond warning"
def test_warn_only_queued_once_per_hash(self):
"""Same hash repeated past the threshold should warn only once."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
@@ -192,14 +376,13 @@ class TestLoopDetection:
for _ in range(2):
mw._apply(_make_state(tool_calls=call), runtime)
# Third — warning injected
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
# Third — warning queued
mw._apply(_make_state(tool_calls=call), runtime)
assert len(mw._pending_warnings[_pending_key()]) == 1
# Fourth — warning already injected, should return None
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
# Fourth — already warned for this hash, no additional enqueue.
mw._apply(_make_state(tool_calls=call), runtime)
assert len(mw._pending_warnings[_pending_key()]) == 1
def test_hard_stop_at_limit(self):
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
@@ -257,6 +440,7 @@ class TestLoopDetection:
mw.reset()
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
assert not mw._pending_warnings.get(_pending_key())
def test_non_ai_message_ignored(self):
mw = LoopDetectionMiddleware()
@@ -283,15 +467,16 @@ class TestLoopDetection:
# One call on thread B
mw._apply(_make_state(tool_calls=call), runtime_b)
# Second call on thread A — triggers warning (2 >= warn_threshold)
result = mw._apply(_make_state(tool_calls=call), runtime_a)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
# Second call on thread A — queues warning under thread-A only.
mw._apply(_make_state(tool_calls=call), runtime_a)
assert mw._pending_warnings.get(_pending_key("thread-A"))
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-A")][0]
assert not mw._pending_warnings.get(_pending_key("thread-B"))
# Second call on thread B — also triggers (independent tracking)
result = mw._apply(_make_state(tool_calls=call), runtime_b)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
# Second call on thread B — independent queue.
mw._apply(_make_state(tool_calls=call), runtime_b)
assert mw._pending_warnings.get(_pending_key("thread-B"))
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-B")][0]
def test_lru_eviction(self):
"""Old threads should be evicted when max_tracked_threads is exceeded."""
@@ -313,6 +498,55 @@ class TestLoopDetection:
assert "thread-new" in mw._history
assert len(mw._history) == 3
def test_warned_hashes_are_pruned_to_sliding_window(self):
"""A long-lived thread should not keep every historical warned hash."""
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=100, window_size=4)
runtime = _make_runtime()
for i in range(12):
call = [_bash_call(f"cmd_{i}")]
mw._apply(_make_state(tool_calls=call), runtime)
mw._apply(_make_state(tool_calls=call), runtime)
assert len(mw._history["test-thread"]) <= 4
assert set(mw._warned["test-thread"]).issubset(set(mw._history["test-thread"]))
assert len(mw._warned["test-thread"]) <= 4
def test_pending_warning_keys_are_capped(self):
"""Abnormal same-thread runs cannot grow pending-warning keys forever."""
mw = LoopDetectionMiddleware(warn_threshold=2, max_tracked_threads=2)
for i in range(10):
runtime = _make_runtime(thread_id="same-thread", run_id=f"run-{i}")
mw._queue_pending_warning(runtime, f"warning-{i}")
assert len(mw._pending_warnings) == mw._max_pending_warning_keys
assert len(mw._pending_warning_touch_order) == mw._max_pending_warning_keys
assert _pending_key("same-thread", "run-9") in mw._pending_warnings
def test_pending_warning_list_is_capped_and_deduped(self):
"""One run cannot accumulate an unbounded warning list."""
mw = LoopDetectionMiddleware()
runtime = _make_runtime()
for i in range(_MAX_PENDING_WARNINGS_PER_RUN + 4):
mw._queue_pending_warning(runtime, f"warning-{i}")
mw._queue_pending_warning(runtime, f"warning-{_MAX_PENDING_WARNINGS_PER_RUN + 3}")
warnings = mw._pending_warnings[_pending_key()]
assert len(warnings) == _MAX_PENDING_WARNINGS_PER_RUN
assert warnings == [f"warning-{i}" for i in range(4, _MAX_PENDING_WARNINGS_PER_RUN + 4)]
def test_pending_warning_touch_order_cleared_with_pending_key(self):
mw = LoopDetectionMiddleware()
runtime = _make_runtime()
mw._queue_pending_warning(runtime, "warning")
mw.after_agent({"messages": []}, runtime)
assert mw._pending_warnings == {}
assert mw._pending_warning_touch_order == OrderedDict()
def test_thread_safe_mutations(self):
"""Verify lock is used for mutations (basic structural test)."""
mw = LoopDetectionMiddleware()
@@ -331,6 +565,99 @@ class TestLoopDetection:
assert "default" in mw._history
class TestLoopDetectionAgentGraphIntegration:
def test_loop_warning_is_transient_in_real_agent_graph(self):
"""after_model queues the warning; wrap_model_call injects it request-only."""
@as_tool
def bash(command: str) -> str:
"""Run a fake shell command."""
return f"ran: {command}"
repeated_calls = [[{"name": "bash", "id": f"call_ls_{i}", "args": {"command": "ls"}}] for i in range(3)]
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
model = _CapturingFakeMessagesListChatModel(
responses=[
AIMessage(content="", tool_calls=repeated_calls[0]),
AIMessage(content="", tool_calls=repeated_calls[1]),
AIMessage(content="", tool_calls=repeated_calls[2]),
AIMessage(content="final answer"),
],
)
graph = create_agent(model=model, tools=[bash], middleware=[mw])
result = graph.invoke(
{"messages": [("user", "inspect the directory")]},
context={"thread_id": "integration-thread", "run_id": "integration-run"},
config={"recursion_limit": 20},
)
assert len(model.seen_messages) == 4
loop_warnings_by_call = [[message for message in messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] for messages in model.seen_messages]
assert loop_warnings_by_call[0] == []
assert loop_warnings_by_call[1] == []
assert loop_warnings_by_call[2] == []
assert len(loop_warnings_by_call[3]) == 1
assert "LOOP DETECTED" in loop_warnings_by_call[3][0].content
fourth_request = model.seen_messages[3]
assert isinstance(fourth_request[-2], ToolMessage)
assert fourth_request[-2].tool_call_id == "call_ls_2"
assert fourth_request[-1] is loop_warnings_by_call[3][0]
persisted_loop_warnings = [message for message in result["messages"] if isinstance(message, HumanMessage) and message.name == "loop_warning"]
assert persisted_loop_warnings == []
assert result["messages"][-1].content == "final answer"
assert mw._pending_warnings == {}
assert mw._pending_warning_touch_order == OrderedDict()
@pytest.mark.asyncio
async def test_loop_warning_is_transient_in_async_agent_graph(self):
"""awrap_model_call injects loop_warning request-only in async graph runs."""
@as_tool
async def bash(command: str) -> str:
"""Run a fake shell command."""
return f"ran: {command}"
repeated_calls = [[{"name": "bash", "id": f"call_async_ls_{i}", "args": {"command": "ls"}}] for i in range(3)]
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
model = _CapturingFakeMessagesListChatModel(
responses=[
AIMessage(content="", tool_calls=repeated_calls[0]),
AIMessage(content="", tool_calls=repeated_calls[1]),
AIMessage(content="", tool_calls=repeated_calls[2]),
AIMessage(content="async final answer"),
],
)
graph = create_agent(model=model, tools=[bash], middleware=[mw])
result = await graph.ainvoke(
{"messages": [("user", "inspect the directory asynchronously")]},
context={"thread_id": "async-integration-thread", "run_id": "async-integration-run"},
config={"recursion_limit": 20},
)
assert len(model.seen_messages) == 4
loop_warnings_by_call = [[message for message in messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] for messages in model.seen_messages]
assert loop_warnings_by_call[0] == []
assert loop_warnings_by_call[1] == []
assert loop_warnings_by_call[2] == []
assert len(loop_warnings_by_call[3]) == 1
assert "LOOP DETECTED" in loop_warnings_by_call[3][0].content
fourth_request = model.seen_messages[3]
assert isinstance(fourth_request[-2], ToolMessage)
assert fourth_request[-2].tool_call_id == "call_async_ls_2"
assert fourth_request[-1] is loop_warnings_by_call[3][0]
persisted_loop_warnings = [message for message in result["messages"] if isinstance(message, HumanMessage) and message.name == "loop_warning"]
assert persisted_loop_warnings == []
assert result["messages"][-1].content == "async final answer"
assert mw._pending_warnings == {}
assert mw._pending_warning_touch_order == OrderedDict()
class TestAppendText:
"""Unit tests for LoopDetectionMiddleware._append_text."""
@@ -507,33 +834,29 @@ class TestToolFrequencyDetection:
for i in range(4):
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
# 5th call to read_file (different file each time) triggers freq warning
# 5th call queues a per-tool-type frequency warning; state untouched.
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_4.py")]), runtime)
assert result is not None
msg = result["messages"][0]
# Warning is appended to the AIMessage content; tool_calls preserved
# so the tools node still runs and Moonshot/OpenAI tool-call pairing
# validation does not break.
assert isinstance(msg, AIMessage)
assert msg.tool_calls
assert "read_file" in msg.content
assert "LOOP DETECTED" in msg.content
assert result is None
queued = mw._pending_warnings.get(_pending_key(), [])
assert queued
assert "read_file" in queued[0]
assert "LOOP DETECTED" in queued[0]
def test_freq_warn_only_injected_once(self):
def test_freq_warn_only_queued_once(self):
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10)
runtime = _make_runtime()
for i in range(2):
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
# 3rd triggers warning
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
# 3rd queues a frequency warning.
mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert len(mw._pending_warnings[_pending_key()]) == 1
# 4th should not re-warn (already warned for read_file)
# 4th: same tool name, no additional enqueue.
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_3.py")]), runtime)
assert result is None
assert len(mw._pending_warnings[_pending_key()]) == 1
def test_freq_hard_stop_at_limit(self):
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=6)
@@ -565,10 +888,10 @@ class TestToolFrequencyDetection:
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
assert result is None
# 3rd read_file triggers (read_file count = 3)
# 3rd read_file triggers — warning is queued (state unchanged).
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert result is not None
assert "read_file" in result["messages"][0].content
assert result is None
assert "read_file" in mw._pending_warnings[_pending_key()][0]
def test_freq_reset_clears_state(self):
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10)
@@ -600,10 +923,10 @@ class TestToolFrequencyDetection:
assert "thread-A" not in mw._tool_freq
assert "thread-A" not in mw._tool_freq_warned
# thread-B state should still be intact — 3rd call triggers warn
# thread-B state should still be intact — 3rd call queues a warn.
result = mw._apply(_make_state(tool_calls=[self._read_call("/b_2.py")]), runtime_b)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
assert result is None
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-B")][0]
# thread-A restarted from 0 — should not trigger
result = mw._apply(_make_state(tool_calls=[self._read_call("/a_new.py")]), runtime_a)
@@ -623,10 +946,11 @@ class TestToolFrequencyDetection:
for i in range(2):
mw._apply(_make_state(tool_calls=[self._read_call(f"/other_{i}.py")]), runtime_b)
# 3rd call on thread A — triggers (count=3 for thread A only)
# 3rd call on thread A — queues a warning (count=3 for thread A only).
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime_a)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
assert result is None
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-A")][0]
assert not mw._pending_warnings.get(_pending_key("thread-B"))
def test_multi_tool_single_response_counted(self):
"""When a single response has multiple tool calls, each is counted."""
@@ -643,10 +967,10 @@ class TestToolFrequencyDetection:
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
# Response 3: 1 more → count = 5 → triggers warn
# Response 3: 1 more → count = 5 → queues warn.
result = mw._apply(_make_state(tool_calls=[self._read_call("/e.py")]), runtime)
assert result is not None
assert "read_file" in result["messages"][0].content
assert result is None
assert "read_file" in mw._pending_warnings[_pending_key()][0]
def test_override_tool_uses_override_thresholds(self):
"""A tool in tool_freq_overrides uses its own thresholds, not the global ones."""
@@ -674,10 +998,14 @@ class TestToolFrequencyDetection:
for i in range(2):
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
# 3rd read_file call hits global warn=3 (read_file has no override)
# 3rd read_file call hits global warn=3 (read_file has no override).
# Warning delivery is deferred to wrap_model_call so the just-emitted
# AIMessage(tool_calls=...) is not mutated before ToolMessages exist.
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
assert result is not None
assert "read_file" in result["messages"][0].content
assert result is None
queued = mw._pending_warnings.get(_pending_key(), [])
assert queued
assert "read_file" in queued[0]
def test_hash_detection_takes_priority(self):
"""Hash-based hard stop fires before frequency check for identical calls."""
@@ -736,11 +1064,13 @@ class TestFromConfig:
mw = LoopDetectionMiddleware.from_config(self._config())
assert mw._tool_freq_overrides == {}
def test_constructed_middleware_detects_loops(self):
def test_constructed_middleware_queues_loop_warning(self):
mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4))
runtime = _make_runtime()
call = [_bash_call("ls")]
mw._apply(_make_state(tool_calls=call), runtime)
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
assert result is None
queued = mw._pending_warnings.get(_pending_key(), [])
assert queued
assert "LOOP DETECTED" in queued[0]
@@ -50,6 +50,8 @@ Intercepts clarification tool calls and converts them into proper user-facing re
Detects when the agent is making the same tool call repeatedly without making progress. When a loop is detected, the middleware intervenes to break the cycle and prevents the agent from burning turns indefinitely.
Warning interventions are queued per thread and run, then drained on the next model call as a single hidden `HumanMessage(name="loop_warning")` appended after existing tool results. This keeps provider tool-call pairing valid. Run start/end hooks clear stale or undelivered warnings, and hard stops still strip tool calls before forcing a final text response.
**Configuration**: built-in, no user configuration.
---
@@ -50,6 +50,8 @@ import { Callout } from "nextra/components";
检测 Agent 是否在没有取得进展的情况下重复进行相同的工具调用。检测到循环时,中间件会介入打破循环,防止 Agent 无限消耗轮次。
Warning 介入会按 thread 和 run 排队,并在下一次模型调用时合并为一条隐藏的 `HumanMessage(name="loop_warning")`,追加到已有工具结果之后。这样不会破坏 provider 对 tool-call/tool-message 配对的校验。Run 开始和结束时会清理过期或未送达的 warning;达到 hard stop 时仍会清空 tool calls 并强制生成最终文本回复。
**配置**:内置,无需用户配置。
---