mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 07:26:50 +00:00
feat(loop-detection): defer warning injection (#2752)
* fix(loop-detection): defer warn injection to wrap_model_call The warn branch in LoopDetectionMiddleware injected a HumanMessage into state from after_model. The tools node had not yet produced ToolMessage responses to the previous AIMessage(tool_calls=...), so the new HumanMessage landed *between* the assistant's tool_calls and their responses. OpenAI/Moonshot reject the next request with "tool_call_ids did not have response messages" because their validators require tool_calls to be followed immediately by tool messages. Detection now runs in after_model as before, but only enqueues the warning into a per-thread list. Injection happens in wrap_model_call, where every prior ToolMessage is already present in request.messages. The warning is appended at the end as HumanMessage(name="loop_warning") — pairing intact, AIMessage semantics untouched, no SystemMessage issues for Anthropic. Closes #2029, addresses #2255 #2293 #2304 #2511. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix(channels): remove loop warning display filter * feat(loop-detection): scope pending warnings by run * docs(loop-detection): update docs * test(loop-detection): assert deferred warnings are queued * fix(loop-detection): cap transient warning state * docs: update docs * add async awrap_model_call test coverage * docs(loop-detection): document transient warnings --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -146,13 +146,6 @@ def _normalize_custom_agent_name(raw_value: str) -> str:
|
||||
return normalized
|
||||
|
||||
|
||||
def _strip_loop_warning_text(text: str) -> str:
|
||||
"""Remove middleware-authored loop warning lines from display text."""
|
||||
if "[LOOP DETECTED]" not in text:
|
||||
return text
|
||||
return "\n".join(line for line in text.splitlines() if "[LOOP DETECTED]" not in line).strip()
|
||||
|
||||
|
||||
def _extract_response_text(result: dict | list) -> str:
|
||||
"""Extract the last AI message text from a LangGraph runs.wait result.
|
||||
|
||||
@@ -162,7 +155,6 @@ def _extract_response_text(result: dict | list) -> str:
|
||||
Handles special cases:
|
||||
- Regular AI text responses
|
||||
- Clarification interrupts (``ask_clarification`` tool messages)
|
||||
- Strips loop-detection warnings attached to tool-call AI messages
|
||||
"""
|
||||
if isinstance(result, list):
|
||||
messages = result
|
||||
@@ -192,12 +184,7 @@ def _extract_response_text(result: dict | list) -> str:
|
||||
# Regular AI message with text content
|
||||
if msg_type == "ai":
|
||||
content = msg.get("content", "")
|
||||
has_tool_calls = bool(msg.get("tool_calls"))
|
||||
if isinstance(content, str) and content:
|
||||
if has_tool_calls:
|
||||
content = _strip_loop_warning_text(content)
|
||||
if not content:
|
||||
continue
|
||||
return content
|
||||
# content can be a list of content blocks
|
||||
if isinstance(content, list):
|
||||
@@ -208,8 +195,6 @@ def _extract_response_text(result: dict | list) -> str:
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
text = "".join(parts)
|
||||
if has_tool_calls:
|
||||
text = _strip_loop_warning_text(text)
|
||||
if text:
|
||||
return text
|
||||
return ""
|
||||
|
||||
@@ -4,22 +4,22 @@
|
||||
|
||||
`create_deerflow_agent` 通过 `RuntimeFeatures` 组装的完整 middleware 链(默认全开时):
|
||||
|
||||
| # | Middleware | `before_agent` | `before_model` | `after_model` | `after_agent` | `wrap_tool_call` | 主 Agent | Subagent | 来源 |
|
||||
|---|-----------|:-:|:-:|:-:|:-:|:-:|:-:|:-:|------|
|
||||
| 0 | ThreadDataMiddleware | ✓ | | | | | ✓ | ✓ | `sandbox` |
|
||||
| 1 | UploadsMiddleware | ✓ | | | | | ✓ | ✗ | `sandbox` |
|
||||
| 2 | SandboxMiddleware | ✓ | | | ✓ | | ✓ | ✓ | `sandbox` |
|
||||
| 3 | DanglingToolCallMiddleware | | | ✓ | | | ✓ | ✗ | 始终开启 |
|
||||
| 4 | GuardrailMiddleware | | | | | ✓ | ✓ | ✓ | *Phase 2 纳入* |
|
||||
| 5 | ToolErrorHandlingMiddleware | | | | | ✓ | ✓ | ✓ | 始终开启 |
|
||||
| 6 | SummarizationMiddleware | | | ✓ | | | ✓ | ✗ | `summarization` |
|
||||
| 7 | TodoMiddleware | | | ✓ | | | ✓ | ✗ | `plan_mode` 参数 |
|
||||
| 8 | TitleMiddleware | | | ✓ | | | ✓ | ✗ | `auto_title` |
|
||||
| 9 | MemoryMiddleware | | | | ✓ | | ✓ | ✗ | `memory` |
|
||||
| 10 | ViewImageMiddleware | | ✓ | | | | ✓ | ✗ | `vision` |
|
||||
| 11 | SubagentLimitMiddleware | | | ✓ | | | ✓ | ✗ | `subagent` |
|
||||
| 12 | LoopDetectionMiddleware | | | ✓ | | | ✓ | ✗ | 始终开启 |
|
||||
| 13 | ClarificationMiddleware | | | ✓ | | | ✓ | ✗ | 始终最后 |
|
||||
| # | Middleware | `before_agent` | `before_model` | `after_model` | `after_agent` | `wrap_model_call` | `wrap_tool_call` | 主 Agent | Subagent | 来源 |
|
||||
|---|-----------|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|------|
|
||||
| 0 | ThreadDataMiddleware | ✓ | | | | | | ✓ | ✓ | `sandbox` |
|
||||
| 1 | UploadsMiddleware | ✓ | | | | | | ✓ | ✗ | `sandbox` |
|
||||
| 2 | SandboxMiddleware | ✓ | | | ✓ | | | ✓ | ✓ | `sandbox` |
|
||||
| 3 | DanglingToolCallMiddleware | | | | | ✓ | | ✓ | ✗ | 始终开启 |
|
||||
| 4 | GuardrailMiddleware | | | | | | ✓ | ✓ | ✓ | *Phase 2 纳入* |
|
||||
| 5 | ToolErrorHandlingMiddleware | | | | | | ✓ | ✓ | ✓ | 始终开启 |
|
||||
| 6 | SummarizationMiddleware | | ✓ | | | | | ✓ | ✗ | `summarization` |
|
||||
| 7 | TodoMiddleware | | ✓ | ✓ | | ✓ | | ✓ | ✗ | `plan_mode` 参数 |
|
||||
| 8 | TitleMiddleware | | | ✓ | | | | ✓ | ✗ | `auto_title` |
|
||||
| 9 | MemoryMiddleware | | | | ✓ | | | ✓ | ✗ | `memory` |
|
||||
| 10 | ViewImageMiddleware | | ✓ | | | | | ✓ | ✗ | `vision` |
|
||||
| 11 | SubagentLimitMiddleware | | | ✓ | | | | ✓ | ✗ | `subagent` |
|
||||
| 12 | LoopDetectionMiddleware | ✓ | | ✓ | ✓ | ✓ | | ✓ | ✗ | 始终开启 |
|
||||
| 13 | ClarificationMiddleware | | | | | | ✓ | ✓ | ✗ | 始终最后 |
|
||||
|
||||
主 agent **14 个** middleware(`make_lead_agent`),subagent **4 个**(ThreadData、Sandbox、Guardrail、ToolErrorHandling)。`create_deerflow_agent` Phase 1 实现 **13 个**(Guardrail 仅支持自定义实例,无内置默认)。
|
||||
|
||||
@@ -35,7 +35,7 @@ graph TB
|
||||
|
||||
subgraph BA ["<b>before_agent</b> 正序 0→N"]
|
||||
direction TB
|
||||
TD["[0] ThreadData<br/>创建线程目录"] --> UL["[1] Uploads<br/>扫描上传文件"] --> SB["[2] Sandbox<br/>获取沙箱"]
|
||||
TD["[0] ThreadData<br/>创建线程目录"] --> UL["[1] Uploads<br/>扫描上传文件"] --> SB["[2] Sandbox<br/>获取沙箱"] --> LD_BA["[12] LoopDetection<br/>清理 stale warning"]
|
||||
end
|
||||
|
||||
subgraph BM ["<b>before_model</b> 正序 0→N"]
|
||||
@@ -43,34 +43,42 @@ graph TB
|
||||
VI["[10] ViewImage<br/>注入图片 base64"]
|
||||
end
|
||||
|
||||
SB --> VI
|
||||
VI --> M["<b>MODEL</b>"]
|
||||
subgraph WM ["<b>wrap_model_call</b>"]
|
||||
direction TB
|
||||
DTC_WM["[3] DanglingToolCall<br/>补悬空 ToolMessage"] --> LD_WM["[12] LoopDetection<br/>注入当前 run warning"]
|
||||
end
|
||||
|
||||
LD_BA --> VI
|
||||
VI --> DTC_WM
|
||||
LD_WM --> M["<b>MODEL</b>"]
|
||||
|
||||
subgraph AM ["<b>after_model</b> 反序 N→0"]
|
||||
direction TB
|
||||
CL["[13] Clarification<br/>拦截 ask_clarification"] --> LD["[12] LoopDetection<br/>检测循环"] --> SL["[11] SubagentLimit<br/>截断多余 task"] --> TI["[8] Title<br/>生成标题"] --> SM["[6] Summarization<br/>上下文压缩"] --> DTC["[3] DanglingToolCall<br/>补缺失 ToolMessage"]
|
||||
LD["[12] LoopDetection<br/>检测循环/排队 warning"] --> SL["[11] SubagentLimit<br/>截断多余 task"] --> TI["[8] Title<br/>生成标题"]
|
||||
end
|
||||
|
||||
M --> CL
|
||||
M --> LD
|
||||
|
||||
subgraph AA ["<b>after_agent</b> 反序 N→0"]
|
||||
direction TB
|
||||
SBR["[2] Sandbox<br/>释放沙箱"] --> MEM["[9] Memory<br/>入队记忆"]
|
||||
LD_CLEAN["[12] LoopDetection<br/>清理 pending warning"] --> MEM["[9] Memory<br/>入队记忆"] --> SBR["[2] Sandbox<br/>释放沙箱"]
|
||||
end
|
||||
|
||||
DTC --> SBR
|
||||
MEM --> END(["response"])
|
||||
TI --> LD_CLEAN
|
||||
SBR --> END(["response"])
|
||||
|
||||
classDef beforeNode fill:#a0a8b5,stroke:#636b7a,color:#2d3239
|
||||
classDef modelNode fill:#b5a8a0,stroke:#7a6b63,color:#2d3239
|
||||
classDef wrapModelNode fill:#a8a0b5,stroke:#6b637a,color:#2d3239
|
||||
classDef afterModelNode fill:#b5a0a8,stroke:#7a636b,color:#2d3239
|
||||
classDef afterAgentNode fill:#a0b5a8,stroke:#637a6b,color:#2d3239
|
||||
classDef terminalNode fill:#a8b5a0,stroke:#6b7a63,color:#2d3239
|
||||
|
||||
class TD,UL,SB,VI beforeNode
|
||||
class TD,UL,SB,LD_BA,VI beforeNode
|
||||
class DTC_WM,LD_WM wrapModelNode
|
||||
class M modelNode
|
||||
class CL,LD,SL,TI,SM,DTC afterModelNode
|
||||
class SBR,MEM afterAgentNode
|
||||
class LD,SL,TI afterModelNode
|
||||
class LD_CLEAN,SBR,MEM afterAgentNode
|
||||
class START,END terminalNode
|
||||
```
|
||||
|
||||
@@ -82,13 +90,12 @@ sequenceDiagram
|
||||
participant TD as ThreadDataMiddleware
|
||||
participant UL as UploadsMiddleware
|
||||
participant SB as SandboxMiddleware
|
||||
participant LD as LoopDetectionMiddleware
|
||||
participant VI as ViewImageMiddleware
|
||||
participant DTC as DanglingToolCallMiddleware
|
||||
participant M as MODEL
|
||||
participant CL as ClarificationMiddleware
|
||||
participant SL as SubagentLimitMiddleware
|
||||
participant TI as TitleMiddleware
|
||||
participant SM as SummarizationMiddleware
|
||||
participant DTC as DanglingToolCallMiddleware
|
||||
participant MEM as MemoryMiddleware
|
||||
|
||||
U ->> TD: invoke
|
||||
@@ -103,19 +110,26 @@ sequenceDiagram
|
||||
activate SB
|
||||
Note right of SB: before_agent 获取沙箱
|
||||
|
||||
SB ->> VI: before_model
|
||||
SB ->> LD: before_agent
|
||||
activate LD
|
||||
Note right of LD: before_agent 清理同 thread 旧 run 的 pending warning
|
||||
LD ->> VI: before_model
|
||||
activate VI
|
||||
Note right of VI: before_model 注入图片 base64
|
||||
|
||||
VI ->> M: messages + tools
|
||||
VI ->> DTC: wrap_model_call
|
||||
activate DTC
|
||||
Note right of DTC: wrap_model_call 补悬空 ToolMessage
|
||||
DTC ->> LD: wrap_model_call
|
||||
Note right of LD: wrap_model_call drain 当前 run warning 并追加到末尾
|
||||
LD ->> M: messages + tools
|
||||
activate M
|
||||
M -->> CL: AI response
|
||||
M -->> LD: AI response
|
||||
deactivate M
|
||||
|
||||
activate CL
|
||||
Note right of CL: after_model 拦截 ask_clarification
|
||||
CL -->> SL: after_model
|
||||
deactivate CL
|
||||
Note right of LD: after_model 检测循环;warning 入队,hard-stop 清 tool_calls
|
||||
LD -->> SL: after_model
|
||||
deactivate LD
|
||||
|
||||
activate SL
|
||||
Note right of SL: after_model 截断多余 task
|
||||
@@ -124,22 +138,18 @@ sequenceDiagram
|
||||
|
||||
activate TI
|
||||
Note right of TI: after_model 生成标题
|
||||
TI -->> SM: after_model
|
||||
TI -->> DTC: done
|
||||
deactivate TI
|
||||
|
||||
activate SM
|
||||
Note right of SM: after_model 上下文压缩
|
||||
SM -->> DTC: after_model
|
||||
deactivate SM
|
||||
|
||||
activate DTC
|
||||
Note right of DTC: after_model 补缺失 ToolMessage
|
||||
DTC -->> VI: done
|
||||
deactivate DTC
|
||||
|
||||
VI -->> SB: done
|
||||
deactivate VI
|
||||
|
||||
Note right of LD: after_agent 清理当前 run 未消费 warning
|
||||
|
||||
Note right of MEM: after_agent 入队记忆
|
||||
|
||||
Note right of SB: after_agent 释放沙箱
|
||||
SB -->> UL: done
|
||||
deactivate SB
|
||||
@@ -147,8 +157,6 @@ sequenceDiagram
|
||||
UL -->> TD: done
|
||||
deactivate UL
|
||||
|
||||
Note right of MEM: after_agent 入队记忆
|
||||
|
||||
TD -->> U: response
|
||||
deactivate TD
|
||||
```
|
||||
@@ -224,12 +232,12 @@ sequenceDiagram
|
||||
participant TD as ThreadData
|
||||
participant UL as Uploads
|
||||
participant SB as Sandbox
|
||||
participant LD as LoopDetection
|
||||
participant VI as ViewImage
|
||||
participant DTC as DanglingToolCall
|
||||
participant M as MODEL
|
||||
participant CL as Clarification
|
||||
participant SL as SubagentLimit
|
||||
participant TI as Title
|
||||
participant SM as Summarization
|
||||
participant MEM as Memory
|
||||
|
||||
U ->> TD: invoke
|
||||
@@ -238,34 +246,40 @@ sequenceDiagram
|
||||
Note right of UL: before_agent 扫描文件
|
||||
UL ->> SB: .
|
||||
Note right of SB: before_agent 获取沙箱
|
||||
SB ->> LD: .
|
||||
Note right of LD: before_agent 清理 stale pending warning
|
||||
|
||||
loop 每轮对话(tool call 循环)
|
||||
SB ->> VI: .
|
||||
Note right of VI: before_model 注入图片
|
||||
VI ->> M: messages + tools
|
||||
M -->> CL: AI response
|
||||
Note right of CL: after_model 拦截 ask_clarification
|
||||
CL -->> SL: .
|
||||
VI ->> DTC: .
|
||||
Note right of DTC: wrap_model_call 补悬空工具结果
|
||||
DTC ->> LD: .
|
||||
Note right of LD: wrap_model_call 注入当前 run warning
|
||||
LD ->> M: messages + tools
|
||||
M -->> LD: AI response
|
||||
Note right of LD: after_model 检测循环/排队 warning
|
||||
LD -->> SL: .
|
||||
Note right of SL: after_model 截断多余 task
|
||||
SL -->> TI: .
|
||||
Note right of TI: after_model 生成标题
|
||||
TI -->> SM: .
|
||||
Note right of SM: after_model 上下文压缩
|
||||
end
|
||||
|
||||
Note right of SB: after_agent 释放沙箱
|
||||
SB -->> MEM: .
|
||||
Note right of LD: after_agent 清理当前 run pending warning
|
||||
LD -->> MEM: .
|
||||
Note right of MEM: after_agent 入队记忆
|
||||
MEM -->> U: response
|
||||
MEM -->> SB: .
|
||||
Note right of SB: after_agent 释放沙箱
|
||||
SB -->> U: response
|
||||
```
|
||||
|
||||
> [!warning] 不是洋葱
|
||||
> 14 个 middleware 中只有 SandboxMiddleware 有 before/after 对称(获取/释放)。其余都是单向的:要么只在 `before_*` 做事,要么只在 `after_*` 做事。`before_agent` / `after_agent` 只跑一次,`before_model` / `after_model` 每轮循环都跑。
|
||||
> 大部分 middleware 只用一个阶段。SandboxMiddleware 使用 `before_agent`/`after_agent` 做资源获取/释放;LoopDetectionMiddleware 也使用这两个钩子,但用途是清理 run-scoped pending warnings,不是资源生命周期对称。`before_agent` / `after_agent` 只跑一次,`before_model` / `after_model` / `wrap_model_call` 每轮循环都跑。
|
||||
|
||||
硬依赖只有 2 处:
|
||||
|
||||
1. **ThreadData 在 Sandbox 之前** — sandbox 需要线程目录
|
||||
2. **Clarification 在列表最后** — `after_model` 反序时最先执行,第一个拦截 `ask_clarification`
|
||||
2. **Clarification 在列表最后** — `wrap_tool_call` 处理 `ask_clarification` 时优先拦截,并通过 `Command(goto=END)` 中断执行
|
||||
|
||||
### 结论
|
||||
|
||||
@@ -273,19 +287,19 @@ sequenceDiagram
|
||||
|---|---|---|
|
||||
| 每个 middleware | before + after 对称 | 大多只用一个钩子 |
|
||||
| 激活条 | 嵌套(外长内短) | 不嵌套(串行) |
|
||||
| 反序的意义 | 清理与初始化配对 | 仅影响 after_model 的执行优先级 |
|
||||
| 反序的意义 | 清理与初始化配对 | 影响 `after_model` / `after_agent` 的执行优先级 |
|
||||
| 典型例子 | Auth: 校验 token / 清理上下文 | ThreadData: 只创建目录,没有清理 |
|
||||
|
||||
## 关键设计点
|
||||
|
||||
### ClarificationMiddleware 为什么在列表最后?
|
||||
|
||||
位置最后 = `after_model` 最先执行。它需要**第一个**看到 model 输出,检查是否有 `ask_clarification` tool call。如果有,立即中断(`Command(goto=END)`),后续 middleware 的 `after_model` 不再执行。
|
||||
位置最后使它在工具调用包装链中优先拦截 `ask_clarification`。如果命中,它返回 `Command(goto=END)`,把格式化后的澄清问题写成 `ToolMessage` 并中断执行。
|
||||
|
||||
### SandboxMiddleware 的对称性
|
||||
|
||||
`before_agent`(正序第 3 个)获取沙箱,`after_agent`(反序第 1 个)释放沙箱。外层进入 → 外层退出,天然的洋葱对称。
|
||||
|
||||
### 大部分 middleware 只用一个钩子
|
||||
### LoopDetectionMiddleware 为什么同时用多个钩子?
|
||||
|
||||
14 个 middleware 中,只有 SandboxMiddleware 同时用了 `before_agent` + `after_agent`(获取/释放)。其余都只在一个阶段执行。洋葱模型的反序特性主要影响 `after_model` 阶段的执行顺序。
|
||||
`after_model` 只做检测:重复工具调用达到 warning 阈值时,把 warning 放入 `(thread_id, run_id)` 作用域的 pending 队列。真正注入发生在下一次 `wrap_model_call`:此时上一轮 `AIMessage(tool_calls)` 对应的 `ToolMessage` 已经在请求里,warning 追加在末尾,不会破坏 OpenAI/Moonshot 的 tool-call pairing。`before_agent` 清理同一 thread 下旧 run 的残留 warning,`after_agent` 清理当前 run 没被消费的 warning。
|
||||
|
||||
+201
-28
@@ -6,10 +6,36 @@ arguments indefinitely until the recursion limit kills the run.
|
||||
Detection strategy:
|
||||
1. After each model response, hash the tool calls (name + args).
|
||||
2. Track recent hashes in a sliding window.
|
||||
3. If the same hash appears >= warn_threshold times, inject a
|
||||
"you are repeating yourself — wrap up" system message (once per hash).
|
||||
3. If the same hash appears >= warn_threshold times, queue a
|
||||
"you are repeating yourself — wrap up" warning for the current
|
||||
thread/run. The warning is **injected at the next model call** (in
|
||||
``wrap_model_call``) as a ``HumanMessage`` appended to the message
|
||||
list, *after* all ToolMessage responses to the previous
|
||||
AIMessage(tool_calls).
|
||||
4. If it appears >= hard_limit times, strip all tool_calls from the
|
||||
response so the agent is forced to produce a final text answer.
|
||||
|
||||
Why the warning is injected at ``wrap_model_call`` instead of
|
||||
``after_model``:
|
||||
|
||||
``after_model`` fires immediately after the model emits an
|
||||
``AIMessage`` that may carry ``tool_calls``. The tools node has not
|
||||
run yet, so no matching ``ToolMessage`` exists in the history. Any
|
||||
message we add here lands *between* the assistant's tool_calls and
|
||||
their responses. OpenAI/Moonshot reject the next request with
|
||||
``"tool_call_ids did not have response messages"`` because their
|
||||
validators require the assistant's tool_calls to be followed
|
||||
immediately by tool messages. Anthropic also disallows mid-stream
|
||||
``SystemMessage``. By deferring the warning to ``wrap_model_call``,
|
||||
every prior ToolMessage is already present in the request's message
|
||||
list and the warning is appended at the end — pairing intact, no
|
||||
``AIMessage`` semantics are mutated.
|
||||
|
||||
Queued warnings are intentionally transient. If a run ends before the
|
||||
next model request drains a queued warning, ``after_agent`` drops it
|
||||
instead of carrying it into a later invocation for the same thread. The
|
||||
hard-stop path still forces termination when the configured safety limit
|
||||
is reached.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -19,11 +45,14 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -38,6 +67,7 @@ _DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
|
||||
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
|
||||
_DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type
|
||||
_DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type
|
||||
_MAX_PENDING_WARNINGS_PER_RUN = 4
|
||||
|
||||
|
||||
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
|
||||
@@ -195,6 +225,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
self._warned: dict[str, set[str]] = defaultdict(set)
|
||||
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
|
||||
# Per-thread/run queue of warnings to inject at the next model call.
|
||||
# Populated by ``after_model`` (detection) and drained by
|
||||
# ``wrap_model_call`` (injection); see module docstring.
|
||||
self._pending_warnings: dict[tuple[str, str], list[str]] = defaultdict(list)
|
||||
self._pending_warning_touch_order: OrderedDict[tuple[str, str], None] = OrderedDict()
|
||||
self._max_pending_warning_keys = max(1, self.max_tracked_threads * 2)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware:
|
||||
@@ -213,9 +249,20 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Extract thread_id from runtime context for per-thread tracking."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id:
|
||||
return thread_id
|
||||
return str(thread_id)
|
||||
return "default"
|
||||
|
||||
def _get_run_id(self, runtime: Runtime) -> str:
|
||||
"""Extract run_id from runtime context for per-run warning scoping."""
|
||||
run_id = runtime.context.get("run_id") if runtime.context else None
|
||||
if run_id:
|
||||
return str(run_id)
|
||||
return "default"
|
||||
|
||||
def _pending_key(self, runtime: Runtime) -> tuple[str, str]:
|
||||
"""Return the pending-warning key for the current thread/run."""
|
||||
return self._get_thread_id(runtime), self._get_run_id(runtime)
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict least recently used threads if over the limit.
|
||||
|
||||
@@ -226,8 +273,52 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
self._warned.pop(evicted_id, None)
|
||||
self._tool_freq.pop(evicted_id, None)
|
||||
self._tool_freq_warned.pop(evicted_id, None)
|
||||
for key in list(self._pending_warnings):
|
||||
if key[0] == evicted_id:
|
||||
self._drop_pending_warning_key_locked(key)
|
||||
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
|
||||
|
||||
def _drop_pending_warning_key_locked(self, key: tuple[str, str]) -> None:
|
||||
"""Drop all pending-warning bookkeeping for one thread/run key.
|
||||
|
||||
Must be called while holding self._lock.
|
||||
"""
|
||||
self._pending_warnings.pop(key, None)
|
||||
self._pending_warning_touch_order.pop(key, None)
|
||||
|
||||
def _touch_pending_warning_key_locked(self, key: tuple[str, str]) -> None:
|
||||
"""Mark a pending-warning key as recently used.
|
||||
|
||||
Must be called while holding self._lock.
|
||||
"""
|
||||
self._pending_warning_touch_order[key] = None
|
||||
self._pending_warning_touch_order.move_to_end(key)
|
||||
|
||||
def _prune_pending_warning_state_locked(self, protected_key: tuple[str, str]) -> None:
|
||||
"""Cap pending-warning state across abnormal or concurrent runs.
|
||||
|
||||
Must be called while holding self._lock.
|
||||
"""
|
||||
overflow = len(self._pending_warning_touch_order) - self._max_pending_warning_keys
|
||||
if overflow <= 0:
|
||||
return
|
||||
|
||||
candidates = [key for key in self._pending_warning_touch_order if key != protected_key]
|
||||
for key in candidates[:overflow]:
|
||||
self._drop_pending_warning_key_locked(key)
|
||||
|
||||
def _queue_pending_warning(self, runtime: Runtime, warning: str) -> None:
|
||||
"""Queue one transient warning for the current thread/run with caps."""
|
||||
pending_key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
warnings = self._pending_warnings[pending_key]
|
||||
if warning not in warnings:
|
||||
warnings.append(warning)
|
||||
if len(warnings) > _MAX_PENDING_WARNINGS_PER_RUN:
|
||||
del warnings[: len(warnings) - _MAX_PENDING_WARNINGS_PER_RUN]
|
||||
self._touch_pending_warning_key_locked(pending_key)
|
||||
self._prune_pending_warning_state_locked(protected_key=pending_key)
|
||||
|
||||
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
|
||||
"""Track tool calls and check for loops.
|
||||
|
||||
@@ -268,6 +359,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
if len(history) > self.window_size:
|
||||
history[:] = history[-self.window_size :]
|
||||
|
||||
warned_hashes = self._warned.get(thread_id)
|
||||
if warned_hashes is not None:
|
||||
warned_hashes.intersection_update(history)
|
||||
if not warned_hashes:
|
||||
self._warned.pop(thread_id, None)
|
||||
|
||||
count = history.count(call_hash)
|
||||
tool_names = [tc.get("name", "?") for tc in tool_calls]
|
||||
|
||||
@@ -381,7 +478,10 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
warning, hard_stop = self._track_and_check(state, runtime)
|
||||
|
||||
if hard_stop:
|
||||
# Strip tool_calls from the last AIMessage to force text output
|
||||
# Strip tool_calls from the last AIMessage to force text output.
|
||||
# Once tool_calls are stripped, the AIMessage no longer requires
|
||||
# matching ToolMessage responses, so mutating it in place here
|
||||
# is safe for OpenAI/Moonshot pairing validators.
|
||||
messages = state.get("messages", [])
|
||||
last_msg = messages[-1]
|
||||
content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG)
|
||||
@@ -389,33 +489,48 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
return {"messages": [stripped_msg]}
|
||||
|
||||
if warning:
|
||||
# WORKAROUND for v2.0-m1 — see #2724.
|
||||
#
|
||||
# Append the warning to the AIMessage content instead of
|
||||
# injecting a separate HumanMessage. Inserting any non-tool
|
||||
# message between an AIMessage(tool_calls=...) and its
|
||||
# ToolMessage responses breaks OpenAI/Moonshot strict pairing
|
||||
# validation ("tool_call_ids did not have response messages")
|
||||
# because the tools node has not run yet at after_model time.
|
||||
# tool_calls are preserved so the tools node still executes.
|
||||
#
|
||||
# This is a temporary mitigation: mutating an existing
|
||||
# AIMessage to carry framework-authored text leaks loop-warning
|
||||
# text into downstream consumers (MemoryMiddleware fact
|
||||
# extraction, TitleMiddleware, telemetry, model replay) as if
|
||||
# the model said it. The proper fix is to defer warning
|
||||
# injection from after_model to wrap_model_call so every prior
|
||||
# ToolMessage is already in the request — see RFC #2517 (which
|
||||
# lists "loop intervention does not leave invalid
|
||||
# tool-call/tool-message state" as acceptance criteria) and
|
||||
# the prototype on `fix/loop-detection-tool-call-pairing`.
|
||||
messages = state.get("messages", [])
|
||||
last_msg = messages[-1]
|
||||
patched_msg = last_msg.model_copy(update={"content": self._append_text(last_msg.content, warning)})
|
||||
return {"messages": [patched_msg]}
|
||||
# Defer injection to the next model call. We must NOT alter the
|
||||
# AIMessage(tool_calls=...) here (would put framework words in
|
||||
# the model's mouth, polluting downstream consumers like
|
||||
# MemoryMiddleware), nor insert a separate non-tool message
|
||||
# (would break OpenAI/Moonshot tool-call pairing because the
|
||||
# tools node has not produced ToolMessage responses yet). The
|
||||
# warning is delivered via ``wrap_model_call`` below.
|
||||
self._queue_pending_warning(runtime, warning)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _clear_other_run_pending_warnings(self, runtime: Runtime) -> None:
|
||||
"""Drop stale pending warnings for previous runs in this thread."""
|
||||
thread_id, current_run_id = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
for key in list(self._pending_warnings):
|
||||
if key[0] == thread_id and key[1] != current_run_id:
|
||||
self._drop_pending_warning_key_locked(key)
|
||||
|
||||
def _clear_current_run_pending_warnings(self, runtime: Runtime) -> None:
|
||||
"""Drop pending warnings owned by the current thread/run."""
|
||||
pending_key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
self._drop_pending_warning_key_locked(pending_key)
|
||||
|
||||
@staticmethod
|
||||
def _format_warning_message(warnings: list[str]) -> str:
|
||||
"""Merge pending warnings into one prompt message."""
|
||||
deduped = list(dict.fromkeys(warnings))
|
||||
return "\n\n".join(deduped)
|
||||
|
||||
@override
|
||||
def before_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
self._clear_other_run_pending_warnings(runtime)
|
||||
return None
|
||||
|
||||
@override
|
||||
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
self._clear_other_run_pending_warnings(runtime)
|
||||
return None
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
@@ -424,6 +539,59 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
@override
|
||||
def after_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
self._clear_current_run_pending_warnings(runtime)
|
||||
return None
|
||||
|
||||
@override
|
||||
async def aafter_agent(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
self._clear_current_run_pending_warnings(runtime)
|
||||
return None
|
||||
|
||||
def _drain_pending_warnings(self, runtime: Runtime) -> list[str]:
|
||||
"""Pop and return all queued warnings for *runtime*'s thread/run."""
|
||||
pending_key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
warnings = self._pending_warnings.pop(pending_key, [])
|
||||
self._pending_warning_touch_order.pop(pending_key, None)
|
||||
return warnings
|
||||
|
||||
def _augment_request(self, request: ModelRequest) -> ModelRequest:
|
||||
"""Append queued loop warnings (if any) to the outgoing message list.
|
||||
|
||||
The warning is placed *after* every existing message, including the
|
||||
ToolMessage responses to the previous AIMessage(tool_calls). This
|
||||
keeps ``assistant tool_calls -> tool_messages`` pairing intact for
|
||||
OpenAI/Moonshot, avoids the Anthropic mid-stream SystemMessage
|
||||
restriction (we use HumanMessage), and never mutates an existing
|
||||
AIMessage.
|
||||
"""
|
||||
warnings = self._drain_pending_warnings(request.runtime)
|
||||
if not warnings:
|
||||
return request
|
||||
new_messages = [
|
||||
*request.messages,
|
||||
HumanMessage(content=self._format_warning_message(warnings), name="loop_warning"),
|
||||
]
|
||||
return request.override(messages=new_messages)
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(self._augment_request(request))
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
return await handler(self._augment_request(request))
|
||||
|
||||
def reset(self, thread_id: str | None = None) -> None:
|
||||
"""Clear tracking state. If thread_id given, clear only that thread."""
|
||||
with self._lock:
|
||||
@@ -432,8 +600,13 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
self._warned.pop(thread_id, None)
|
||||
self._tool_freq.pop(thread_id, None)
|
||||
self._tool_freq_warned.pop(thread_id, None)
|
||||
for key in list(self._pending_warnings):
|
||||
if key[0] == thread_id:
|
||||
self._drop_pending_warning_key_locked(key)
|
||||
else:
|
||||
self._history.clear()
|
||||
self._warned.clear()
|
||||
self._tool_freq.clear()
|
||||
self._tool_freq_warned.clear()
|
||||
self._pending_warnings.clear()
|
||||
self._pending_warning_touch_order.clear()
|
||||
|
||||
@@ -372,37 +372,6 @@ class TestExtractResponseText:
|
||||
# Should return "" (no text in current turn), NOT "Hi there!" from previous turn
|
||||
assert _extract_response_text(result) == ""
|
||||
|
||||
def test_does_not_publish_loop_warning_on_tool_calling_ai_message(self):
|
||||
"""Loop-detection warning text on a tool-calling AI message is middleware-authored."""
|
||||
from app.channels.manager import _extract_response_text
|
||||
|
||||
result = {
|
||||
"messages": [
|
||||
{"type": "human", "content": "search the repo"},
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "[LOOP DETECTED] You are repeating the same tool calls.",
|
||||
"tool_calls": [{"name": "grep", "args": {"pattern": "TODO"}, "id": "call_1"}],
|
||||
},
|
||||
]
|
||||
}
|
||||
assert _extract_response_text(result) == ""
|
||||
|
||||
def test_preserves_visible_text_when_stripping_loop_warning(self):
|
||||
from app.channels.manager import _extract_response_text
|
||||
|
||||
result = {
|
||||
"messages": [
|
||||
{"type": "human", "content": "prepare the report"},
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "Here is the report.\n\n[LOOP DETECTED] You are repeating the same tool calls.",
|
||||
"tool_calls": [{"name": "present_files", "args": {"filepaths": ["/mnt/user-data/outputs/report.md"]}, "id": "call_1"}],
|
||||
},
|
||||
]
|
||||
}
|
||||
assert _extract_response_text(result) == "Here is the report."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChannelManager tests
|
||||
|
||||
@@ -1,24 +1,94 @@
|
||||
"""Tests for LoopDetectionMiddleware."""
|
||||
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, SystemMessage
|
||||
import pytest
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import tool as as_tool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import (
|
||||
_HARD_STOP_MSG,
|
||||
_MAX_PENDING_WARNINGS_PER_RUN,
|
||||
LoopDetectionMiddleware,
|
||||
_hash_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
def _make_runtime(thread_id="test-thread"):
|
||||
def _make_runtime(thread_id="test-thread", run_id="test-run"):
|
||||
"""Build a minimal Runtime mock with context."""
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": thread_id}
|
||||
runtime.context = {"thread_id": thread_id, "run_id": run_id}
|
||||
return runtime
|
||||
|
||||
|
||||
def _pending_key(thread_id="test-thread", run_id="test-run"):
|
||||
return (thread_id, run_id)
|
||||
|
||||
|
||||
def _make_request(messages, runtime):
|
||||
"""Build a minimal ModelRequest stand-in for wrap_model_call tests."""
|
||||
request = MagicMock()
|
||||
request.messages = list(messages)
|
||||
request.runtime = runtime
|
||||
request.override = lambda **updates: _override_request(request, updates)
|
||||
return request
|
||||
|
||||
|
||||
def _override_request(request, updates):
|
||||
"""Mimic ModelRequest.override(): return a copy with fields replaced."""
|
||||
new = MagicMock()
|
||||
new.messages = updates.get("messages", request.messages)
|
||||
new.runtime = updates.get("runtime", request.runtime)
|
||||
new.override = lambda **u: _override_request(new, u)
|
||||
return new
|
||||
|
||||
|
||||
def _capture_handler():
|
||||
"""Build a sync handler that records the request it was called with."""
|
||||
captured: list = []
|
||||
|
||||
def handler(req):
|
||||
captured.append(req)
|
||||
return MagicMock()
|
||||
|
||||
return captured, handler
|
||||
|
||||
|
||||
class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel):
|
||||
"""Fake chat model that records each model request's messages."""
|
||||
|
||||
_seen_messages: list[list[Any]] = PrivateAttr(default_factory=list)
|
||||
|
||||
@property
|
||||
def seen_messages(self) -> list[list[Any]]:
|
||||
return self._seen_messages
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Any,
|
||||
*,
|
||||
tool_choice: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable:
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self._seen_messages.append(list(messages))
|
||||
return super()._generate(
|
||||
messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _make_state(tool_calls=None, content=""):
|
||||
"""Build a minimal AgentState dict with an AIMessage.
|
||||
|
||||
@@ -138,7 +208,15 @@ class TestLoopDetection:
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is None
|
||||
|
||||
def test_warn_at_threshold(self):
|
||||
def test_warn_at_threshold_queues_but_does_not_mutate_state(self):
|
||||
"""At warn threshold, ``after_model`` enqueues but returns None.
|
||||
|
||||
Detection observes the just-emitted AIMessage(tool_calls=...). The
|
||||
tools node hasn't run yet, so injecting any non-tool message here
|
||||
would split the assistant's tool_calls from their ToolMessage
|
||||
responses and break OpenAI/Moonshot pairing. The warning is
|
||||
delivered later from ``wrap_model_call``.
|
||||
"""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=5)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
@@ -146,44 +224,150 @@ class TestLoopDetection:
|
||||
for _ in range(2):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
# Third identical call triggers warning. The warning is appended to
|
||||
# the AIMessage content (tool_calls preserved) — never inserted as a
|
||||
# separate HumanMessage between the AIMessage(tool_calls) and its
|
||||
# ToolMessage responses, which would break OpenAI/Moonshot strict
|
||||
# tool-call pairing validation.
|
||||
# Third identical call triggers warning detection.
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
assert len(msgs) == 1
|
||||
assert isinstance(msgs[0], AIMessage)
|
||||
assert len(msgs[0].tool_calls) == len(call)
|
||||
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
|
||||
assert "LOOP DETECTED" in msgs[0].content
|
||||
# Detection must not mutate state — the AIMessage with tool_calls is
|
||||
# left untouched so the tools node runs normally.
|
||||
assert result is None
|
||||
# ...but a warning is queued for the next model call.
|
||||
assert mw._pending_warnings[_pending_key()]
|
||||
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key()][0]
|
||||
|
||||
def test_warn_does_not_break_tool_call_pairing(self):
|
||||
"""Regression: the warn branch must NOT inject a non-tool message
|
||||
after an AIMessage(tool_calls=...). Moonshot/OpenAI reject the next
|
||||
request with 'tool_call_ids did not have response messages' if any
|
||||
non-tool message is wedged between the AIMessage and its ToolMessage
|
||||
responses. See #2029.
|
||||
def test_warn_injected_at_next_model_call(self):
|
||||
"""``wrap_model_call`` appends a HumanMessage(loop_warning) to the
|
||||
outgoing messages — *after* every existing message — so that the
|
||||
AIMessage(tool_calls=...) -> ToolMessage(...) pairing stays intact.
|
||||
"""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
for _ in range(2):
|
||||
for _ in range(3):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
assert len(msgs) == 1
|
||||
assert isinstance(msgs[0], AIMessage)
|
||||
assert len(msgs[0].tool_calls) == len(call)
|
||||
assert msgs[0].tool_calls[0]["id"] == call[0]["id"]
|
||||
# Build the messages the agent runtime would assemble for the next
|
||||
# turn: prior AIMessage(tool_calls), its ToolMessage responses, ...
|
||||
ai_msg = AIMessage(content="", tool_calls=call)
|
||||
tool_msg = ToolMessage(content="ok", tool_call_id=call[0]["id"], name="bash")
|
||||
request = _make_request([ai_msg, tool_msg], runtime)
|
||||
|
||||
def test_warn_only_injected_once(self):
|
||||
"""Warning for the same hash should only be injected once per thread."""
|
||||
captured, handler = _capture_handler()
|
||||
mw.wrap_model_call(request, handler)
|
||||
|
||||
sent = captured[0].messages
|
||||
# AIMessage and ToolMessage stay in order, untouched.
|
||||
assert sent[0] is ai_msg
|
||||
assert sent[1] is tool_msg
|
||||
# HumanMessage(warning) appears AFTER the ToolMessage — pairing intact.
|
||||
assert isinstance(sent[2], HumanMessage)
|
||||
assert sent[2].name == "loop_warning"
|
||||
assert "LOOP DETECTED" in sent[2].content
|
||||
|
||||
def test_warn_queue_drained_after_injection(self):
|
||||
"""A queued warning must be emitted exactly once per detection event."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
for _ in range(3):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
request = _make_request([AIMessage(content="hi")], runtime)
|
||||
captured, handler = _capture_handler()
|
||||
|
||||
# First call: warning is appended.
|
||||
mw.wrap_model_call(request, handler)
|
||||
first = captured[0].messages
|
||||
assert any(isinstance(m, HumanMessage) for m in first)
|
||||
|
||||
# Subsequent call without new detection: no warning re-emitted.
|
||||
request2 = _make_request([AIMessage(content="hi")], runtime)
|
||||
mw.wrap_model_call(request2, handler)
|
||||
second = captured[1].messages
|
||||
assert not any(isinstance(m, HumanMessage) for m in second)
|
||||
|
||||
def test_warn_queue_scoped_by_run_id(self):
|
||||
"""A warning queued for one run must not be injected into another run."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
||||
runtime_a = _make_runtime(run_id="run-A")
|
||||
runtime_b = _make_runtime(run_id="run-B")
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
for _ in range(3):
|
||||
mw._apply(_make_state(tool_calls=call), runtime_a)
|
||||
|
||||
request_b = _make_request([AIMessage(content="hi")], runtime_b)
|
||||
captured, handler = _capture_handler()
|
||||
mw.wrap_model_call(request_b, handler)
|
||||
assert not any(isinstance(m, HumanMessage) for m in captured[0].messages)
|
||||
assert mw._pending_warnings.get(_pending_key(run_id="run-A"))
|
||||
|
||||
request_a = _make_request([AIMessage(content="hi")], runtime_a)
|
||||
mw.wrap_model_call(request_a, handler)
|
||||
assert any(isinstance(message, HumanMessage) and message.name == "loop_warning" for message in captured[1].messages)
|
||||
|
||||
def test_missing_run_id_uses_default_pending_scope(self):
|
||||
"""When runtime has no run_id, warning handling falls back to the default run scope."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "test-thread"}
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
for _ in range(3):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
assert mw._pending_warnings.get(_pending_key(run_id="default"))
|
||||
|
||||
request = _make_request([AIMessage(content="hi")], runtime)
|
||||
captured, handler = _capture_handler()
|
||||
mw.wrap_model_call(request, handler)
|
||||
|
||||
loop_warnings = [message for message in captured[0].messages if isinstance(message, HumanMessage) and message.name == "loop_warning"]
|
||||
assert len(loop_warnings) == 1
|
||||
assert "LOOP DETECTED" in loop_warnings[0].content
|
||||
assert not mw._pending_warnings.get(_pending_key(run_id="default"))
|
||||
|
||||
def test_before_agent_clears_stale_pending_warnings_for_thread(self):
|
||||
"""Starting a new run drops stale warnings from prior runs in the same thread."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
||||
runtime_a = _make_runtime(run_id="run-A")
|
||||
runtime_b = _make_runtime(run_id="run-B")
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
for _ in range(3):
|
||||
mw._apply(_make_state(tool_calls=call), runtime_a)
|
||||
|
||||
assert mw._pending_warnings.get(_pending_key(run_id="run-A"))
|
||||
mw.before_agent({"messages": []}, runtime_b)
|
||||
assert not mw._pending_warnings.get(_pending_key(run_id="run-A"))
|
||||
|
||||
def test_after_agent_clears_current_run_pending_warnings(self):
|
||||
"""Run cleanup should drop warnings that never reached wrap_model_call."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
for _ in range(3):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
assert mw._pending_warnings.get(_pending_key())
|
||||
mw.after_agent({"messages": []}, runtime)
|
||||
assert not mw._pending_warnings.get(_pending_key())
|
||||
|
||||
def test_multiple_pending_warnings_are_merged_into_one_message(self):
|
||||
"""Edge-case drains should produce one loop_warning prompt message."""
|
||||
mw = LoopDetectionMiddleware()
|
||||
runtime = _make_runtime()
|
||||
mw._pending_warnings[_pending_key()] = ["first warning", "second warning", "first warning"]
|
||||
request = _make_request([AIMessage(content="hi")], runtime)
|
||||
captured, handler = _capture_handler()
|
||||
|
||||
mw.wrap_model_call(request, handler)
|
||||
|
||||
loop_warnings = [message for message in captured[0].messages if isinstance(message, HumanMessage) and message.name == "loop_warning"]
|
||||
assert len(loop_warnings) == 1
|
||||
assert loop_warnings[0].content == "first warning\n\nsecond warning"
|
||||
|
||||
def test_warn_only_queued_once_per_hash(self):
|
||||
"""Same hash repeated past the threshold should warn only once."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
@@ -192,14 +376,13 @@ class TestLoopDetection:
|
||||
for _ in range(2):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
# Third — warning injected
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
# Third — warning queued
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert len(mw._pending_warnings[_pending_key()]) == 1
|
||||
|
||||
# Fourth — warning already injected, should return None
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is None
|
||||
# Fourth — already warned for this hash, no additional enqueue.
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert len(mw._pending_warnings[_pending_key()]) == 1
|
||||
|
||||
def test_hard_stop_at_limit(self):
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
|
||||
@@ -257,6 +440,7 @@ class TestLoopDetection:
|
||||
mw.reset()
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is None
|
||||
assert not mw._pending_warnings.get(_pending_key())
|
||||
|
||||
def test_non_ai_message_ignored(self):
|
||||
mw = LoopDetectionMiddleware()
|
||||
@@ -283,15 +467,16 @@ class TestLoopDetection:
|
||||
# One call on thread B
|
||||
mw._apply(_make_state(tool_calls=call), runtime_b)
|
||||
|
||||
# Second call on thread A — triggers warning (2 >= warn_threshold)
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime_a)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
# Second call on thread A — queues warning under thread-A only.
|
||||
mw._apply(_make_state(tool_calls=call), runtime_a)
|
||||
assert mw._pending_warnings.get(_pending_key("thread-A"))
|
||||
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-A")][0]
|
||||
assert not mw._pending_warnings.get(_pending_key("thread-B"))
|
||||
|
||||
# Second call on thread B — also triggers (independent tracking)
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime_b)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
# Second call on thread B — independent queue.
|
||||
mw._apply(_make_state(tool_calls=call), runtime_b)
|
||||
assert mw._pending_warnings.get(_pending_key("thread-B"))
|
||||
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-B")][0]
|
||||
|
||||
def test_lru_eviction(self):
|
||||
"""Old threads should be evicted when max_tracked_threads is exceeded."""
|
||||
@@ -313,6 +498,55 @@ class TestLoopDetection:
|
||||
assert "thread-new" in mw._history
|
||||
assert len(mw._history) == 3
|
||||
|
||||
def test_warned_hashes_are_pruned_to_sliding_window(self):
|
||||
"""A long-lived thread should not keep every historical warned hash."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=100, window_size=4)
|
||||
runtime = _make_runtime()
|
||||
|
||||
for i in range(12):
|
||||
call = [_bash_call(f"cmd_{i}")]
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
assert len(mw._history["test-thread"]) <= 4
|
||||
assert set(mw._warned["test-thread"]).issubset(set(mw._history["test-thread"]))
|
||||
assert len(mw._warned["test-thread"]) <= 4
|
||||
|
||||
def test_pending_warning_keys_are_capped(self):
|
||||
"""Abnormal same-thread runs cannot grow pending-warning keys forever."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2, max_tracked_threads=2)
|
||||
|
||||
for i in range(10):
|
||||
runtime = _make_runtime(thread_id="same-thread", run_id=f"run-{i}")
|
||||
mw._queue_pending_warning(runtime, f"warning-{i}")
|
||||
|
||||
assert len(mw._pending_warnings) == mw._max_pending_warning_keys
|
||||
assert len(mw._pending_warning_touch_order) == mw._max_pending_warning_keys
|
||||
assert _pending_key("same-thread", "run-9") in mw._pending_warnings
|
||||
|
||||
def test_pending_warning_list_is_capped_and_deduped(self):
|
||||
"""One run cannot accumulate an unbounded warning list."""
|
||||
mw = LoopDetectionMiddleware()
|
||||
runtime = _make_runtime()
|
||||
|
||||
for i in range(_MAX_PENDING_WARNINGS_PER_RUN + 4):
|
||||
mw._queue_pending_warning(runtime, f"warning-{i}")
|
||||
mw._queue_pending_warning(runtime, f"warning-{_MAX_PENDING_WARNINGS_PER_RUN + 3}")
|
||||
|
||||
warnings = mw._pending_warnings[_pending_key()]
|
||||
assert len(warnings) == _MAX_PENDING_WARNINGS_PER_RUN
|
||||
assert warnings == [f"warning-{i}" for i in range(4, _MAX_PENDING_WARNINGS_PER_RUN + 4)]
|
||||
|
||||
def test_pending_warning_touch_order_cleared_with_pending_key(self):
|
||||
mw = LoopDetectionMiddleware()
|
||||
runtime = _make_runtime()
|
||||
mw._queue_pending_warning(runtime, "warning")
|
||||
|
||||
mw.after_agent({"messages": []}, runtime)
|
||||
|
||||
assert mw._pending_warnings == {}
|
||||
assert mw._pending_warning_touch_order == OrderedDict()
|
||||
|
||||
def test_thread_safe_mutations(self):
|
||||
"""Verify lock is used for mutations (basic structural test)."""
|
||||
mw = LoopDetectionMiddleware()
|
||||
@@ -331,6 +565,99 @@ class TestLoopDetection:
|
||||
assert "default" in mw._history
|
||||
|
||||
|
||||
class TestLoopDetectionAgentGraphIntegration:
|
||||
def test_loop_warning_is_transient_in_real_agent_graph(self):
|
||||
"""after_model queues the warning; wrap_model_call injects it request-only."""
|
||||
|
||||
@as_tool
|
||||
def bash(command: str) -> str:
|
||||
"""Run a fake shell command."""
|
||||
return f"ran: {command}"
|
||||
|
||||
repeated_calls = [[{"name": "bash", "id": f"call_ls_{i}", "args": {"command": "ls"}}] for i in range(3)]
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
||||
model = _CapturingFakeMessagesListChatModel(
|
||||
responses=[
|
||||
AIMessage(content="", tool_calls=repeated_calls[0]),
|
||||
AIMessage(content="", tool_calls=repeated_calls[1]),
|
||||
AIMessage(content="", tool_calls=repeated_calls[2]),
|
||||
AIMessage(content="final answer"),
|
||||
],
|
||||
)
|
||||
graph = create_agent(model=model, tools=[bash], middleware=[mw])
|
||||
|
||||
result = graph.invoke(
|
||||
{"messages": [("user", "inspect the directory")]},
|
||||
context={"thread_id": "integration-thread", "run_id": "integration-run"},
|
||||
config={"recursion_limit": 20},
|
||||
)
|
||||
|
||||
assert len(model.seen_messages) == 4
|
||||
loop_warnings_by_call = [[message for message in messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] for messages in model.seen_messages]
|
||||
assert loop_warnings_by_call[0] == []
|
||||
assert loop_warnings_by_call[1] == []
|
||||
assert loop_warnings_by_call[2] == []
|
||||
assert len(loop_warnings_by_call[3]) == 1
|
||||
assert "LOOP DETECTED" in loop_warnings_by_call[3][0].content
|
||||
|
||||
fourth_request = model.seen_messages[3]
|
||||
assert isinstance(fourth_request[-2], ToolMessage)
|
||||
assert fourth_request[-2].tool_call_id == "call_ls_2"
|
||||
assert fourth_request[-1] is loop_warnings_by_call[3][0]
|
||||
|
||||
persisted_loop_warnings = [message for message in result["messages"] if isinstance(message, HumanMessage) and message.name == "loop_warning"]
|
||||
assert persisted_loop_warnings == []
|
||||
assert result["messages"][-1].content == "final answer"
|
||||
assert mw._pending_warnings == {}
|
||||
assert mw._pending_warning_touch_order == OrderedDict()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_warning_is_transient_in_async_agent_graph(self):
|
||||
"""awrap_model_call injects loop_warning request-only in async graph runs."""
|
||||
|
||||
@as_tool
|
||||
async def bash(command: str) -> str:
|
||||
"""Run a fake shell command."""
|
||||
return f"ran: {command}"
|
||||
|
||||
repeated_calls = [[{"name": "bash", "id": f"call_async_ls_{i}", "args": {"command": "ls"}}] for i in range(3)]
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
||||
model = _CapturingFakeMessagesListChatModel(
|
||||
responses=[
|
||||
AIMessage(content="", tool_calls=repeated_calls[0]),
|
||||
AIMessage(content="", tool_calls=repeated_calls[1]),
|
||||
AIMessage(content="", tool_calls=repeated_calls[2]),
|
||||
AIMessage(content="async final answer"),
|
||||
],
|
||||
)
|
||||
graph = create_agent(model=model, tools=[bash], middleware=[mw])
|
||||
|
||||
result = await graph.ainvoke(
|
||||
{"messages": [("user", "inspect the directory asynchronously")]},
|
||||
context={"thread_id": "async-integration-thread", "run_id": "async-integration-run"},
|
||||
config={"recursion_limit": 20},
|
||||
)
|
||||
|
||||
assert len(model.seen_messages) == 4
|
||||
loop_warnings_by_call = [[message for message in messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] for messages in model.seen_messages]
|
||||
assert loop_warnings_by_call[0] == []
|
||||
assert loop_warnings_by_call[1] == []
|
||||
assert loop_warnings_by_call[2] == []
|
||||
assert len(loop_warnings_by_call[3]) == 1
|
||||
assert "LOOP DETECTED" in loop_warnings_by_call[3][0].content
|
||||
|
||||
fourth_request = model.seen_messages[3]
|
||||
assert isinstance(fourth_request[-2], ToolMessage)
|
||||
assert fourth_request[-2].tool_call_id == "call_async_ls_2"
|
||||
assert fourth_request[-1] is loop_warnings_by_call[3][0]
|
||||
|
||||
persisted_loop_warnings = [message for message in result["messages"] if isinstance(message, HumanMessage) and message.name == "loop_warning"]
|
||||
assert persisted_loop_warnings == []
|
||||
assert result["messages"][-1].content == "async final answer"
|
||||
assert mw._pending_warnings == {}
|
||||
assert mw._pending_warning_touch_order == OrderedDict()
|
||||
|
||||
|
||||
class TestAppendText:
|
||||
"""Unit tests for LoopDetectionMiddleware._append_text."""
|
||||
|
||||
@@ -507,33 +834,29 @@ class TestToolFrequencyDetection:
|
||||
for i in range(4):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
|
||||
|
||||
# 5th call to read_file (different file each time) triggers freq warning
|
||||
# 5th call queues a per-tool-type frequency warning; state untouched.
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_4.py")]), runtime)
|
||||
assert result is not None
|
||||
msg = result["messages"][0]
|
||||
# Warning is appended to the AIMessage content; tool_calls preserved
|
||||
# so the tools node still runs and Moonshot/OpenAI tool-call pairing
|
||||
# validation does not break.
|
||||
assert isinstance(msg, AIMessage)
|
||||
assert msg.tool_calls
|
||||
assert "read_file" in msg.content
|
||||
assert "LOOP DETECTED" in msg.content
|
||||
assert result is None
|
||||
queued = mw._pending_warnings.get(_pending_key(), [])
|
||||
assert queued
|
||||
assert "read_file" in queued[0]
|
||||
assert "LOOP DETECTED" in queued[0]
|
||||
|
||||
def test_freq_warn_only_injected_once(self):
|
||||
def test_freq_warn_only_queued_once(self):
|
||||
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10)
|
||||
runtime = _make_runtime()
|
||||
|
||||
for i in range(2):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
|
||||
|
||||
# 3rd triggers warning
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
# 3rd queues a frequency warning.
|
||||
mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
|
||||
assert len(mw._pending_warnings[_pending_key()]) == 1
|
||||
|
||||
# 4th should not re-warn (already warned for read_file)
|
||||
# 4th: same tool name, no additional enqueue.
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_3.py")]), runtime)
|
||||
assert result is None
|
||||
assert len(mw._pending_warnings[_pending_key()]) == 1
|
||||
|
||||
def test_freq_hard_stop_at_limit(self):
|
||||
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=6)
|
||||
@@ -565,10 +888,10 @@ class TestToolFrequencyDetection:
|
||||
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
|
||||
assert result is None
|
||||
|
||||
# 3rd read_file triggers (read_file count = 3)
|
||||
# 3rd read_file triggers — warning is queued (state unchanged).
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
|
||||
assert result is not None
|
||||
assert "read_file" in result["messages"][0].content
|
||||
assert result is None
|
||||
assert "read_file" in mw._pending_warnings[_pending_key()][0]
|
||||
|
||||
def test_freq_reset_clears_state(self):
|
||||
mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10)
|
||||
@@ -600,10 +923,10 @@ class TestToolFrequencyDetection:
|
||||
assert "thread-A" not in mw._tool_freq
|
||||
assert "thread-A" not in mw._tool_freq_warned
|
||||
|
||||
# thread-B state should still be intact — 3rd call triggers warn
|
||||
# thread-B state should still be intact — 3rd call queues a warn.
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/b_2.py")]), runtime_b)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
assert result is None
|
||||
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-B")][0]
|
||||
|
||||
# thread-A restarted from 0 — should not trigger
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/a_new.py")]), runtime_a)
|
||||
@@ -623,10 +946,11 @@ class TestToolFrequencyDetection:
|
||||
for i in range(2):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/other_{i}.py")]), runtime_b)
|
||||
|
||||
# 3rd call on thread A — triggers (count=3 for thread A only)
|
||||
# 3rd call on thread A — queues a warning (count=3 for thread A only).
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime_a)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
assert result is None
|
||||
assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-A")][0]
|
||||
assert not mw._pending_warnings.get(_pending_key("thread-B"))
|
||||
|
||||
def test_multi_tool_single_response_counted(self):
|
||||
"""When a single response has multiple tool calls, each is counted."""
|
||||
@@ -643,10 +967,10 @@ class TestToolFrequencyDetection:
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is None
|
||||
|
||||
# Response 3: 1 more → count = 5 → triggers warn
|
||||
# Response 3: 1 more → count = 5 → queues warn.
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/e.py")]), runtime)
|
||||
assert result is not None
|
||||
assert "read_file" in result["messages"][0].content
|
||||
assert result is None
|
||||
assert "read_file" in mw._pending_warnings[_pending_key()][0]
|
||||
|
||||
def test_override_tool_uses_override_thresholds(self):
|
||||
"""A tool in tool_freq_overrides uses its own thresholds, not the global ones."""
|
||||
@@ -674,10 +998,14 @@ class TestToolFrequencyDetection:
|
||||
for i in range(2):
|
||||
mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime)
|
||||
|
||||
# 3rd read_file call hits global warn=3 (read_file has no override)
|
||||
# 3rd read_file call hits global warn=3 (read_file has no override).
|
||||
# Warning delivery is deferred to wrap_model_call so the just-emitted
|
||||
# AIMessage(tool_calls=...) is not mutated before ToolMessages exist.
|
||||
result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime)
|
||||
assert result is not None
|
||||
assert "read_file" in result["messages"][0].content
|
||||
assert result is None
|
||||
queued = mw._pending_warnings.get(_pending_key(), [])
|
||||
assert queued
|
||||
assert "read_file" in queued[0]
|
||||
|
||||
def test_hash_detection_takes_priority(self):
|
||||
"""Hash-based hard stop fires before frequency check for identical calls."""
|
||||
@@ -736,11 +1064,13 @@ class TestFromConfig:
|
||||
mw = LoopDetectionMiddleware.from_config(self._config())
|
||||
assert mw._tool_freq_overrides == {}
|
||||
|
||||
def test_constructed_middleware_detects_loops(self):
|
||||
def test_constructed_middleware_queues_loop_warning(self):
|
||||
mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4))
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
assert result is None
|
||||
queued = mw._pending_warnings.get(_pending_key(), [])
|
||||
assert queued
|
||||
assert "LOOP DETECTED" in queued[0]
|
||||
|
||||
@@ -50,6 +50,8 @@ Intercepts clarification tool calls and converts them into proper user-facing re
|
||||
|
||||
Detects when the agent is making the same tool call repeatedly without making progress. When a loop is detected, the middleware intervenes to break the cycle and prevents the agent from burning turns indefinitely.
|
||||
|
||||
Warning interventions are queued per thread and run, then drained on the next model call as a single hidden `HumanMessage(name="loop_warning")` appended after existing tool results. This keeps provider tool-call pairing valid. Run start/end hooks clear stale or undelivered warnings, and hard stops still strip tool calls before forcing a final text response.
|
||||
|
||||
**Configuration**: built-in, no user configuration.
|
||||
|
||||
---
|
||||
|
||||
@@ -50,6 +50,8 @@ import { Callout } from "nextra/components";
|
||||
|
||||
检测 Agent 是否在没有取得进展的情况下重复进行相同的工具调用。检测到循环时,中间件会介入打破循环,防止 Agent 无限消耗轮次。
|
||||
|
||||
Warning 介入会按 thread 和 run 排队,并在下一次模型调用时合并为一条隐藏的 `HumanMessage(name="loop_warning")`,追加到已有工具结果之后。这样不会破坏 provider 对 tool-call/tool-message 配对的校验。Run 开始和结束时会清理过期或未送达的 warning;达到 hard stop 时仍会清空 tool calls 并强制生成最终文本回复。
|
||||
|
||||
**配置**:内置,无需用户配置。
|
||||
|
||||
---
|
||||
|
||||
Reference in New Issue
Block a user