fix(middleware): Prevent todo completion reminder IMMessage leak (#2907)
* fix(middleware): Prevent todo completion reminder IMMessage leak (#2892) * make format * fix(middleware): Clear stale todo reminder counts (#2892) * add size guard for _completion_reminder_counts and add a integration test
This commit is contained in:
@@ -7,17 +7,21 @@ reminder message so the model still knows about the outstanding todo list.
|
|||||||
|
|
||||||
Additionally, this middleware prevents the agent from exiting the loop while
|
Additionally, this middleware prevents the agent from exiting the loop while
|
||||||
there are still incomplete todo items. When the model produces a final response
|
there are still incomplete todo items. When the model produces a final response
|
||||||
(no tool calls) but todos are not yet complete, the middleware injects a reminder
|
(no tool calls) but todos are not yet complete, the middleware queues a reminder
|
||||||
and jumps back to the model node to force continued engagement.
|
for the next model request and jumps back to the model node to force continued
|
||||||
|
engagement. The completion reminder is injected via ``wrap_model_call`` instead
|
||||||
|
of being persisted into graph state as a normal user-visible message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any, override
|
from typing import Any, override
|
||||||
|
|
||||||
from langchain.agents.middleware import TodoListMiddleware
|
from langchain.agents.middleware import TodoListMiddleware
|
||||||
from langchain.agents.middleware.todo import PlanningState, Todo
|
from langchain.agents.middleware.todo import PlanningState, Todo
|
||||||
from langchain.agents.middleware.types import hook_config
|
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
@@ -55,6 +59,51 @@ def _format_todos(todos: list[Todo]) -> str:
|
|||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_completion_reminder(todos: list[Todo]) -> str:
|
||||||
|
"""Format a completion reminder for incomplete todo items."""
|
||||||
|
incomplete = [t for t in todos if t.get("status") != "completed"]
|
||||||
|
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
|
||||||
|
return (
|
||||||
|
"<system_reminder>\n"
|
||||||
|
"You have incomplete todo items that must be finished before giving your final response:\n\n"
|
||||||
|
f"{incomplete_text}\n\n"
|
||||||
|
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
|
||||||
|
"as you finish them, and only respond when all items are done.\n"
|
||||||
|
"</system_reminder>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_TOOL_CALL_FINISH_REASONS = {"tool_calls", "function_call"}
|
||||||
|
|
||||||
|
|
||||||
|
def _has_tool_call_intent_or_error(message: AIMessage) -> bool:
|
||||||
|
"""Return True when an AIMessage is not a clean final answer.
|
||||||
|
|
||||||
|
Todo completion reminders should only fire when the model has produced a
|
||||||
|
plain final response. Provider/tool parsing details have moved across
|
||||||
|
LangChain versions and integrations, so keep all tool-intent/error signals
|
||||||
|
behind this helper instead of checking one concrete field at the call site.
|
||||||
|
"""
|
||||||
|
if message.tool_calls:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if getattr(message, "invalid_tool_calls", None):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Backward/provider compatibility: some integrations preserve raw or legacy
|
||||||
|
# tool-call intent in additional_kwargs even when structured tool_calls is
|
||||||
|
# empty. If this helper changes, update the matching sentinel test
|
||||||
|
# `TestToolCallIntentOrError.test_langchain_ai_message_tool_fields_are_explicitly_handled`;
|
||||||
|
# if that test fails after a LangChain upgrade, review this helper so new
|
||||||
|
# tool-call/error fields are not silently treated as clean final answers.
|
||||||
|
additional_kwargs = getattr(message, "additional_kwargs", {}) or {}
|
||||||
|
if additional_kwargs.get("tool_calls") or additional_kwargs.get("function_call"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
response_metadata = getattr(message, "response_metadata", {}) or {}
|
||||||
|
return response_metadata.get("finish_reason") in _TOOL_CALL_FINISH_REASONS
|
||||||
|
|
||||||
|
|
||||||
class TodoMiddleware(TodoListMiddleware):
|
class TodoMiddleware(TodoListMiddleware):
|
||||||
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
|
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
|
||||||
|
|
||||||
@@ -89,6 +138,7 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
formatted = _format_todos(todos)
|
formatted = _format_todos(todos)
|
||||||
reminder = HumanMessage(
|
reminder = HumanMessage(
|
||||||
name="todo_reminder",
|
name="todo_reminder",
|
||||||
|
additional_kwargs={"hide_from_ui": True},
|
||||||
content=(
|
content=(
|
||||||
"<system_reminder>\n"
|
"<system_reminder>\n"
|
||||||
"Your todo list from earlier is no longer visible in the current context window, "
|
"Your todo list from earlier is no longer visible in the current context window, "
|
||||||
@@ -113,6 +163,100 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
# Maximum number of completion reminders before allowing the agent to exit.
|
# Maximum number of completion reminders before allowing the agent to exit.
|
||||||
# This prevents infinite loops when the agent cannot make further progress.
|
# This prevents infinite loops when the agent cannot make further progress.
|
||||||
_MAX_COMPLETION_REMINDERS = 2
|
_MAX_COMPLETION_REMINDERS = 2
|
||||||
|
# Hard cap for per-run reminder bookkeeping in long-lived middleware instances.
|
||||||
|
_MAX_COMPLETION_REMINDER_KEYS = 4096
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._pending_completion_reminders: dict[tuple[str, str], list[str]] = {}
|
||||||
|
self._completion_reminder_counts: dict[tuple[str, str], int] = {}
|
||||||
|
self._completion_reminder_touch_order: dict[tuple[str, str], int] = {}
|
||||||
|
self._completion_reminder_next_order = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_thread_id(runtime: Runtime) -> str:
|
||||||
|
context = getattr(runtime, "context", None)
|
||||||
|
thread_id = context.get("thread_id") if context else None
|
||||||
|
return str(thread_id) if thread_id else "default"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_run_id(runtime: Runtime) -> str:
|
||||||
|
context = getattr(runtime, "context", None)
|
||||||
|
run_id = context.get("run_id") if context else None
|
||||||
|
return str(run_id) if run_id else "default"
|
||||||
|
|
||||||
|
def _pending_key(self, runtime: Runtime) -> tuple[str, str]:
|
||||||
|
return self._get_thread_id(runtime), self._get_run_id(runtime)
|
||||||
|
|
||||||
|
def _touch_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
|
||||||
|
self._completion_reminder_next_order += 1
|
||||||
|
self._completion_reminder_touch_order[key] = self._completion_reminder_next_order
|
||||||
|
|
||||||
|
def _completion_reminder_keys_locked(self) -> set[tuple[str, str]]:
|
||||||
|
keys = set(self._pending_completion_reminders)
|
||||||
|
keys.update(self._completion_reminder_counts)
|
||||||
|
keys.update(self._completion_reminder_touch_order)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
def _drop_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
|
||||||
|
self._pending_completion_reminders.pop(key, None)
|
||||||
|
self._completion_reminder_counts.pop(key, None)
|
||||||
|
self._completion_reminder_touch_order.pop(key, None)
|
||||||
|
|
||||||
|
def _prune_completion_reminder_state_locked(self, protected_key: tuple[str, str]) -> None:
|
||||||
|
keys = self._completion_reminder_keys_locked()
|
||||||
|
overflow = len(keys) - self._MAX_COMPLETION_REMINDER_KEYS
|
||||||
|
if overflow <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
candidates = [key for key in keys if key != protected_key]
|
||||||
|
candidates.sort(key=lambda key: self._completion_reminder_touch_order.get(key, 0))
|
||||||
|
for key in candidates[:overflow]:
|
||||||
|
self._drop_completion_reminder_key_locked(key)
|
||||||
|
|
||||||
|
def _queue_completion_reminder(self, runtime: Runtime, reminder: str) -> None:
|
||||||
|
key = self._pending_key(runtime)
|
||||||
|
with self._lock:
|
||||||
|
self._pending_completion_reminders.setdefault(key, []).append(reminder)
|
||||||
|
self._completion_reminder_counts[key] = self._completion_reminder_counts.get(key, 0) + 1
|
||||||
|
self._touch_completion_reminder_key_locked(key)
|
||||||
|
self._prune_completion_reminder_state_locked(protected_key=key)
|
||||||
|
|
||||||
|
def _completion_reminder_count_for_runtime(self, runtime: Runtime) -> int:
|
||||||
|
key = self._pending_key(runtime)
|
||||||
|
with self._lock:
|
||||||
|
return self._completion_reminder_counts.get(key, 0)
|
||||||
|
|
||||||
|
def _drain_completion_reminders(self, runtime: Runtime) -> list[str]:
|
||||||
|
key = self._pending_key(runtime)
|
||||||
|
with self._lock:
|
||||||
|
reminders = self._pending_completion_reminders.pop(key, [])
|
||||||
|
if reminders or key in self._completion_reminder_counts:
|
||||||
|
self._touch_completion_reminder_key_locked(key)
|
||||||
|
return reminders
|
||||||
|
|
||||||
|
def _clear_other_run_completion_reminders(self, runtime: Runtime) -> None:
|
||||||
|
thread_id, current_run_id = self._pending_key(runtime)
|
||||||
|
with self._lock:
|
||||||
|
for key in self._completion_reminder_keys_locked():
|
||||||
|
if key[0] == thread_id and key[1] != current_run_id:
|
||||||
|
self._drop_completion_reminder_key_locked(key)
|
||||||
|
|
||||||
|
def _clear_current_run_completion_reminders(self, runtime: Runtime) -> None:
|
||||||
|
key = self._pending_key(runtime)
|
||||||
|
with self._lock:
|
||||||
|
self._drop_completion_reminder_key_locked(key)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
|
self._clear_other_run_completion_reminders(runtime)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
|
self._clear_other_run_completion_reminders(runtime)
|
||||||
|
return None
|
||||||
|
|
||||||
@hook_config(can_jump_to=["model"])
|
@hook_config(can_jump_to=["model"])
|
||||||
@override
|
@override
|
||||||
@@ -137,10 +281,12 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
if base_result is not None:
|
if base_result is not None:
|
||||||
return base_result
|
return base_result
|
||||||
|
|
||||||
# 2. Only intervene when the agent wants to exit (no tool calls).
|
# 2. Only intervene when the agent wants to exit cleanly. Tool-call
|
||||||
|
# intent or tool-call parse errors should be handled by the tool path
|
||||||
|
# instead of being masked by todo reminders.
|
||||||
messages = state.get("messages") or []
|
messages = state.get("messages") or []
|
||||||
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
|
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
|
||||||
if not last_ai or last_ai.tool_calls:
|
if not last_ai or _has_tool_call_intent_or_error(last_ai):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 3. Allow exit when all todos are completed or there are no todos.
|
# 3. Allow exit when all todos are completed or there are no todos.
|
||||||
@@ -149,24 +295,14 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
|
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
|
||||||
if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS:
|
if self._completion_reminder_count_for_runtime(runtime) >= self._MAX_COMPLETION_REMINDERS:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 5. Inject a reminder and force the agent back to the model.
|
# 5. Queue a reminder for the next model request and jump back. We must
|
||||||
incomplete = [t for t in todos if t.get("status") != "completed"]
|
# not persist this control prompt as a normal HumanMessage, otherwise it
|
||||||
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
|
# can leak into user-visible message streams and saved transcripts.
|
||||||
reminder = HumanMessage(
|
self._queue_completion_reminder(runtime, _format_completion_reminder(todos))
|
||||||
name="todo_completion_reminder",
|
return {"jump_to": "model"}
|
||||||
content=(
|
|
||||||
"<system_reminder>\n"
|
|
||||||
"You have incomplete todo items that must be finished before giving your final response:\n\n"
|
|
||||||
f"{incomplete_text}\n\n"
|
|
||||||
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
|
|
||||||
"as you finish them, and only respond when all items are done.\n"
|
|
||||||
"</system_reminder>"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return {"jump_to": "model", "messages": [reminder]}
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@hook_config(can_jump_to=["model"])
|
@hook_config(can_jump_to=["model"])
|
||||||
@@ -177,3 +313,47 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Async version of after_model."""
|
"""Async version of after_model."""
|
||||||
return self.after_model(state, runtime)
|
return self.after_model(state, runtime)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_pending_completion_reminders(reminders: list[str]) -> str:
|
||||||
|
return "\n\n".join(dict.fromkeys(reminders))
|
||||||
|
|
||||||
|
def _augment_request(self, request: ModelRequest) -> ModelRequest:
|
||||||
|
reminders = self._drain_completion_reminders(request.runtime)
|
||||||
|
if not reminders:
|
||||||
|
return request
|
||||||
|
new_messages = [
|
||||||
|
*request.messages,
|
||||||
|
HumanMessage(
|
||||||
|
content=self._format_pending_completion_reminders(reminders),
|
||||||
|
name="todo_completion_reminder",
|
||||||
|
additional_kwargs={"hide_from_ui": True},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return request.override(messages=new_messages)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def wrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
|
) -> ModelCallResult:
|
||||||
|
return handler(self._augment_request(request))
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def awrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
|
) -> ModelCallResult:
|
||||||
|
return await handler(self._augment_request(request))
|
||||||
|
|
||||||
|
@override
|
||||||
|
def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
|
self._clear_current_run_completion_reminders(runtime)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
|
self._clear_current_run_completion_reminders(runtime)
|
||||||
|
return None
|
||||||
|
|||||||
@@ -1,14 +1,19 @@
|
|||||||
"""Tests for TodoMiddleware context-loss detection."""
|
"""Tests for TodoMiddleware context-loss detection."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import MagicMock
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
from pydantic import PrivateAttr
|
||||||
|
|
||||||
from deerflow.agents.middlewares.todo_middleware import (
|
from deerflow.agents.middlewares.todo_middleware import (
|
||||||
TodoMiddleware,
|
TodoMiddleware,
|
||||||
_completion_reminder_count,
|
_completion_reminder_count,
|
||||||
_format_todos,
|
_format_todos,
|
||||||
|
_has_tool_call_intent_or_error,
|
||||||
_reminder_in_messages,
|
_reminder_in_messages,
|
||||||
_todos_in_messages,
|
_todos_in_messages,
|
||||||
)
|
)
|
||||||
@@ -22,9 +27,35 @@ def _reminder_msg():
|
|||||||
return HumanMessage(name="todo_reminder", content="reminder")
|
return HumanMessage(name="todo_reminder", content="reminder")
|
||||||
|
|
||||||
|
|
||||||
|
class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel):
|
||||||
|
_seen_messages: list[list[Any]] = PrivateAttr(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def seen_messages(self) -> list[list[Any]]:
|
||||||
|
return self._seen_messages
|
||||||
|
|
||||||
|
def bind_tools(self, tools, *, tool_choice=None, **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||||
|
self._seen_messages.append(list(messages))
|
||||||
|
return super()._generate(
|
||||||
|
messages,
|
||||||
|
stop=stop,
|
||||||
|
run_manager=run_manager,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _make_runtime():
|
def _make_runtime():
|
||||||
runtime = MagicMock()
|
runtime = MagicMock()
|
||||||
runtime.context = {"thread_id": "test-thread"}
|
runtime.context = {"thread_id": "test-thread", "run_id": "test-run"}
|
||||||
|
return runtime
|
||||||
|
|
||||||
|
|
||||||
|
def _make_runtime_for(thread_id: str, run_id: str):
|
||||||
|
runtime = _make_runtime()
|
||||||
|
runtime.context = {"thread_id": thread_id, "run_id": run_id}
|
||||||
return runtime
|
return runtime
|
||||||
|
|
||||||
|
|
||||||
@@ -161,10 +192,62 @@ def _completion_reminder_msg():
|
|||||||
return HumanMessage(name="todo_completion_reminder", content="finish your todos")
|
return HumanMessage(name="todo_completion_reminder", content="finish your todos")
|
||||||
|
|
||||||
|
|
||||||
|
def _todo_completion_reminders(messages):
|
||||||
|
reminders = []
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, HumanMessage) and message.name == "todo_completion_reminder":
|
||||||
|
reminders.append(message)
|
||||||
|
return reminders
|
||||||
|
|
||||||
|
|
||||||
def _ai_no_tool_calls():
|
def _ai_no_tool_calls():
|
||||||
return AIMessage(content="I'm done!")
|
return AIMessage(content="I'm done!")
|
||||||
|
|
||||||
|
|
||||||
|
def _ai_with_invalid_tool_calls():
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[],
|
||||||
|
invalid_tool_calls=[
|
||||||
|
{
|
||||||
|
"type": "invalid_tool_call",
|
||||||
|
"id": "write_file:36",
|
||||||
|
"name": "write_file",
|
||||||
|
"args": "{invalid",
|
||||||
|
"error": "Failed to parse tool arguments",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ai_with_raw_provider_tool_calls():
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[],
|
||||||
|
invalid_tool_calls=[],
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "raw-tool-call",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "write_file", "arguments": '{"path":"report.md"}'},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ai_with_legacy_function_call():
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={"function_call": {"name": "write_file", "arguments": '{"path":"report.md"}'}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ai_with_tool_finish_reason():
|
||||||
|
return AIMessage(content="", response_metadata={"finish_reason": "tool_calls"})
|
||||||
|
|
||||||
|
|
||||||
def _incomplete_todos():
|
def _incomplete_todos():
|
||||||
return [
|
return [
|
||||||
{"status": "completed", "content": "Step 1"},
|
{"status": "completed", "content": "Step 1"},
|
||||||
@@ -194,6 +277,36 @@ class TestCompletionReminderCount:
|
|||||||
assert _completion_reminder_count(msgs) == 1
|
assert _completion_reminder_count(msgs) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolCallIntentOrError:
|
||||||
|
def test_false_for_plain_final_answer(self):
|
||||||
|
assert _has_tool_call_intent_or_error(_ai_no_tool_calls()) is False
|
||||||
|
|
||||||
|
def test_true_for_structured_tool_calls(self):
|
||||||
|
assert _has_tool_call_intent_or_error(_ai_with_write_todos()) is True
|
||||||
|
|
||||||
|
def test_true_for_invalid_tool_calls(self):
|
||||||
|
assert _has_tool_call_intent_or_error(_ai_with_invalid_tool_calls()) is True
|
||||||
|
|
||||||
|
def test_true_for_raw_provider_tool_calls(self):
|
||||||
|
assert _has_tool_call_intent_or_error(_ai_with_raw_provider_tool_calls()) is True
|
||||||
|
|
||||||
|
def test_true_for_legacy_function_call(self):
|
||||||
|
assert _has_tool_call_intent_or_error(_ai_with_legacy_function_call()) is True
|
||||||
|
|
||||||
|
def test_true_for_tool_finish_reason(self):
|
||||||
|
assert _has_tool_call_intent_or_error(_ai_with_tool_finish_reason()) is True
|
||||||
|
|
||||||
|
def test_langchain_ai_message_tool_fields_are_explicitly_handled(self):
|
||||||
|
# Sentinel for LangChain compatibility: if future AIMessage versions add
|
||||||
|
# new top-level tool/function-call fields, this test should fail. When
|
||||||
|
# it does, update `_has_tool_call_intent_or_error()` so the completion
|
||||||
|
# reminder guard explicitly decides whether each new field means "not a
|
||||||
|
# clean final answer"; the helper has a matching comment pointing back
|
||||||
|
# to this sentinel.
|
||||||
|
tool_related_fields = {name for name in AIMessage.model_fields if "tool" in name.lower() or ("function" in name.lower() and "call" in name.lower())}
|
||||||
|
assert tool_related_fields <= {"tool_calls", "invalid_tool_calls"}
|
||||||
|
|
||||||
|
|
||||||
class TestAfterModel:
|
class TestAfterModel:
|
||||||
def test_returns_none_when_agent_still_using_tools(self):
|
def test_returns_none_when_agent_still_using_tools(self):
|
||||||
mw = TodoMiddleware()
|
mw = TodoMiddleware()
|
||||||
@@ -235,68 +348,299 @@ class TestAfterModel:
|
|||||||
}
|
}
|
||||||
assert mw.after_model(state, _make_runtime()) is None
|
assert mw.after_model(state, _make_runtime()) is None
|
||||||
|
|
||||||
def test_injects_reminder_and_jumps_to_model_when_incomplete(self):
|
def test_queues_reminder_and_jumps_to_model_when_incomplete(self):
|
||||||
mw = TodoMiddleware()
|
mw = TodoMiddleware()
|
||||||
|
runtime = _make_runtime()
|
||||||
state = {
|
state = {
|
||||||
"messages": [HumanMessage(content="hi"), _ai_no_tool_calls()],
|
"messages": [HumanMessage(content="hi"), _ai_no_tool_calls()],
|
||||||
"todos": _incomplete_todos(),
|
"todos": _incomplete_todos(),
|
||||||
}
|
}
|
||||||
result = mw.after_model(state, _make_runtime())
|
result = mw.after_model(state, runtime)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["jump_to"] == "model"
|
assert result["jump_to"] == "model"
|
||||||
assert len(result["messages"]) == 1
|
assert "messages" not in result
|
||||||
reminder = result["messages"][0]
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.runtime = runtime
|
||||||
|
request.messages = state["messages"]
|
||||||
|
request.override.return_value = "patched-request"
|
||||||
|
handler = MagicMock(return_value="response")
|
||||||
|
|
||||||
|
assert mw.wrap_model_call(request, handler) == "response"
|
||||||
|
request.override.assert_called_once()
|
||||||
|
reminder = request.override.call_args.kwargs["messages"][-1]
|
||||||
assert isinstance(reminder, HumanMessage)
|
assert isinstance(reminder, HumanMessage)
|
||||||
assert reminder.name == "todo_completion_reminder"
|
assert reminder.name == "todo_completion_reminder"
|
||||||
|
assert reminder.additional_kwargs["hide_from_ui"] is True
|
||||||
assert "Step 2" in reminder.content
|
assert "Step 2" in reminder.content
|
||||||
assert "Step 3" in reminder.content
|
assert "Step 3" in reminder.content
|
||||||
|
handler.assert_called_once_with("patched-request")
|
||||||
|
|
||||||
def test_reminder_lists_only_incomplete_items(self):
|
def test_reminder_lists_only_incomplete_items(self):
|
||||||
mw = TodoMiddleware()
|
mw = TodoMiddleware()
|
||||||
|
runtime = _make_runtime()
|
||||||
state = {
|
state = {
|
||||||
"messages": [_ai_no_tool_calls()],
|
"messages": [_ai_no_tool_calls()],
|
||||||
"todos": _incomplete_todos(),
|
"todos": _incomplete_todos(),
|
||||||
}
|
}
|
||||||
result = mw.after_model(state, _make_runtime())
|
result = mw.after_model(state, runtime)
|
||||||
content = result["messages"][0].content
|
assert result is not None
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.runtime = runtime
|
||||||
|
request.messages = state["messages"]
|
||||||
|
request.override.return_value = "patched-request"
|
||||||
|
mw.wrap_model_call(request, MagicMock(return_value="response"))
|
||||||
|
content = request.override.call_args.kwargs["messages"][-1].content
|
||||||
assert "Step 1" not in content # completed — should not appear
|
assert "Step 1" not in content # completed — should not appear
|
||||||
assert "Step 2" in content
|
assert "Step 2" in content
|
||||||
assert "Step 3" in content
|
assert "Step 3" in content
|
||||||
|
|
||||||
def test_allows_exit_after_max_reminders(self):
|
def test_allows_exit_after_max_reminders(self):
|
||||||
mw = TodoMiddleware()
|
mw = TodoMiddleware()
|
||||||
|
runtime = _make_runtime()
|
||||||
state = {
|
state = {
|
||||||
"messages": [
|
"messages": [
|
||||||
_completion_reminder_msg(),
|
|
||||||
_completion_reminder_msg(),
|
|
||||||
_ai_no_tool_calls(),
|
_ai_no_tool_calls(),
|
||||||
],
|
],
|
||||||
"todos": _incomplete_todos(),
|
"todos": _incomplete_todos(),
|
||||||
}
|
}
|
||||||
|
assert mw.after_model(state, runtime) is not None
|
||||||
|
assert mw.after_model(state, runtime) is not None
|
||||||
|
assert mw.after_model(state, runtime) is None
|
||||||
|
|
||||||
|
def test_still_sends_reminder_before_cap(self):
|
||||||
|
mw = TodoMiddleware()
|
||||||
|
runtime = _make_runtime()
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
_ai_no_tool_calls(),
|
||||||
|
],
|
||||||
|
"todos": _incomplete_todos(),
|
||||||
|
}
|
||||||
|
assert mw.after_model(state, runtime) is not None
|
||||||
|
result = mw.after_model(state, runtime)
|
||||||
|
assert result is not None
|
||||||
|
assert result["jump_to"] == "model"
|
||||||
|
|
||||||
|
def test_does_not_trigger_for_invalid_tool_calls(self):
|
||||||
|
mw = TodoMiddleware()
|
||||||
|
state = {
|
||||||
|
"messages": [_ai_with_invalid_tool_calls()],
|
||||||
|
"todos": _incomplete_todos(),
|
||||||
|
}
|
||||||
assert mw.after_model(state, _make_runtime()) is None
|
assert mw.after_model(state, _make_runtime()) is None
|
||||||
|
|
||||||
def test_still_sends_reminder_before_cap(self):
|
def test_does_not_trigger_for_raw_provider_tool_calls(self):
|
||||||
mw = TodoMiddleware()
|
mw = TodoMiddleware()
|
||||||
state = {
|
state = {
|
||||||
"messages": [
|
"messages": [_ai_with_raw_provider_tool_calls()],
|
||||||
_completion_reminder_msg(), # 1 reminder so far
|
|
||||||
_ai_no_tool_calls(),
|
|
||||||
],
|
|
||||||
"todos": _incomplete_todos(),
|
"todos": _incomplete_todos(),
|
||||||
}
|
}
|
||||||
result = mw.after_model(state, _make_runtime())
|
assert mw.after_model(state, _make_runtime()) is None
|
||||||
assert result is not None
|
|
||||||
assert result["jump_to"] == "model"
|
def test_does_not_trigger_for_legacy_function_call(self):
|
||||||
|
mw = TodoMiddleware()
|
||||||
|
state = {
|
||||||
|
"messages": [_ai_with_legacy_function_call()],
|
||||||
|
"todos": _incomplete_todos(),
|
||||||
|
}
|
||||||
|
assert mw.after_model(state, _make_runtime()) is None
|
||||||
|
|
||||||
|
def test_does_not_trigger_for_tool_finish_reason(self):
|
||||||
|
mw = TodoMiddleware()
|
||||||
|
state = {
|
||||||
|
"messages": [_ai_with_tool_finish_reason()],
|
||||||
|
"todos": _incomplete_todos(),
|
||||||
|
}
|
||||||
|
assert mw.after_model(state, _make_runtime()) is None
|
||||||
|
|
||||||
|
|
||||||
class TestAafterModel:
|
class TestAafterModel:
|
||||||
def test_delegates_to_sync(self):
|
def test_delegates_to_sync(self):
|
||||||
mw = TodoMiddleware()
|
mw = TodoMiddleware()
|
||||||
|
runtime = _make_runtime()
|
||||||
state = {
|
state = {
|
||||||
"messages": [_ai_no_tool_calls()],
|
"messages": [_ai_no_tool_calls()],
|
||||||
"todos": _incomplete_todos(),
|
"todos": _incomplete_todos(),
|
||||||
}
|
}
|
||||||
result = asyncio.run(mw.aafter_model(state, _make_runtime()))
|
result = asyncio.run(mw.aafter_model(state, runtime))
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["jump_to"] == "model"
|
assert result["jump_to"] == "model"
|
||||||
assert result["messages"][0].name == "todo_completion_reminder"
|
assert "messages" not in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestWrapModelCall:
|
||||||
|
def test_no_pending_reminder_passthrough(self):
|
||||||
|
mw = TodoMiddleware()
|
||||||
|
request = MagicMock()
|
||||||
|
request.runtime = _make_runtime()
|
||||||
|
request.messages = [HumanMessage(content="hi")]
|
||||||
|
handler = MagicMock(return_value="response")
|
||||||
|
|
||||||
|
assert mw.wrap_model_call(request, handler) == "response"
|
||||||
|
request.override.assert_not_called()
|
||||||
|
handler.assert_called_once_with(request)
|
||||||
|
|
||||||
|
def test_pending_reminder_is_injected_once(self):
|
||||||
|
mw = TodoMiddleware()
|
||||||
|
runtime = _make_runtime()
|
||||||
|
state = {
|
||||||
|
"messages": [_ai_no_tool_calls()],
|
||||||
|
"todos": _incomplete_todos(),
|
||||||
|
}
|
||||||
|
mw.after_model(state, runtime)
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.runtime = runtime
|
||||||
|
request.messages = state["messages"]
|
||||||
|
request.override.return_value = "patched-request"
|
||||||
|
handler = MagicMock(return_value="response")
|
||||||
|
|
||||||
|
assert mw.wrap_model_call(request, handler) == "response"
|
||||||
|
injected_messages = request.override.call_args.kwargs["messages"]
|
||||||
|
assert injected_messages[-1].name == "todo_completion_reminder"
|
||||||
|
|
||||||
|
request.override.reset_mock()
|
||||||
|
handler.reset_mock()
|
||||||
|
handler.return_value = "second-response"
|
||||||
|
assert mw.wrap_model_call(request, handler) == "second-response"
|
||||||
|
request.override.assert_not_called()
|
||||||
|
handler.assert_called_once_with(request)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTodoMiddlewareAgentGraphIntegration:
|
||||||
|
def test_completion_reminder_is_transient_in_real_agent_graph(self):
|
||||||
|
mw = TodoMiddleware()
|
||||||
|
model = _CapturingFakeMessagesListChatModel(
|
||||||
|
responses=[
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "write_todos",
|
||||||
|
"id": "todos-1",
|
||||||
|
"args": {
|
||||||
|
"todos": [
|
||||||
|
{"content": "Step 1", "status": "completed"},
|
||||||
|
{"content": "Step 2", "status": "pending"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessage(content="premature final 1"),
|
||||||
|
AIMessage(content="premature final 2"),
|
||||||
|
AIMessage(content="premature final 3"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
graph = create_agent(model=model, tools=[], middleware=[mw])
|
||||||
|
|
||||||
|
result = graph.invoke(
|
||||||
|
{"messages": [("user", "finish all todos")]},
|
||||||
|
context={"thread_id": "integration-thread", "run_id": "integration-run"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(model.seen_messages) == 4
|
||||||
|
reminders_by_call = [_todo_completion_reminders(messages) for messages in model.seen_messages]
|
||||||
|
assert reminders_by_call[0] == []
|
||||||
|
assert reminders_by_call[1] == []
|
||||||
|
assert len(reminders_by_call[2]) == 1
|
||||||
|
assert len(reminders_by_call[3]) == 1
|
||||||
|
assert "Step 1" not in reminders_by_call[2][0].content
|
||||||
|
assert "Step 2" in reminders_by_call[2][0].content
|
||||||
|
|
||||||
|
persisted_reminders = _todo_completion_reminders(result["messages"])
|
||||||
|
assert persisted_reminders == []
|
||||||
|
assert result["messages"][-1].content == "premature final 3"
|
||||||
|
assert result["todos"] == [
|
||||||
|
{"content": "Step 1", "status": "completed"},
|
||||||
|
{"content": "Step 2", "status": "pending"},
|
||||||
|
]
|
||||||
|
assert mw._pending_completion_reminders == {}
|
||||||
|
assert mw._completion_reminder_counts == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunScopedReminderCleanup:
|
||||||
|
def test_before_agent_clears_stale_count_without_pending_reminder(self):
|
||||||
|
mw = TodoMiddleware()
|
||||||
|
stale_runtime = _make_runtime()
|
||||||
|
stale_runtime.context = {"thread_id": "test-thread", "run_id": "stale-run"}
|
||||||
|
current_runtime = _make_runtime()
|
||||||
|
current_runtime.context = {"thread_id": "test-thread", "run_id": "current-run"}
|
||||||
|
other_thread_runtime = _make_runtime()
|
||||||
|
other_thread_runtime.context = {"thread_id": "other-thread", "run_id": "stale-run"}
|
||||||
|
|
||||||
|
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
|
||||||
|
assert mw.after_model(state, stale_runtime) is not None
|
||||||
|
assert mw.after_model(state, other_thread_runtime) is not None
|
||||||
|
|
||||||
|
# Simulate a model call that drained the pending message, followed by an
|
||||||
|
# abnormal run end where after_agent did not clear the reminder count.
|
||||||
|
assert mw._drain_completion_reminders(stale_runtime)
|
||||||
|
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 1
|
||||||
|
|
||||||
|
mw.before_agent({}, current_runtime)
|
||||||
|
|
||||||
|
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
|
||||||
|
assert mw._completion_reminder_count_for_runtime(other_thread_runtime) == 1
|
||||||
|
|
||||||
|
def test_size_guard_prunes_oldest_count_only_reminder_state(self):
|
||||||
|
mw = TodoMiddleware()
|
||||||
|
mw._MAX_COMPLETION_REMINDER_KEYS = 2
|
||||||
|
first_runtime = _make_runtime_for("thread-a", "run-a")
|
||||||
|
second_runtime = _make_runtime_for("thread-b", "run-b")
|
||||||
|
third_runtime = _make_runtime_for("thread-c", "run-c")
|
||||||
|
|
||||||
|
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
|
||||||
|
assert mw.after_model(state, first_runtime) is not None
|
||||||
|
|
||||||
|
# Simulate the normal model request path: pending reminder is consumed,
|
||||||
|
# but the run count remains until after_agent() or stale cleanup.
|
||||||
|
assert mw._drain_completion_reminders(first_runtime)
|
||||||
|
assert mw._completion_reminder_count_for_runtime(first_runtime) == 1
|
||||||
|
|
||||||
|
assert mw.after_model(state, second_runtime) is not None
|
||||||
|
assert mw.after_model(state, third_runtime) is not None
|
||||||
|
|
||||||
|
assert mw._completion_reminder_count_for_runtime(first_runtime) == 0
|
||||||
|
assert mw._completion_reminder_count_for_runtime(second_runtime) == 1
|
||||||
|
assert mw._completion_reminder_count_for_runtime(third_runtime) == 1
|
||||||
|
assert ("thread-a", "run-a") not in mw._completion_reminder_touch_order
|
||||||
|
|
||||||
|
def test_size_guard_prunes_pending_and_count_state_together(self):
|
||||||
|
mw = TodoMiddleware()
|
||||||
|
mw._MAX_COMPLETION_REMINDER_KEYS = 1
|
||||||
|
stale_runtime = _make_runtime_for("thread-a", "run-a")
|
||||||
|
current_runtime = _make_runtime_for("thread-b", "run-b")
|
||||||
|
|
||||||
|
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
|
||||||
|
assert mw.after_model(state, stale_runtime) is not None
|
||||||
|
assert mw.after_model(state, current_runtime) is not None
|
||||||
|
|
||||||
|
assert mw._drain_completion_reminders(stale_runtime) == []
|
||||||
|
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
|
||||||
|
assert mw._completion_reminder_count_for_runtime(current_runtime) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestAwrapModelCall:
|
||||||
|
def test_async_pending_reminder_is_injected(self):
|
||||||
|
mw = TodoMiddleware()
|
||||||
|
runtime = _make_runtime()
|
||||||
|
state = {
|
||||||
|
"messages": [_ai_no_tool_calls()],
|
||||||
|
"todos": _incomplete_todos(),
|
||||||
|
}
|
||||||
|
mw.after_model(state, runtime)
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.runtime = runtime
|
||||||
|
request.messages = state["messages"]
|
||||||
|
request.override.return_value = "patched-request"
|
||||||
|
handler = AsyncMock(return_value="response")
|
||||||
|
|
||||||
|
result = asyncio.run(mw.awrap_model_call(request, handler))
|
||||||
|
assert result == "response"
|
||||||
|
injected_messages = request.override.call_args.kwargs["messages"]
|
||||||
|
assert injected_messages[-1].name == "todo_completion_reminder"
|
||||||
|
handler.assert_awaited_once_with("patched-request")
|
||||||
|
|||||||
@@ -26,6 +26,13 @@ export type MessageGroup =
|
|||||||
| AssistantClarificationGroup
|
| AssistantClarificationGroup
|
||||||
| AssistantSubagentGroup;
|
| AssistantSubagentGroup;
|
||||||
|
|
||||||
|
const HIDDEN_CONTROL_MESSAGE_NAMES = new Set([
|
||||||
|
"summary",
|
||||||
|
"loop_warning",
|
||||||
|
"todo_reminder",
|
||||||
|
"todo_completion_reminder",
|
||||||
|
]);
|
||||||
|
|
||||||
export function getMessageGroups(messages: Message[]): MessageGroup[] {
|
export function getMessageGroups(messages: Message[]): MessageGroup[] {
|
||||||
if (messages.length === 0) {
|
if (messages.length === 0) {
|
||||||
return [];
|
return [];
|
||||||
@@ -53,10 +60,6 @@ export function getMessageGroups(messages: Message[]): MessageGroup[] {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (message.name === "todo_reminder") {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (message.type === "human") {
|
if (message.type === "human") {
|
||||||
groups.push({ id: message.id, type: "human", messages: [message] });
|
groups.push({ id: message.id, type: "human", messages: [message] });
|
||||||
continue;
|
continue;
|
||||||
@@ -368,8 +371,8 @@ export function findToolCallResult(toolCallId: string, messages: Message[]) {
|
|||||||
export function isHiddenFromUIMessage(message: Message) {
|
export function isHiddenFromUIMessage(message: Message) {
|
||||||
return (
|
return (
|
||||||
message.additional_kwargs?.hide_from_ui === true ||
|
message.additional_kwargs?.hide_from_ui === true ||
|
||||||
message.name === "summary" ||
|
(typeof message.name === "string" &&
|
||||||
message.name === "loop_warning"
|
HIDDEN_CONTROL_MESSAGE_NAMES.has(message.name))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -63,3 +63,37 @@ test("aggregates token usage messages once per assistant turn", () => {
|
|||||||
),
|
),
|
||||||
).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]);
|
).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test("hides internal todo reminder messages from message groups", () => {
|
||||||
|
const messages = [
|
||||||
|
{
|
||||||
|
id: "human-1",
|
||||||
|
type: "human",
|
||||||
|
content: "Audit the middleware",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "todo-reminder-1",
|
||||||
|
type: "human",
|
||||||
|
name: "todo_completion_reminder",
|
||||||
|
content: "<system_reminder>finish todos</system_reminder>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "todo-reminder-2",
|
||||||
|
type: "human",
|
||||||
|
name: "todo_reminder",
|
||||||
|
content: "<system_reminder>remember todos</system_reminder>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "ai-1",
|
||||||
|
type: "ai",
|
||||||
|
content: "Done",
|
||||||
|
},
|
||||||
|
] as Message[];
|
||||||
|
|
||||||
|
const groups = getMessageGroups(messages);
|
||||||
|
|
||||||
|
expect(groups.map((group) => group.type)).toEqual(["human", "assistant"]);
|
||||||
|
expect(
|
||||||
|
groups.flatMap((group) => group.messages).map((message) => message.id),
|
||||||
|
).toEqual(["human-1", "ai-1"]);
|
||||||
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user