mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-27 02:16:01 +00:00
fix(todo): reuse thread state schema (#3206)
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -20,11 +20,13 @@ from collections.abc import Awaitable, Callable
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents.middleware import TodoListMiddleware
|
||||
from langchain.agents.middleware.todo import PlanningState, Todo
|
||||
from langchain.agents.middleware.todo import Todo
|
||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
|
||||
|
||||
def _todos_in_messages(messages: list[Any]) -> bool:
|
||||
"""Return True if any AIMessage in *messages* contains a write_todos tool call."""
|
||||
@@ -113,10 +115,12 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
and injects a reminder message so the model can continue tracking progress.
|
||||
"""
|
||||
|
||||
state_schema = ThreadState
|
||||
|
||||
@override
|
||||
def before_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
state: ThreadState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Inject a todo-list reminder when write_todos has left the context window."""
|
||||
@@ -154,7 +158,7 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
@override
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
state: ThreadState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async version of before_model."""
|
||||
@@ -249,12 +253,12 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
self._drop_completion_reminder_key_locked(key)
|
||||
|
||||
@override
|
||||
def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
def before_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_other_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@override
|
||||
async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
async def abefore_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_other_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@@ -262,7 +266,7 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
@override
|
||||
def after_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
state: ThreadState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Prevent premature agent exit when todo items are still incomplete.
|
||||
@@ -308,7 +312,7 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
@hook_config(can_jump_to=["model"])
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
state: ThreadState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async version of after_model."""
|
||||
@@ -349,11 +353,11 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
return await handler(self._augment_request(request))
|
||||
|
||||
@override
|
||||
def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
def after_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_current_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@override
|
||||
async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
async def aafter_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_current_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user