fix(todo): reuse thread state schema (#3206)

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
QY
2026-05-26 23:58:08 +08:00
committed by GitHub
parent da41701f87
commit 92905e9e3e
2 changed files with 50 additions and 9 deletions
@@ -20,11 +20,13 @@ from collections.abc import Awaitable, Callable
from typing import Any, override from typing import Any, override
from langchain.agents.middleware import TodoListMiddleware from langchain.agents.middleware import TodoListMiddleware
from langchain.agents.middleware.todo import PlanningState, Todo from langchain.agents.middleware.todo import Todo
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadState
def _todos_in_messages(messages: list[Any]) -> bool: def _todos_in_messages(messages: list[Any]) -> bool:
"""Return True if any AIMessage in *messages* contains a write_todos tool call.""" """Return True if any AIMessage in *messages* contains a write_todos tool call."""
@@ -113,10 +115,12 @@ class TodoMiddleware(TodoListMiddleware):
and injects a reminder message so the model can continue tracking progress. and injects a reminder message so the model can continue tracking progress.
""" """
state_schema = ThreadState
@override @override
def before_model( def before_model(
self, self,
state: PlanningState, state: ThreadState,
runtime: Runtime, runtime: Runtime,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Inject a todo-list reminder when write_todos has left the context window.""" """Inject a todo-list reminder when write_todos has left the context window."""
@@ -154,7 +158,7 @@ class TodoMiddleware(TodoListMiddleware):
@override @override
async def abefore_model( async def abefore_model(
self, self,
state: PlanningState, state: ThreadState,
runtime: Runtime, runtime: Runtime,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Async version of before_model.""" """Async version of before_model."""
@@ -249,12 +253,12 @@ class TodoMiddleware(TodoListMiddleware):
self._drop_completion_reminder_key_locked(key) self._drop_completion_reminder_key_locked(key)
@override @override
def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: def before_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_other_run_completion_reminders(runtime) self._clear_other_run_completion_reminders(runtime)
return None return None
@override @override
async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: async def abefore_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_other_run_completion_reminders(runtime) self._clear_other_run_completion_reminders(runtime)
return None return None
@@ -262,7 +266,7 @@ class TodoMiddleware(TodoListMiddleware):
@override @override
def after_model( def after_model(
self, self,
state: PlanningState, state: ThreadState,
runtime: Runtime, runtime: Runtime,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Prevent premature agent exit when todo items are still incomplete. """Prevent premature agent exit when todo items are still incomplete.
@@ -308,7 +312,7 @@ class TodoMiddleware(TodoListMiddleware):
@hook_config(can_jump_to=["model"]) @hook_config(can_jump_to=["model"])
async def aafter_model( async def aafter_model(
self, self,
state: PlanningState, state: ThreadState,
runtime: Runtime, runtime: Runtime,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Async version of after_model.""" """Async version of after_model."""
@@ -349,11 +353,11 @@ class TodoMiddleware(TodoListMiddleware):
return await handler(self._augment_request(request)) return await handler(self._augment_request(request))
@override @override
def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: def after_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_current_run_completion_reminders(runtime) self._clear_current_run_completion_reminders(runtime)
return None return None
@override @override
async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: async def aafter_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_current_run_completion_reminders(runtime) self._clear_current_run_completion_reminders(runtime)
return None return None
+37
View File
@@ -17,6 +17,7 @@ from deerflow.agents.middlewares.todo_middleware import (
_reminder_in_messages, _reminder_in_messages,
_todos_in_messages, _todos_in_messages,
) )
from deerflow.agents.thread_state import ThreadState
def _ai_with_write_todos(): def _ai_with_write_todos():
@@ -510,6 +511,42 @@ class TestWrapModelCall:
class TestTodoMiddlewareAgentGraphIntegration: class TestTodoMiddlewareAgentGraphIntegration:
def test_reuses_thread_state_todos_schema_in_real_agent_graph(self):
mw = TodoMiddleware()
model = _CapturingFakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "write_todos",
"id": "todos-1",
"args": {
"todos": [
{"content": "Step 1", "status": "pending"},
]
},
}
],
),
AIMessage(content="final"),
],
)
graph = create_agent(
model=model,
tools=[],
middleware=[mw],
state_schema=ThreadState,
)
result = graph.invoke(
{"messages": [("user", "create a todo")]},
context={"thread_id": "schema-thread", "run_id": "schema-run"},
)
assert result["todos"] == [{"content": "Step 1", "status": "pending"}]
def test_completion_reminder_is_transient_in_real_agent_graph(self): def test_completion_reminder_is_transient_in_real_agent_graph(self):
mw = TodoMiddleware() mw = TodoMiddleware()
model = _CapturingFakeMessagesListChatModel( model = _CapturingFakeMessagesListChatModel(