fix(todo): reuse thread state schema (#3206)

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
QY
2026-05-26 23:58:08 +08:00
committed by GitHub
parent da41701f87
commit 92905e9e3e
2 changed files with 50 additions and 9 deletions
@@ -20,11 +20,13 @@ from collections.abc import Awaitable, Callable
from typing import Any, override
from langchain.agents.middleware import TodoListMiddleware
from langchain.agents.middleware.todo import PlanningState, Todo
from langchain.agents.middleware.todo import Todo
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadState
def _todos_in_messages(messages: list[Any]) -> bool:
"""Return True if any AIMessage in *messages* contains a write_todos tool call."""
@@ -113,10 +115,12 @@ class TodoMiddleware(TodoListMiddleware):
and injects a reminder message so the model can continue tracking progress.
"""
state_schema = ThreadState
@override
def before_model(
self,
state: PlanningState,
state: ThreadState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Inject a todo-list reminder when write_todos has left the context window."""
@@ -154,7 +158,7 @@ class TodoMiddleware(TodoListMiddleware):
@override
async def abefore_model(
self,
state: PlanningState,
state: ThreadState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async version of before_model."""
@@ -249,12 +253,12 @@ class TodoMiddleware(TodoListMiddleware):
self._drop_completion_reminder_key_locked(key)
@override
def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
def before_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_other_run_completion_reminders(runtime)
return None
@override
async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
async def abefore_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_other_run_completion_reminders(runtime)
return None
@@ -262,7 +266,7 @@ class TodoMiddleware(TodoListMiddleware):
@override
def after_model(
self,
state: PlanningState,
state: ThreadState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Prevent premature agent exit when todo items are still incomplete.
@@ -308,7 +312,7 @@ class TodoMiddleware(TodoListMiddleware):
@hook_config(can_jump_to=["model"])
async def aafter_model(
self,
state: PlanningState,
state: ThreadState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async version of after_model."""
@@ -349,11 +353,11 @@ class TodoMiddleware(TodoListMiddleware):
return await handler(self._augment_request(request))
@override
def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
def after_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_current_run_completion_reminders(runtime)
return None
@override
async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
async def aafter_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_current_run_completion_reminders(runtime)
return None
+37
View File
@@ -17,6 +17,7 @@ from deerflow.agents.middlewares.todo_middleware import (
_reminder_in_messages,
_todos_in_messages,
)
from deerflow.agents.thread_state import ThreadState
def _ai_with_write_todos():
@@ -510,6 +511,42 @@ class TestWrapModelCall:
class TestTodoMiddlewareAgentGraphIntegration:
def test_reuses_thread_state_todos_schema_in_real_agent_graph(self):
mw = TodoMiddleware()
model = _CapturingFakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "write_todos",
"id": "todos-1",
"args": {
"todos": [
{"content": "Step 1", "status": "pending"},
]
},
}
],
),
AIMessage(content="final"),
],
)
graph = create_agent(
model=model,
tools=[],
middleware=[mw],
state_schema=ThreadState,
)
result = graph.invoke(
{"messages": [("user", "create a todo")]},
context={"thread_id": "schema-thread", "run_id": "schema-run"},
)
assert result["todos"] == [{"content": "Step 1", "status": "pending"}]
def test_completion_reminder_is_transient_in_real_agent_graph(self):
mw = TodoMiddleware()
model = _CapturingFakeMessagesListChatModel(