diff --git a/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py index 9215aefc5..3e3ebdd81 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py @@ -20,11 +20,13 @@ from collections.abc import Awaitable, Callable from typing import Any, override from langchain.agents.middleware import TodoListMiddleware -from langchain.agents.middleware.todo import PlanningState, Todo +from langchain.agents.middleware.todo import Todo from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config from langchain_core.messages import AIMessage, HumanMessage from langgraph.runtime import Runtime +from deerflow.agents.thread_state import ThreadState + def _todos_in_messages(messages: list[Any]) -> bool: """Return True if any AIMessage in *messages* contains a write_todos tool call.""" @@ -113,10 +115,12 @@ class TodoMiddleware(TodoListMiddleware): and injects a reminder message so the model can continue tracking progress. """ + state_schema = ThreadState + @override def before_model( self, - state: PlanningState, + state: ThreadState, runtime: Runtime, ) -> dict[str, Any] | None: """Inject a todo-list reminder when write_todos has left the context window.""" @@ -154,7 +158,7 @@ class TodoMiddleware(TodoListMiddleware): @override async def abefore_model( self, - state: PlanningState, + state: ThreadState, runtime: Runtime, ) -> dict[str, Any] | None: """Async version of before_model.""" @@ -249,12 +253,12 @@ class TodoMiddleware(TodoListMiddleware): self._drop_completion_reminder_key_locked(key) @override - def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + def before_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None: self._clear_other_run_completion_reminders(runtime) return None @override - async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + async def abefore_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None: self._clear_other_run_completion_reminders(runtime) return None @@ -262,7 +266,7 @@ class TodoMiddleware(TodoListMiddleware): @override def after_model( self, - state: PlanningState, + state: ThreadState, runtime: Runtime, ) -> dict[str, Any] | None: """Prevent premature agent exit when todo items are still incomplete. @@ -308,7 +312,7 @@ class TodoMiddleware(TodoListMiddleware): @hook_config(can_jump_to=["model"]) async def aafter_model( self, - state: PlanningState, + state: ThreadState, runtime: Runtime, ) -> dict[str, Any] | None: """Async version of after_model.""" @@ -349,11 +353,11 @@ class TodoMiddleware(TodoListMiddleware): return await handler(self._augment_request(request)) @override - def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + def after_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None: self._clear_current_run_completion_reminders(runtime) return None @override - async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + async def aafter_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None: self._clear_current_run_completion_reminders(runtime) return None diff --git a/backend/tests/test_todo_middleware.py b/backend/tests/test_todo_middleware.py index 934e730f2..1848b906e 100644 --- a/backend/tests/test_todo_middleware.py +++ b/backend/tests/test_todo_middleware.py @@ -17,6 +17,7 @@ from deerflow.agents.middlewares.todo_middleware import ( _reminder_in_messages, _todos_in_messages, ) +from deerflow.agents.thread_state import ThreadState def _ai_with_write_todos(): @@ -510,6 +511,42 @@ class TestWrapModelCall: class TestTodoMiddlewareAgentGraphIntegration: + def test_reuses_thread_state_todos_schema_in_real_agent_graph(self): + mw = TodoMiddleware() + model = _CapturingFakeMessagesListChatModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "write_todos", + "id": "todos-1", + "args": { + "todos": [ + {"content": "Step 1", "status": "pending"}, + ] + }, + } + ], + ), + AIMessage(content="final"), + ], + ) + + graph = create_agent( + model=model, + tools=[], + middleware=[mw], + state_schema=ThreadState, + ) + + result = graph.invoke( + {"messages": [("user", "create a todo")]}, + context={"thread_id": "schema-thread", "run_id": "schema-run"}, + ) + + assert result["todos"] == [{"content": "Step 1", "status": "pending"}] + def test_completion_reminder_is_transient_in_real_agent_graph(self): mw = TodoMiddleware() model = _CapturingFakeMessagesListChatModel(