From 0c37509b3854f18de64df3d5b0b2b43bc09d9b6f Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Fri, 15 May 2026 16:12:37 +0200 Subject: [PATCH] fix(middleware): Prevent todo completion reminder IMMessage leak (#2907) * fix(middleware): Prevent todo completion reminder IMMessage leak (#2892) * make format * fix(middleware): Clear stale todo reminder counts (#2892) * add size guard for _completion_reminder_counts and add a integration test --- .../agents/middlewares/todo_middleware.py | 222 +++++++++- backend/tests/test_todo_middleware.py | 384 +++++++++++++++++- frontend/src/core/messages/utils.ts | 15 +- .../tests/unit/core/messages/utils.test.ts | 34 ++ 4 files changed, 608 insertions(+), 47 deletions(-) diff --git a/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py index b8cd10884..9215aefc5 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py @@ -7,17 +7,21 @@ reminder message so the model still knows about the outstanding todo list. Additionally, this middleware prevents the agent from exiting the loop while there are still incomplete todo items. When the model produces a final response -(no tool calls) but todos are not yet complete, the middleware injects a reminder -and jumps back to the model node to force continued engagement. +(no tool calls) but todos are not yet complete, the middleware queues a reminder +for the next model request and jumps back to the model node to force continued +engagement. The completion reminder is injected via ``wrap_model_call`` instead +of being persisted into graph state as a normal user-visible message. """ from __future__ import annotations +import threading +from collections.abc import Awaitable, Callable from typing import Any, override from langchain.agents.middleware import TodoListMiddleware from langchain.agents.middleware.todo import PlanningState, Todo -from langchain.agents.middleware.types import hook_config +from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config from langchain_core.messages import AIMessage, HumanMessage from langgraph.runtime import Runtime @@ -55,6 +59,51 @@ def _format_todos(todos: list[Todo]) -> str: return "\n".join(lines) +def _format_completion_reminder(todos: list[Todo]) -> str: + """Format a completion reminder for incomplete todo items.""" + incomplete = [t for t in todos if t.get("status") != "completed"] + incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete) + return ( + "\n" + "You have incomplete todo items that must be finished before giving your final response:\n\n" + f"{incomplete_text}\n\n" + "Please continue working on these tasks. Call `write_todos` to mark items as completed " + "as you finish them, and only respond when all items are done.\n" + "" + ) + + +_TOOL_CALL_FINISH_REASONS = {"tool_calls", "function_call"} + + +def _has_tool_call_intent_or_error(message: AIMessage) -> bool: + """Return True when an AIMessage is not a clean final answer. + + Todo completion reminders should only fire when the model has produced a + plain final response. Provider/tool parsing details have moved across + LangChain versions and integrations, so keep all tool-intent/error signals + behind this helper instead of checking one concrete field at the call site. + """ + if message.tool_calls: + return True + + if getattr(message, "invalid_tool_calls", None): + return True + + # Backward/provider compatibility: some integrations preserve raw or legacy + # tool-call intent in additional_kwargs even when structured tool_calls is + # empty. If this helper changes, update the matching sentinel test + # `TestToolCallIntentOrError.test_langchain_ai_message_tool_fields_are_explicitly_handled`; + # if that test fails after a LangChain upgrade, review this helper so new + # tool-call/error fields are not silently treated as clean final answers. + additional_kwargs = getattr(message, "additional_kwargs", {}) or {} + if additional_kwargs.get("tool_calls") or additional_kwargs.get("function_call"): + return True + + response_metadata = getattr(message, "response_metadata", {}) or {} + return response_metadata.get("finish_reason") in _TOOL_CALL_FINISH_REASONS + + class TodoMiddleware(TodoListMiddleware): """Extends TodoListMiddleware with `write_todos` context-loss detection. @@ -89,6 +138,7 @@ class TodoMiddleware(TodoListMiddleware): formatted = _format_todos(todos) reminder = HumanMessage( name="todo_reminder", + additional_kwargs={"hide_from_ui": True}, content=( "\n" "Your todo list from earlier is no longer visible in the current context window, " @@ -113,6 +163,100 @@ class TodoMiddleware(TodoListMiddleware): # Maximum number of completion reminders before allowing the agent to exit. # This prevents infinite loops when the agent cannot make further progress. _MAX_COMPLETION_REMINDERS = 2 + # Hard cap for per-run reminder bookkeeping in long-lived middleware instances. + _MAX_COMPLETION_REMINDER_KEYS = 4096 + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._lock = threading.Lock() + self._pending_completion_reminders: dict[tuple[str, str], list[str]] = {} + self._completion_reminder_counts: dict[tuple[str, str], int] = {} + self._completion_reminder_touch_order: dict[tuple[str, str], int] = {} + self._completion_reminder_next_order = 0 + + @staticmethod + def _get_thread_id(runtime: Runtime) -> str: + context = getattr(runtime, "context", None) + thread_id = context.get("thread_id") if context else None + return str(thread_id) if thread_id else "default" + + @staticmethod + def _get_run_id(runtime: Runtime) -> str: + context = getattr(runtime, "context", None) + run_id = context.get("run_id") if context else None + return str(run_id) if run_id else "default" + + def _pending_key(self, runtime: Runtime) -> tuple[str, str]: + return self._get_thread_id(runtime), self._get_run_id(runtime) + + def _touch_completion_reminder_key_locked(self, key: tuple[str, str]) -> None: + self._completion_reminder_next_order += 1 + self._completion_reminder_touch_order[key] = self._completion_reminder_next_order + + def _completion_reminder_keys_locked(self) -> set[tuple[str, str]]: + keys = set(self._pending_completion_reminders) + keys.update(self._completion_reminder_counts) + keys.update(self._completion_reminder_touch_order) + return keys + + def _drop_completion_reminder_key_locked(self, key: tuple[str, str]) -> None: + self._pending_completion_reminders.pop(key, None) + self._completion_reminder_counts.pop(key, None) + self._completion_reminder_touch_order.pop(key, None) + + def _prune_completion_reminder_state_locked(self, protected_key: tuple[str, str]) -> None: + keys = self._completion_reminder_keys_locked() + overflow = len(keys) - self._MAX_COMPLETION_REMINDER_KEYS + if overflow <= 0: + return + + candidates = [key for key in keys if key != protected_key] + candidates.sort(key=lambda key: self._completion_reminder_touch_order.get(key, 0)) + for key in candidates[:overflow]: + self._drop_completion_reminder_key_locked(key) + + def _queue_completion_reminder(self, runtime: Runtime, reminder: str) -> None: + key = self._pending_key(runtime) + with self._lock: + self._pending_completion_reminders.setdefault(key, []).append(reminder) + self._completion_reminder_counts[key] = self._completion_reminder_counts.get(key, 0) + 1 + self._touch_completion_reminder_key_locked(key) + self._prune_completion_reminder_state_locked(protected_key=key) + + def _completion_reminder_count_for_runtime(self, runtime: Runtime) -> int: + key = self._pending_key(runtime) + with self._lock: + return self._completion_reminder_counts.get(key, 0) + + def _drain_completion_reminders(self, runtime: Runtime) -> list[str]: + key = self._pending_key(runtime) + with self._lock: + reminders = self._pending_completion_reminders.pop(key, []) + if reminders or key in self._completion_reminder_counts: + self._touch_completion_reminder_key_locked(key) + return reminders + + def _clear_other_run_completion_reminders(self, runtime: Runtime) -> None: + thread_id, current_run_id = self._pending_key(runtime) + with self._lock: + for key in self._completion_reminder_keys_locked(): + if key[0] == thread_id and key[1] != current_run_id: + self._drop_completion_reminder_key_locked(key) + + def _clear_current_run_completion_reminders(self, runtime: Runtime) -> None: + key = self._pending_key(runtime) + with self._lock: + self._drop_completion_reminder_key_locked(key) + + @override + def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + self._clear_other_run_completion_reminders(runtime) + return None + + @override + async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + self._clear_other_run_completion_reminders(runtime) + return None @hook_config(can_jump_to=["model"]) @override @@ -137,10 +281,12 @@ class TodoMiddleware(TodoListMiddleware): if base_result is not None: return base_result - # 2. Only intervene when the agent wants to exit (no tool calls). + # 2. Only intervene when the agent wants to exit cleanly. Tool-call + # intent or tool-call parse errors should be handled by the tool path + # instead of being masked by todo reminders. messages = state.get("messages") or [] last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None) - if not last_ai or last_ai.tool_calls: + if not last_ai or _has_tool_call_intent_or_error(last_ai): return None # 3. Allow exit when all todos are completed or there are no todos. @@ -149,24 +295,14 @@ class TodoMiddleware(TodoListMiddleware): return None # 4. Enforce a reminder cap to prevent infinite re-engagement loops. - if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS: + if self._completion_reminder_count_for_runtime(runtime) >= self._MAX_COMPLETION_REMINDERS: return None - # 5. Inject a reminder and force the agent back to the model. - incomplete = [t for t in todos if t.get("status") != "completed"] - incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete) - reminder = HumanMessage( - name="todo_completion_reminder", - content=( - "\n" - "You have incomplete todo items that must be finished before giving your final response:\n\n" - f"{incomplete_text}\n\n" - "Please continue working on these tasks. Call `write_todos` to mark items as completed " - "as you finish them, and only respond when all items are done.\n" - "" - ), - ) - return {"jump_to": "model", "messages": [reminder]} + # 5. Queue a reminder for the next model request and jump back. We must + # not persist this control prompt as a normal HumanMessage, otherwise it + # can leak into user-visible message streams and saved transcripts. + self._queue_completion_reminder(runtime, _format_completion_reminder(todos)) + return {"jump_to": "model"} @override @hook_config(can_jump_to=["model"]) @@ -177,3 +313,47 @@ class TodoMiddleware(TodoListMiddleware): ) -> dict[str, Any] | None: """Async version of after_model.""" return self.after_model(state, runtime) + + @staticmethod + def _format_pending_completion_reminders(reminders: list[str]) -> str: + return "\n\n".join(dict.fromkeys(reminders)) + + def _augment_request(self, request: ModelRequest) -> ModelRequest: + reminders = self._drain_completion_reminders(request.runtime) + if not reminders: + return request + new_messages = [ + *request.messages, + HumanMessage( + content=self._format_pending_completion_reminders(reminders), + name="todo_completion_reminder", + additional_kwargs={"hide_from_ui": True}, + ), + ] + return request.override(messages=new_messages) + + @override + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + return handler(self._augment_request(request)) + + @override + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: + return await handler(self._augment_request(request)) + + @override + def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + self._clear_current_run_completion_reminders(runtime) + return None + + @override + async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + self._clear_current_run_completion_reminders(runtime) + return None diff --git a/backend/tests/test_todo_middleware.py b/backend/tests/test_todo_middleware.py index efeee9eb0..934e730f2 100644 --- a/backend/tests/test_todo_middleware.py +++ b/backend/tests/test_todo_middleware.py @@ -1,14 +1,19 @@ """Tests for TodoMiddleware context-loss detection.""" import asyncio -from unittest.mock import MagicMock +from typing import Any +from unittest.mock import AsyncMock, MagicMock +from langchain.agents import create_agent +from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel from langchain_core.messages import AIMessage, HumanMessage +from pydantic import PrivateAttr from deerflow.agents.middlewares.todo_middleware import ( TodoMiddleware, _completion_reminder_count, _format_todos, + _has_tool_call_intent_or_error, _reminder_in_messages, _todos_in_messages, ) @@ -22,9 +27,35 @@ def _reminder_msg(): return HumanMessage(name="todo_reminder", content="reminder") +class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel): + _seen_messages: list[list[Any]] = PrivateAttr(default_factory=list) + + @property + def seen_messages(self) -> list[list[Any]]: + return self._seen_messages + + def bind_tools(self, tools, *, tool_choice=None, **kwargs): + return self + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + self._seen_messages.append(list(messages)) + return super()._generate( + messages, + stop=stop, + run_manager=run_manager, + **kwargs, + ) + + def _make_runtime(): runtime = MagicMock() - runtime.context = {"thread_id": "test-thread"} + runtime.context = {"thread_id": "test-thread", "run_id": "test-run"} + return runtime + + +def _make_runtime_for(thread_id: str, run_id: str): + runtime = _make_runtime() + runtime.context = {"thread_id": thread_id, "run_id": run_id} return runtime @@ -161,10 +192,62 @@ def _completion_reminder_msg(): return HumanMessage(name="todo_completion_reminder", content="finish your todos") +def _todo_completion_reminders(messages): + reminders = [] + for message in messages: + if isinstance(message, HumanMessage) and message.name == "todo_completion_reminder": + reminders.append(message) + return reminders + + def _ai_no_tool_calls(): return AIMessage(content="I'm done!") +def _ai_with_invalid_tool_calls(): + return AIMessage( + content="", + tool_calls=[], + invalid_tool_calls=[ + { + "type": "invalid_tool_call", + "id": "write_file:36", + "name": "write_file", + "args": "{invalid", + "error": "Failed to parse tool arguments", + } + ], + ) + + +def _ai_with_raw_provider_tool_calls(): + return AIMessage( + content="", + tool_calls=[], + invalid_tool_calls=[], + additional_kwargs={ + "tool_calls": [ + { + "id": "raw-tool-call", + "type": "function", + "function": {"name": "write_file", "arguments": '{"path":"report.md"}'}, + } + ] + }, + ) + + +def _ai_with_legacy_function_call(): + return AIMessage( + content="", + additional_kwargs={"function_call": {"name": "write_file", "arguments": '{"path":"report.md"}'}}, + ) + + +def _ai_with_tool_finish_reason(): + return AIMessage(content="", response_metadata={"finish_reason": "tool_calls"}) + + def _incomplete_todos(): return [ {"status": "completed", "content": "Step 1"}, @@ -194,6 +277,36 @@ class TestCompletionReminderCount: assert _completion_reminder_count(msgs) == 1 +class TestToolCallIntentOrError: + def test_false_for_plain_final_answer(self): + assert _has_tool_call_intent_or_error(_ai_no_tool_calls()) is False + + def test_true_for_structured_tool_calls(self): + assert _has_tool_call_intent_or_error(_ai_with_write_todos()) is True + + def test_true_for_invalid_tool_calls(self): + assert _has_tool_call_intent_or_error(_ai_with_invalid_tool_calls()) is True + + def test_true_for_raw_provider_tool_calls(self): + assert _has_tool_call_intent_or_error(_ai_with_raw_provider_tool_calls()) is True + + def test_true_for_legacy_function_call(self): + assert _has_tool_call_intent_or_error(_ai_with_legacy_function_call()) is True + + def test_true_for_tool_finish_reason(self): + assert _has_tool_call_intent_or_error(_ai_with_tool_finish_reason()) is True + + def test_langchain_ai_message_tool_fields_are_explicitly_handled(self): + # Sentinel for LangChain compatibility: if future AIMessage versions add + # new top-level tool/function-call fields, this test should fail. When + # it does, update `_has_tool_call_intent_or_error()` so the completion + # reminder guard explicitly decides whether each new field means "not a + # clean final answer"; the helper has a matching comment pointing back + # to this sentinel. + tool_related_fields = {name for name in AIMessage.model_fields if "tool" in name.lower() or ("function" in name.lower() and "call" in name.lower())} + assert tool_related_fields <= {"tool_calls", "invalid_tool_calls"} + + class TestAfterModel: def test_returns_none_when_agent_still_using_tools(self): mw = TodoMiddleware() @@ -235,68 +348,299 @@ class TestAfterModel: } assert mw.after_model(state, _make_runtime()) is None - def test_injects_reminder_and_jumps_to_model_when_incomplete(self): + def test_queues_reminder_and_jumps_to_model_when_incomplete(self): mw = TodoMiddleware() + runtime = _make_runtime() state = { "messages": [HumanMessage(content="hi"), _ai_no_tool_calls()], "todos": _incomplete_todos(), } - result = mw.after_model(state, _make_runtime()) + result = mw.after_model(state, runtime) assert result is not None assert result["jump_to"] == "model" - assert len(result["messages"]) == 1 - reminder = result["messages"][0] + assert "messages" not in result + + request = MagicMock() + request.runtime = runtime + request.messages = state["messages"] + request.override.return_value = "patched-request" + handler = MagicMock(return_value="response") + + assert mw.wrap_model_call(request, handler) == "response" + request.override.assert_called_once() + reminder = request.override.call_args.kwargs["messages"][-1] assert isinstance(reminder, HumanMessage) assert reminder.name == "todo_completion_reminder" + assert reminder.additional_kwargs["hide_from_ui"] is True assert "Step 2" in reminder.content assert "Step 3" in reminder.content + handler.assert_called_once_with("patched-request") def test_reminder_lists_only_incomplete_items(self): mw = TodoMiddleware() + runtime = _make_runtime() state = { "messages": [_ai_no_tool_calls()], "todos": _incomplete_todos(), } - result = mw.after_model(state, _make_runtime()) - content = result["messages"][0].content + result = mw.after_model(state, runtime) + assert result is not None + + request = MagicMock() + request.runtime = runtime + request.messages = state["messages"] + request.override.return_value = "patched-request" + mw.wrap_model_call(request, MagicMock(return_value="response")) + content = request.override.call_args.kwargs["messages"][-1].content assert "Step 1" not in content # completed — should not appear assert "Step 2" in content assert "Step 3" in content def test_allows_exit_after_max_reminders(self): mw = TodoMiddleware() + runtime = _make_runtime() state = { "messages": [ - _completion_reminder_msg(), - _completion_reminder_msg(), _ai_no_tool_calls(), ], "todos": _incomplete_todos(), } + assert mw.after_model(state, runtime) is not None + assert mw.after_model(state, runtime) is not None + assert mw.after_model(state, runtime) is None + + def test_still_sends_reminder_before_cap(self): + mw = TodoMiddleware() + runtime = _make_runtime() + state = { + "messages": [ + _ai_no_tool_calls(), + ], + "todos": _incomplete_todos(), + } + assert mw.after_model(state, runtime) is not None + result = mw.after_model(state, runtime) + assert result is not None + assert result["jump_to"] == "model" + + def test_does_not_trigger_for_invalid_tool_calls(self): + mw = TodoMiddleware() + state = { + "messages": [_ai_with_invalid_tool_calls()], + "todos": _incomplete_todos(), + } assert mw.after_model(state, _make_runtime()) is None - def test_still_sends_reminder_before_cap(self): + def test_does_not_trigger_for_raw_provider_tool_calls(self): mw = TodoMiddleware() state = { - "messages": [ - _completion_reminder_msg(), # 1 reminder so far - _ai_no_tool_calls(), - ], + "messages": [_ai_with_raw_provider_tool_calls()], "todos": _incomplete_todos(), } - result = mw.after_model(state, _make_runtime()) - assert result is not None - assert result["jump_to"] == "model" + assert mw.after_model(state, _make_runtime()) is None + + def test_does_not_trigger_for_legacy_function_call(self): + mw = TodoMiddleware() + state = { + "messages": [_ai_with_legacy_function_call()], + "todos": _incomplete_todos(), + } + assert mw.after_model(state, _make_runtime()) is None + + def test_does_not_trigger_for_tool_finish_reason(self): + mw = TodoMiddleware() + state = { + "messages": [_ai_with_tool_finish_reason()], + "todos": _incomplete_todos(), + } + assert mw.after_model(state, _make_runtime()) is None class TestAafterModel: def test_delegates_to_sync(self): mw = TodoMiddleware() + runtime = _make_runtime() state = { "messages": [_ai_no_tool_calls()], "todos": _incomplete_todos(), } - result = asyncio.run(mw.aafter_model(state, _make_runtime())) + result = asyncio.run(mw.aafter_model(state, runtime)) assert result is not None assert result["jump_to"] == "model" - assert result["messages"][0].name == "todo_completion_reminder" + assert "messages" not in result + + +class TestWrapModelCall: + def test_no_pending_reminder_passthrough(self): + mw = TodoMiddleware() + request = MagicMock() + request.runtime = _make_runtime() + request.messages = [HumanMessage(content="hi")] + handler = MagicMock(return_value="response") + + assert mw.wrap_model_call(request, handler) == "response" + request.override.assert_not_called() + handler.assert_called_once_with(request) + + def test_pending_reminder_is_injected_once(self): + mw = TodoMiddleware() + runtime = _make_runtime() + state = { + "messages": [_ai_no_tool_calls()], + "todos": _incomplete_todos(), + } + mw.after_model(state, runtime) + + request = MagicMock() + request.runtime = runtime + request.messages = state["messages"] + request.override.return_value = "patched-request" + handler = MagicMock(return_value="response") + + assert mw.wrap_model_call(request, handler) == "response" + injected_messages = request.override.call_args.kwargs["messages"] + assert injected_messages[-1].name == "todo_completion_reminder" + + request.override.reset_mock() + handler.reset_mock() + handler.return_value = "second-response" + assert mw.wrap_model_call(request, handler) == "second-response" + request.override.assert_not_called() + handler.assert_called_once_with(request) + + +class TestTodoMiddlewareAgentGraphIntegration: + def test_completion_reminder_is_transient_in_real_agent_graph(self): + mw = TodoMiddleware() + model = _CapturingFakeMessagesListChatModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "write_todos", + "id": "todos-1", + "args": { + "todos": [ + {"content": "Step 1", "status": "completed"}, + {"content": "Step 2", "status": "pending"}, + ] + }, + } + ], + ), + AIMessage(content="premature final 1"), + AIMessage(content="premature final 2"), + AIMessage(content="premature final 3"), + ], + ) + graph = create_agent(model=model, tools=[], middleware=[mw]) + + result = graph.invoke( + {"messages": [("user", "finish all todos")]}, + context={"thread_id": "integration-thread", "run_id": "integration-run"}, + ) + + assert len(model.seen_messages) == 4 + reminders_by_call = [_todo_completion_reminders(messages) for messages in model.seen_messages] + assert reminders_by_call[0] == [] + assert reminders_by_call[1] == [] + assert len(reminders_by_call[2]) == 1 + assert len(reminders_by_call[3]) == 1 + assert "Step 1" not in reminders_by_call[2][0].content + assert "Step 2" in reminders_by_call[2][0].content + + persisted_reminders = _todo_completion_reminders(result["messages"]) + assert persisted_reminders == [] + assert result["messages"][-1].content == "premature final 3" + assert result["todos"] == [ + {"content": "Step 1", "status": "completed"}, + {"content": "Step 2", "status": "pending"}, + ] + assert mw._pending_completion_reminders == {} + assert mw._completion_reminder_counts == {} + + +class TestRunScopedReminderCleanup: + def test_before_agent_clears_stale_count_without_pending_reminder(self): + mw = TodoMiddleware() + stale_runtime = _make_runtime() + stale_runtime.context = {"thread_id": "test-thread", "run_id": "stale-run"} + current_runtime = _make_runtime() + current_runtime.context = {"thread_id": "test-thread", "run_id": "current-run"} + other_thread_runtime = _make_runtime() + other_thread_runtime.context = {"thread_id": "other-thread", "run_id": "stale-run"} + + state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()} + assert mw.after_model(state, stale_runtime) is not None + assert mw.after_model(state, other_thread_runtime) is not None + + # Simulate a model call that drained the pending message, followed by an + # abnormal run end where after_agent did not clear the reminder count. + assert mw._drain_completion_reminders(stale_runtime) + assert mw._completion_reminder_count_for_runtime(stale_runtime) == 1 + + mw.before_agent({}, current_runtime) + + assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0 + assert mw._completion_reminder_count_for_runtime(other_thread_runtime) == 1 + + def test_size_guard_prunes_oldest_count_only_reminder_state(self): + mw = TodoMiddleware() + mw._MAX_COMPLETION_REMINDER_KEYS = 2 + first_runtime = _make_runtime_for("thread-a", "run-a") + second_runtime = _make_runtime_for("thread-b", "run-b") + third_runtime = _make_runtime_for("thread-c", "run-c") + + state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()} + assert mw.after_model(state, first_runtime) is not None + + # Simulate the normal model request path: pending reminder is consumed, + # but the run count remains until after_agent() or stale cleanup. + assert mw._drain_completion_reminders(first_runtime) + assert mw._completion_reminder_count_for_runtime(first_runtime) == 1 + + assert mw.after_model(state, second_runtime) is not None + assert mw.after_model(state, third_runtime) is not None + + assert mw._completion_reminder_count_for_runtime(first_runtime) == 0 + assert mw._completion_reminder_count_for_runtime(second_runtime) == 1 + assert mw._completion_reminder_count_for_runtime(third_runtime) == 1 + assert ("thread-a", "run-a") not in mw._completion_reminder_touch_order + + def test_size_guard_prunes_pending_and_count_state_together(self): + mw = TodoMiddleware() + mw._MAX_COMPLETION_REMINDER_KEYS = 1 + stale_runtime = _make_runtime_for("thread-a", "run-a") + current_runtime = _make_runtime_for("thread-b", "run-b") + + state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()} + assert mw.after_model(state, stale_runtime) is not None + assert mw.after_model(state, current_runtime) is not None + + assert mw._drain_completion_reminders(stale_runtime) == [] + assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0 + assert mw._completion_reminder_count_for_runtime(current_runtime) == 1 + + +class TestAwrapModelCall: + def test_async_pending_reminder_is_injected(self): + mw = TodoMiddleware() + runtime = _make_runtime() + state = { + "messages": [_ai_no_tool_calls()], + "todos": _incomplete_todos(), + } + mw.after_model(state, runtime) + + request = MagicMock() + request.runtime = runtime + request.messages = state["messages"] + request.override.return_value = "patched-request" + handler = AsyncMock(return_value="response") + + result = asyncio.run(mw.awrap_model_call(request, handler)) + assert result == "response" + injected_messages = request.override.call_args.kwargs["messages"] + assert injected_messages[-1].name == "todo_completion_reminder" + handler.assert_awaited_once_with("patched-request") diff --git a/frontend/src/core/messages/utils.ts b/frontend/src/core/messages/utils.ts index e20daa1b6..3f1fef9ad 100644 --- a/frontend/src/core/messages/utils.ts +++ b/frontend/src/core/messages/utils.ts @@ -26,6 +26,13 @@ export type MessageGroup = | AssistantClarificationGroup | AssistantSubagentGroup; +const HIDDEN_CONTROL_MESSAGE_NAMES = new Set([ + "summary", + "loop_warning", + "todo_reminder", + "todo_completion_reminder", +]); + export function getMessageGroups(messages: Message[]): MessageGroup[] { if (messages.length === 0) { return []; @@ -53,10 +60,6 @@ export function getMessageGroups(messages: Message[]): MessageGroup[] { continue; } - if (message.name === "todo_reminder") { - continue; - } - if (message.type === "human") { groups.push({ id: message.id, type: "human", messages: [message] }); continue; @@ -368,8 +371,8 @@ export function findToolCallResult(toolCallId: string, messages: Message[]) { export function isHiddenFromUIMessage(message: Message) { return ( message.additional_kwargs?.hide_from_ui === true || - message.name === "summary" || - message.name === "loop_warning" + (typeof message.name === "string" && + HIDDEN_CONTROL_MESSAGE_NAMES.has(message.name)) ); } diff --git a/frontend/tests/unit/core/messages/utils.test.ts b/frontend/tests/unit/core/messages/utils.test.ts index 24d014c7e..cbc245583 100644 --- a/frontend/tests/unit/core/messages/utils.test.ts +++ b/frontend/tests/unit/core/messages/utils.test.ts @@ -63,3 +63,37 @@ test("aggregates token usage messages once per assistant turn", () => { ), ).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]); }); + +test("hides internal todo reminder messages from message groups", () => { + const messages = [ + { + id: "human-1", + type: "human", + content: "Audit the middleware", + }, + { + id: "todo-reminder-1", + type: "human", + name: "todo_completion_reminder", + content: "finish todos", + }, + { + id: "todo-reminder-2", + type: "human", + name: "todo_reminder", + content: "remember todos", + }, + { + id: "ai-1", + type: "ai", + content: "Done", + }, + ] as Message[]; + + const groups = getMessageGroups(messages); + + expect(groups.map((group) => group.type)).toEqual(["human", "assistant"]); + expect( + groups.flatMap((group) => group.messages).map((message) => message.id), + ).toEqual(["human-1", "ai-1"]); +});