fix(middleware): Prevent todo completion reminder IMMessage leak (#2907)

* fix(middleware): Prevent todo completion reminder IMMessage leak (#2892)

* make format

* fix(middleware): Clear stale todo reminder counts (#2892)

* add size guard for _completion_reminder_counts and add a integration test
This commit is contained in:
Nan Gao
2026-05-15 16:12:37 +02:00
committed by GitHub
parent 181d836541
commit 0c37509b38
4 changed files with 608 additions and 47 deletions
@@ -7,17 +7,21 @@ reminder message so the model still knows about the outstanding todo list.
Additionally, this middleware prevents the agent from exiting the loop while Additionally, this middleware prevents the agent from exiting the loop while
there are still incomplete todo items. When the model produces a final response there are still incomplete todo items. When the model produces a final response
(no tool calls) but todos are not yet complete, the middleware injects a reminder (no tool calls) but todos are not yet complete, the middleware queues a reminder
and jumps back to the model node to force continued engagement. for the next model request and jumps back to the model node to force continued
engagement. The completion reminder is injected via ``wrap_model_call`` instead
of being persisted into graph state as a normal user-visible message.
""" """
from __future__ import annotations from __future__ import annotations
import threading
from collections.abc import Awaitable, Callable
from typing import Any, override from typing import Any, override
from langchain.agents.middleware import TodoListMiddleware from langchain.agents.middleware import TodoListMiddleware
from langchain.agents.middleware.todo import PlanningState, Todo from langchain.agents.middleware.todo import PlanningState, Todo
from langchain.agents.middleware.types import hook_config from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
@@ -55,6 +59,51 @@ def _format_todos(todos: list[Todo]) -> str:
return "\n".join(lines) return "\n".join(lines)
def _format_completion_reminder(todos: list[Todo]) -> str:
"""Format a completion reminder for incomplete todo items."""
incomplete = [t for t in todos if t.get("status") != "completed"]
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
return (
"<system_reminder>\n"
"You have incomplete todo items that must be finished before giving your final response:\n\n"
f"{incomplete_text}\n\n"
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
"as you finish them, and only respond when all items are done.\n"
"</system_reminder>"
)
_TOOL_CALL_FINISH_REASONS = {"tool_calls", "function_call"}
def _has_tool_call_intent_or_error(message: AIMessage) -> bool:
"""Return True when an AIMessage is not a clean final answer.
Todo completion reminders should only fire when the model has produced a
plain final response. Provider/tool parsing details have moved across
LangChain versions and integrations, so keep all tool-intent/error signals
behind this helper instead of checking one concrete field at the call site.
"""
if message.tool_calls:
return True
if getattr(message, "invalid_tool_calls", None):
return True
# Backward/provider compatibility: some integrations preserve raw or legacy
# tool-call intent in additional_kwargs even when structured tool_calls is
# empty. If this helper changes, update the matching sentinel test
# `TestToolCallIntentOrError.test_langchain_ai_message_tool_fields_are_explicitly_handled`;
# if that test fails after a LangChain upgrade, review this helper so new
# tool-call/error fields are not silently treated as clean final answers.
additional_kwargs = getattr(message, "additional_kwargs", {}) or {}
if additional_kwargs.get("tool_calls") or additional_kwargs.get("function_call"):
return True
response_metadata = getattr(message, "response_metadata", {}) or {}
return response_metadata.get("finish_reason") in _TOOL_CALL_FINISH_REASONS
class TodoMiddleware(TodoListMiddleware): class TodoMiddleware(TodoListMiddleware):
"""Extends TodoListMiddleware with `write_todos` context-loss detection. """Extends TodoListMiddleware with `write_todos` context-loss detection.
@@ -89,6 +138,7 @@ class TodoMiddleware(TodoListMiddleware):
formatted = _format_todos(todos) formatted = _format_todos(todos)
reminder = HumanMessage( reminder = HumanMessage(
name="todo_reminder", name="todo_reminder",
additional_kwargs={"hide_from_ui": True},
content=( content=(
"<system_reminder>\n" "<system_reminder>\n"
"Your todo list from earlier is no longer visible in the current context window, " "Your todo list from earlier is no longer visible in the current context window, "
@@ -113,6 +163,100 @@ class TodoMiddleware(TodoListMiddleware):
# Maximum number of completion reminders before allowing the agent to exit. # Maximum number of completion reminders before allowing the agent to exit.
# This prevents infinite loops when the agent cannot make further progress. # This prevents infinite loops when the agent cannot make further progress.
_MAX_COMPLETION_REMINDERS = 2 _MAX_COMPLETION_REMINDERS = 2
# Hard cap for per-run reminder bookkeeping in long-lived middleware instances.
_MAX_COMPLETION_REMINDER_KEYS = 4096
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._lock = threading.Lock()
self._pending_completion_reminders: dict[tuple[str, str], list[str]] = {}
self._completion_reminder_counts: dict[tuple[str, str], int] = {}
self._completion_reminder_touch_order: dict[tuple[str, str], int] = {}
self._completion_reminder_next_order = 0
@staticmethod
def _get_thread_id(runtime: Runtime) -> str:
context = getattr(runtime, "context", None)
thread_id = context.get("thread_id") if context else None
return str(thread_id) if thread_id else "default"
@staticmethod
def _get_run_id(runtime: Runtime) -> str:
context = getattr(runtime, "context", None)
run_id = context.get("run_id") if context else None
return str(run_id) if run_id else "default"
def _pending_key(self, runtime: Runtime) -> tuple[str, str]:
return self._get_thread_id(runtime), self._get_run_id(runtime)
def _touch_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
self._completion_reminder_next_order += 1
self._completion_reminder_touch_order[key] = self._completion_reminder_next_order
def _completion_reminder_keys_locked(self) -> set[tuple[str, str]]:
keys = set(self._pending_completion_reminders)
keys.update(self._completion_reminder_counts)
keys.update(self._completion_reminder_touch_order)
return keys
def _drop_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
self._pending_completion_reminders.pop(key, None)
self._completion_reminder_counts.pop(key, None)
self._completion_reminder_touch_order.pop(key, None)
def _prune_completion_reminder_state_locked(self, protected_key: tuple[str, str]) -> None:
keys = self._completion_reminder_keys_locked()
overflow = len(keys) - self._MAX_COMPLETION_REMINDER_KEYS
if overflow <= 0:
return
candidates = [key for key in keys if key != protected_key]
candidates.sort(key=lambda key: self._completion_reminder_touch_order.get(key, 0))
for key in candidates[:overflow]:
self._drop_completion_reminder_key_locked(key)
def _queue_completion_reminder(self, runtime: Runtime, reminder: str) -> None:
key = self._pending_key(runtime)
with self._lock:
self._pending_completion_reminders.setdefault(key, []).append(reminder)
self._completion_reminder_counts[key] = self._completion_reminder_counts.get(key, 0) + 1
self._touch_completion_reminder_key_locked(key)
self._prune_completion_reminder_state_locked(protected_key=key)
def _completion_reminder_count_for_runtime(self, runtime: Runtime) -> int:
key = self._pending_key(runtime)
with self._lock:
return self._completion_reminder_counts.get(key, 0)
def _drain_completion_reminders(self, runtime: Runtime) -> list[str]:
key = self._pending_key(runtime)
with self._lock:
reminders = self._pending_completion_reminders.pop(key, [])
if reminders or key in self._completion_reminder_counts:
self._touch_completion_reminder_key_locked(key)
return reminders
def _clear_other_run_completion_reminders(self, runtime: Runtime) -> None:
thread_id, current_run_id = self._pending_key(runtime)
with self._lock:
for key in self._completion_reminder_keys_locked():
if key[0] == thread_id and key[1] != current_run_id:
self._drop_completion_reminder_key_locked(key)
def _clear_current_run_completion_reminders(self, runtime: Runtime) -> None:
key = self._pending_key(runtime)
with self._lock:
self._drop_completion_reminder_key_locked(key)
@override
def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_other_run_completion_reminders(runtime)
return None
@override
async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_other_run_completion_reminders(runtime)
return None
@hook_config(can_jump_to=["model"]) @hook_config(can_jump_to=["model"])
@override @override
@@ -137,10 +281,12 @@ class TodoMiddleware(TodoListMiddleware):
if base_result is not None: if base_result is not None:
return base_result return base_result
# 2. Only intervene when the agent wants to exit (no tool calls). # 2. Only intervene when the agent wants to exit cleanly. Tool-call
# intent or tool-call parse errors should be handled by the tool path
# instead of being masked by todo reminders.
messages = state.get("messages") or [] messages = state.get("messages") or []
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None) last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
if not last_ai or last_ai.tool_calls: if not last_ai or _has_tool_call_intent_or_error(last_ai):
return None return None
# 3. Allow exit when all todos are completed or there are no todos. # 3. Allow exit when all todos are completed or there are no todos.
@@ -149,24 +295,14 @@ class TodoMiddleware(TodoListMiddleware):
return None return None
# 4. Enforce a reminder cap to prevent infinite re-engagement loops. # 4. Enforce a reminder cap to prevent infinite re-engagement loops.
if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS: if self._completion_reminder_count_for_runtime(runtime) >= self._MAX_COMPLETION_REMINDERS:
return None return None
# 5. Inject a reminder and force the agent back to the model. # 5. Queue a reminder for the next model request and jump back. We must
incomplete = [t for t in todos if t.get("status") != "completed"] # not persist this control prompt as a normal HumanMessage, otherwise it
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete) # can leak into user-visible message streams and saved transcripts.
reminder = HumanMessage( self._queue_completion_reminder(runtime, _format_completion_reminder(todos))
name="todo_completion_reminder", return {"jump_to": "model"}
content=(
"<system_reminder>\n"
"You have incomplete todo items that must be finished before giving your final response:\n\n"
f"{incomplete_text}\n\n"
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
"as you finish them, and only respond when all items are done.\n"
"</system_reminder>"
),
)
return {"jump_to": "model", "messages": [reminder]}
@override @override
@hook_config(can_jump_to=["model"]) @hook_config(can_jump_to=["model"])
@@ -177,3 +313,47 @@ class TodoMiddleware(TodoListMiddleware):
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Async version of after_model.""" """Async version of after_model."""
return self.after_model(state, runtime) return self.after_model(state, runtime)
@staticmethod
def _format_pending_completion_reminders(reminders: list[str]) -> str:
return "\n\n".join(dict.fromkeys(reminders))
def _augment_request(self, request: ModelRequest) -> ModelRequest:
reminders = self._drain_completion_reminders(request.runtime)
if not reminders:
return request
new_messages = [
*request.messages,
HumanMessage(
content=self._format_pending_completion_reminders(reminders),
name="todo_completion_reminder",
additional_kwargs={"hide_from_ui": True},
),
]
return request.override(messages=new_messages)
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(self._augment_request(request))
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
return await handler(self._augment_request(request))
@override
def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_current_run_completion_reminders(runtime)
return None
@override
async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_current_run_completion_reminders(runtime)
return None
+364 -20
View File
@@ -1,14 +1,19 @@
"""Tests for TodoMiddleware context-loss detection.""" """Tests for TodoMiddleware context-loss detection."""
import asyncio import asyncio
from unittest.mock import MagicMock from typing import Any
from unittest.mock import AsyncMock, MagicMock
from langchain.agents import create_agent
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from pydantic import PrivateAttr
from deerflow.agents.middlewares.todo_middleware import ( from deerflow.agents.middlewares.todo_middleware import (
TodoMiddleware, TodoMiddleware,
_completion_reminder_count, _completion_reminder_count,
_format_todos, _format_todos,
_has_tool_call_intent_or_error,
_reminder_in_messages, _reminder_in_messages,
_todos_in_messages, _todos_in_messages,
) )
@@ -22,9 +27,35 @@ def _reminder_msg():
return HumanMessage(name="todo_reminder", content="reminder") return HumanMessage(name="todo_reminder", content="reminder")
class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel):
_seen_messages: list[list[Any]] = PrivateAttr(default_factory=list)
@property
def seen_messages(self) -> list[list[Any]]:
return self._seen_messages
def bind_tools(self, tools, *, tool_choice=None, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self._seen_messages.append(list(messages))
return super()._generate(
messages,
stop=stop,
run_manager=run_manager,
**kwargs,
)
def _make_runtime(): def _make_runtime():
runtime = MagicMock() runtime = MagicMock()
runtime.context = {"thread_id": "test-thread"} runtime.context = {"thread_id": "test-thread", "run_id": "test-run"}
return runtime
def _make_runtime_for(thread_id: str, run_id: str):
runtime = _make_runtime()
runtime.context = {"thread_id": thread_id, "run_id": run_id}
return runtime return runtime
@@ -161,10 +192,62 @@ def _completion_reminder_msg():
return HumanMessage(name="todo_completion_reminder", content="finish your todos") return HumanMessage(name="todo_completion_reminder", content="finish your todos")
def _todo_completion_reminders(messages):
reminders = []
for message in messages:
if isinstance(message, HumanMessage) and message.name == "todo_completion_reminder":
reminders.append(message)
return reminders
def _ai_no_tool_calls(): def _ai_no_tool_calls():
return AIMessage(content="I'm done!") return AIMessage(content="I'm done!")
def _ai_with_invalid_tool_calls():
return AIMessage(
content="",
tool_calls=[],
invalid_tool_calls=[
{
"type": "invalid_tool_call",
"id": "write_file:36",
"name": "write_file",
"args": "{invalid",
"error": "Failed to parse tool arguments",
}
],
)
def _ai_with_raw_provider_tool_calls():
return AIMessage(
content="",
tool_calls=[],
invalid_tool_calls=[],
additional_kwargs={
"tool_calls": [
{
"id": "raw-tool-call",
"type": "function",
"function": {"name": "write_file", "arguments": '{"path":"report.md"}'},
}
]
},
)
def _ai_with_legacy_function_call():
return AIMessage(
content="",
additional_kwargs={"function_call": {"name": "write_file", "arguments": '{"path":"report.md"}'}},
)
def _ai_with_tool_finish_reason():
return AIMessage(content="", response_metadata={"finish_reason": "tool_calls"})
def _incomplete_todos(): def _incomplete_todos():
return [ return [
{"status": "completed", "content": "Step 1"}, {"status": "completed", "content": "Step 1"},
@@ -194,6 +277,36 @@ class TestCompletionReminderCount:
assert _completion_reminder_count(msgs) == 1 assert _completion_reminder_count(msgs) == 1
class TestToolCallIntentOrError:
def test_false_for_plain_final_answer(self):
assert _has_tool_call_intent_or_error(_ai_no_tool_calls()) is False
def test_true_for_structured_tool_calls(self):
assert _has_tool_call_intent_or_error(_ai_with_write_todos()) is True
def test_true_for_invalid_tool_calls(self):
assert _has_tool_call_intent_or_error(_ai_with_invalid_tool_calls()) is True
def test_true_for_raw_provider_tool_calls(self):
assert _has_tool_call_intent_or_error(_ai_with_raw_provider_tool_calls()) is True
def test_true_for_legacy_function_call(self):
assert _has_tool_call_intent_or_error(_ai_with_legacy_function_call()) is True
def test_true_for_tool_finish_reason(self):
assert _has_tool_call_intent_or_error(_ai_with_tool_finish_reason()) is True
def test_langchain_ai_message_tool_fields_are_explicitly_handled(self):
# Sentinel for LangChain compatibility: if future AIMessage versions add
# new top-level tool/function-call fields, this test should fail. When
# it does, update `_has_tool_call_intent_or_error()` so the completion
# reminder guard explicitly decides whether each new field means "not a
# clean final answer"; the helper has a matching comment pointing back
# to this sentinel.
tool_related_fields = {name for name in AIMessage.model_fields if "tool" in name.lower() or ("function" in name.lower() and "call" in name.lower())}
assert tool_related_fields <= {"tool_calls", "invalid_tool_calls"}
class TestAfterModel: class TestAfterModel:
def test_returns_none_when_agent_still_using_tools(self): def test_returns_none_when_agent_still_using_tools(self):
mw = TodoMiddleware() mw = TodoMiddleware()
@@ -235,68 +348,299 @@ class TestAfterModel:
} }
assert mw.after_model(state, _make_runtime()) is None assert mw.after_model(state, _make_runtime()) is None
def test_injects_reminder_and_jumps_to_model_when_incomplete(self): def test_queues_reminder_and_jumps_to_model_when_incomplete(self):
mw = TodoMiddleware() mw = TodoMiddleware()
runtime = _make_runtime()
state = { state = {
"messages": [HumanMessage(content="hi"), _ai_no_tool_calls()], "messages": [HumanMessage(content="hi"), _ai_no_tool_calls()],
"todos": _incomplete_todos(), "todos": _incomplete_todos(),
} }
result = mw.after_model(state, _make_runtime()) result = mw.after_model(state, runtime)
assert result is not None assert result is not None
assert result["jump_to"] == "model" assert result["jump_to"] == "model"
assert len(result["messages"]) == 1 assert "messages" not in result
reminder = result["messages"][0]
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
handler = MagicMock(return_value="response")
assert mw.wrap_model_call(request, handler) == "response"
request.override.assert_called_once()
reminder = request.override.call_args.kwargs["messages"][-1]
assert isinstance(reminder, HumanMessage) assert isinstance(reminder, HumanMessage)
assert reminder.name == "todo_completion_reminder" assert reminder.name == "todo_completion_reminder"
assert reminder.additional_kwargs["hide_from_ui"] is True
assert "Step 2" in reminder.content assert "Step 2" in reminder.content
assert "Step 3" in reminder.content assert "Step 3" in reminder.content
handler.assert_called_once_with("patched-request")
def test_reminder_lists_only_incomplete_items(self): def test_reminder_lists_only_incomplete_items(self):
mw = TodoMiddleware() mw = TodoMiddleware()
runtime = _make_runtime()
state = { state = {
"messages": [_ai_no_tool_calls()], "messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(), "todos": _incomplete_todos(),
} }
result = mw.after_model(state, _make_runtime()) result = mw.after_model(state, runtime)
content = result["messages"][0].content assert result is not None
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
mw.wrap_model_call(request, MagicMock(return_value="response"))
content = request.override.call_args.kwargs["messages"][-1].content
assert "Step 1" not in content # completed — should not appear assert "Step 1" not in content # completed — should not appear
assert "Step 2" in content assert "Step 2" in content
assert "Step 3" in content assert "Step 3" in content
def test_allows_exit_after_max_reminders(self): def test_allows_exit_after_max_reminders(self):
mw = TodoMiddleware() mw = TodoMiddleware()
runtime = _make_runtime()
state = { state = {
"messages": [ "messages": [
_completion_reminder_msg(),
_completion_reminder_msg(),
_ai_no_tool_calls(), _ai_no_tool_calls(),
], ],
"todos": _incomplete_todos(), "todos": _incomplete_todos(),
} }
assert mw.after_model(state, runtime) is not None
assert mw.after_model(state, runtime) is not None
assert mw.after_model(state, runtime) is None
def test_still_sends_reminder_before_cap(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [
_ai_no_tool_calls(),
],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, runtime) is not None
result = mw.after_model(state, runtime)
assert result is not None
assert result["jump_to"] == "model"
def test_does_not_trigger_for_invalid_tool_calls(self):
mw = TodoMiddleware()
state = {
"messages": [_ai_with_invalid_tool_calls()],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, _make_runtime()) is None assert mw.after_model(state, _make_runtime()) is None
def test_still_sends_reminder_before_cap(self): def test_does_not_trigger_for_raw_provider_tool_calls(self):
mw = TodoMiddleware() mw = TodoMiddleware()
state = { state = {
"messages": [ "messages": [_ai_with_raw_provider_tool_calls()],
_completion_reminder_msg(), # 1 reminder so far
_ai_no_tool_calls(),
],
"todos": _incomplete_todos(), "todos": _incomplete_todos(),
} }
result = mw.after_model(state, _make_runtime()) assert mw.after_model(state, _make_runtime()) is None
assert result is not None
assert result["jump_to"] == "model" def test_does_not_trigger_for_legacy_function_call(self):
mw = TodoMiddleware()
state = {
"messages": [_ai_with_legacy_function_call()],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, _make_runtime()) is None
def test_does_not_trigger_for_tool_finish_reason(self):
mw = TodoMiddleware()
state = {
"messages": [_ai_with_tool_finish_reason()],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, _make_runtime()) is None
class TestAafterModel: class TestAafterModel:
def test_delegates_to_sync(self): def test_delegates_to_sync(self):
mw = TodoMiddleware() mw = TodoMiddleware()
runtime = _make_runtime()
state = { state = {
"messages": [_ai_no_tool_calls()], "messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(), "todos": _incomplete_todos(),
} }
result = asyncio.run(mw.aafter_model(state, _make_runtime())) result = asyncio.run(mw.aafter_model(state, runtime))
assert result is not None assert result is not None
assert result["jump_to"] == "model" assert result["jump_to"] == "model"
assert result["messages"][0].name == "todo_completion_reminder" assert "messages" not in result
class TestWrapModelCall:
def test_no_pending_reminder_passthrough(self):
mw = TodoMiddleware()
request = MagicMock()
request.runtime = _make_runtime()
request.messages = [HumanMessage(content="hi")]
handler = MagicMock(return_value="response")
assert mw.wrap_model_call(request, handler) == "response"
request.override.assert_not_called()
handler.assert_called_once_with(request)
def test_pending_reminder_is_injected_once(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
mw.after_model(state, runtime)
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
handler = MagicMock(return_value="response")
assert mw.wrap_model_call(request, handler) == "response"
injected_messages = request.override.call_args.kwargs["messages"]
assert injected_messages[-1].name == "todo_completion_reminder"
request.override.reset_mock()
handler.reset_mock()
handler.return_value = "second-response"
assert mw.wrap_model_call(request, handler) == "second-response"
request.override.assert_not_called()
handler.assert_called_once_with(request)
class TestTodoMiddlewareAgentGraphIntegration:
def test_completion_reminder_is_transient_in_real_agent_graph(self):
mw = TodoMiddleware()
model = _CapturingFakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "write_todos",
"id": "todos-1",
"args": {
"todos": [
{"content": "Step 1", "status": "completed"},
{"content": "Step 2", "status": "pending"},
]
},
}
],
),
AIMessage(content="premature final 1"),
AIMessage(content="premature final 2"),
AIMessage(content="premature final 3"),
],
)
graph = create_agent(model=model, tools=[], middleware=[mw])
result = graph.invoke(
{"messages": [("user", "finish all todos")]},
context={"thread_id": "integration-thread", "run_id": "integration-run"},
)
assert len(model.seen_messages) == 4
reminders_by_call = [_todo_completion_reminders(messages) for messages in model.seen_messages]
assert reminders_by_call[0] == []
assert reminders_by_call[1] == []
assert len(reminders_by_call[2]) == 1
assert len(reminders_by_call[3]) == 1
assert "Step 1" not in reminders_by_call[2][0].content
assert "Step 2" in reminders_by_call[2][0].content
persisted_reminders = _todo_completion_reminders(result["messages"])
assert persisted_reminders == []
assert result["messages"][-1].content == "premature final 3"
assert result["todos"] == [
{"content": "Step 1", "status": "completed"},
{"content": "Step 2", "status": "pending"},
]
assert mw._pending_completion_reminders == {}
assert mw._completion_reminder_counts == {}
class TestRunScopedReminderCleanup:
def test_before_agent_clears_stale_count_without_pending_reminder(self):
mw = TodoMiddleware()
stale_runtime = _make_runtime()
stale_runtime.context = {"thread_id": "test-thread", "run_id": "stale-run"}
current_runtime = _make_runtime()
current_runtime.context = {"thread_id": "test-thread", "run_id": "current-run"}
other_thread_runtime = _make_runtime()
other_thread_runtime.context = {"thread_id": "other-thread", "run_id": "stale-run"}
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
assert mw.after_model(state, stale_runtime) is not None
assert mw.after_model(state, other_thread_runtime) is not None
# Simulate a model call that drained the pending message, followed by an
# abnormal run end where after_agent did not clear the reminder count.
assert mw._drain_completion_reminders(stale_runtime)
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 1
mw.before_agent({}, current_runtime)
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
assert mw._completion_reminder_count_for_runtime(other_thread_runtime) == 1
def test_size_guard_prunes_oldest_count_only_reminder_state(self):
mw = TodoMiddleware()
mw._MAX_COMPLETION_REMINDER_KEYS = 2
first_runtime = _make_runtime_for("thread-a", "run-a")
second_runtime = _make_runtime_for("thread-b", "run-b")
third_runtime = _make_runtime_for("thread-c", "run-c")
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
assert mw.after_model(state, first_runtime) is not None
# Simulate the normal model request path: pending reminder is consumed,
# but the run count remains until after_agent() or stale cleanup.
assert mw._drain_completion_reminders(first_runtime)
assert mw._completion_reminder_count_for_runtime(first_runtime) == 1
assert mw.after_model(state, second_runtime) is not None
assert mw.after_model(state, third_runtime) is not None
assert mw._completion_reminder_count_for_runtime(first_runtime) == 0
assert mw._completion_reminder_count_for_runtime(second_runtime) == 1
assert mw._completion_reminder_count_for_runtime(third_runtime) == 1
assert ("thread-a", "run-a") not in mw._completion_reminder_touch_order
def test_size_guard_prunes_pending_and_count_state_together(self):
mw = TodoMiddleware()
mw._MAX_COMPLETION_REMINDER_KEYS = 1
stale_runtime = _make_runtime_for("thread-a", "run-a")
current_runtime = _make_runtime_for("thread-b", "run-b")
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
assert mw.after_model(state, stale_runtime) is not None
assert mw.after_model(state, current_runtime) is not None
assert mw._drain_completion_reminders(stale_runtime) == []
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
assert mw._completion_reminder_count_for_runtime(current_runtime) == 1
class TestAwrapModelCall:
def test_async_pending_reminder_is_injected(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
mw.after_model(state, runtime)
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
handler = AsyncMock(return_value="response")
result = asyncio.run(mw.awrap_model_call(request, handler))
assert result == "response"
injected_messages = request.override.call_args.kwargs["messages"]
assert injected_messages[-1].name == "todo_completion_reminder"
handler.assert_awaited_once_with("patched-request")
+9 -6
View File
@@ -26,6 +26,13 @@ export type MessageGroup =
| AssistantClarificationGroup | AssistantClarificationGroup
| AssistantSubagentGroup; | AssistantSubagentGroup;
const HIDDEN_CONTROL_MESSAGE_NAMES = new Set([
"summary",
"loop_warning",
"todo_reminder",
"todo_completion_reminder",
]);
export function getMessageGroups(messages: Message[]): MessageGroup[] { export function getMessageGroups(messages: Message[]): MessageGroup[] {
if (messages.length === 0) { if (messages.length === 0) {
return []; return [];
@@ -53,10 +60,6 @@ export function getMessageGroups(messages: Message[]): MessageGroup[] {
continue; continue;
} }
if (message.name === "todo_reminder") {
continue;
}
if (message.type === "human") { if (message.type === "human") {
groups.push({ id: message.id, type: "human", messages: [message] }); groups.push({ id: message.id, type: "human", messages: [message] });
continue; continue;
@@ -368,8 +371,8 @@ export function findToolCallResult(toolCallId: string, messages: Message[]) {
export function isHiddenFromUIMessage(message: Message) { export function isHiddenFromUIMessage(message: Message) {
return ( return (
message.additional_kwargs?.hide_from_ui === true || message.additional_kwargs?.hide_from_ui === true ||
message.name === "summary" || (typeof message.name === "string" &&
message.name === "loop_warning" HIDDEN_CONTROL_MESSAGE_NAMES.has(message.name))
); );
} }
@@ -63,3 +63,37 @@ test("aggregates token usage messages once per assistant turn", () => {
), ),
).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]); ).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]);
}); });
test("hides internal todo reminder messages from message groups", () => {
const messages = [
{
id: "human-1",
type: "human",
content: "Audit the middleware",
},
{
id: "todo-reminder-1",
type: "human",
name: "todo_completion_reminder",
content: "<system_reminder>finish todos</system_reminder>",
},
{
id: "todo-reminder-2",
type: "human",
name: "todo_reminder",
content: "<system_reminder>remember todos</system_reminder>",
},
{
id: "ai-1",
type: "ai",
content: "Done",
},
] as Message[];
const groups = getMessageGroups(messages);
expect(groups.map((group) => group.type)).toEqual(["human", "assistant"]);
expect(
groups.flatMap((group) => group.messages).map((message) => message.id),
).toEqual(["human-1", "ai-1"]);
});