fix(todo): reuse thread state schema (#3206)

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
QY
2026-05-26 23:58:08 +08:00
committed by GitHub
parent da41701f87
commit 92905e9e3e
2 changed files with 50 additions and 9 deletions
@@ -20,11 +20,13 @@ from collections.abc import Awaitable, Callable
from typing import Any, override
from langchain.agents.middleware import TodoListMiddleware
from langchain.agents.middleware.todo import PlanningState, Todo
from langchain.agents.middleware.todo import Todo
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadState
def _todos_in_messages(messages: list[Any]) -> bool:
"""Return True if any AIMessage in *messages* contains a write_todos tool call."""
@@ -113,10 +115,12 @@ class TodoMiddleware(TodoListMiddleware):
and injects a reminder message so the model can continue tracking progress.
"""
state_schema = ThreadState
@override
def before_model(
self,
state: PlanningState,
state: ThreadState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Inject a todo-list reminder when write_todos has left the context window."""
@@ -154,7 +158,7 @@ class TodoMiddleware(TodoListMiddleware):
@override
async def abefore_model(
self,
state: PlanningState,
state: ThreadState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async version of before_model."""
@@ -249,12 +253,12 @@ class TodoMiddleware(TodoListMiddleware):
self._drop_completion_reminder_key_locked(key)
@override
def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
def before_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_other_run_completion_reminders(runtime)
return None
@override
async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
async def abefore_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_other_run_completion_reminders(runtime)
return None
@@ -262,7 +266,7 @@ class TodoMiddleware(TodoListMiddleware):
@override
def after_model(
self,
state: PlanningState,
state: ThreadState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Prevent premature agent exit when todo items are still incomplete.
@@ -308,7 +312,7 @@ class TodoMiddleware(TodoListMiddleware):
@hook_config(can_jump_to=["model"])
async def aafter_model(
self,
state: PlanningState,
state: ThreadState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async version of after_model."""
@@ -349,11 +353,11 @@ class TodoMiddleware(TodoListMiddleware):
return await handler(self._augment_request(request))
@override
def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
def after_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_current_run_completion_reminders(runtime)
return None
@override
async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
async def aafter_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_current_run_completion_reminders(runtime)
return None