mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-26 18:06:00 +00:00
fix(todo): reuse thread state schema (#3206)
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -20,11 +20,13 @@ from collections.abc import Awaitable, Callable
|
|||||||
from typing import Any, override
|
from typing import Any, override
|
||||||
|
|
||||||
from langchain.agents.middleware import TodoListMiddleware
|
from langchain.agents.middleware import TodoListMiddleware
|
||||||
from langchain.agents.middleware.todo import PlanningState, Todo
|
from langchain.agents.middleware.todo import Todo
|
||||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
|
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from deerflow.agents.thread_state import ThreadState
|
||||||
|
|
||||||
|
|
||||||
def _todos_in_messages(messages: list[Any]) -> bool:
|
def _todos_in_messages(messages: list[Any]) -> bool:
|
||||||
"""Return True if any AIMessage in *messages* contains a write_todos tool call."""
|
"""Return True if any AIMessage in *messages* contains a write_todos tool call."""
|
||||||
@@ -113,10 +115,12 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
and injects a reminder message so the model can continue tracking progress.
|
and injects a reminder message so the model can continue tracking progress.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
state_schema = ThreadState
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def before_model(
|
def before_model(
|
||||||
self,
|
self,
|
||||||
state: PlanningState,
|
state: ThreadState,
|
||||||
runtime: Runtime,
|
runtime: Runtime,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Inject a todo-list reminder when write_todos has left the context window."""
|
"""Inject a todo-list reminder when write_todos has left the context window."""
|
||||||
@@ -154,7 +158,7 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
@override
|
@override
|
||||||
async def abefore_model(
|
async def abefore_model(
|
||||||
self,
|
self,
|
||||||
state: PlanningState,
|
state: ThreadState,
|
||||||
runtime: Runtime,
|
runtime: Runtime,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Async version of before_model."""
|
"""Async version of before_model."""
|
||||||
@@ -249,12 +253,12 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
self._drop_completion_reminder_key_locked(key)
|
self._drop_completion_reminder_key_locked(key)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
def before_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
self._clear_other_run_completion_reminders(runtime)
|
self._clear_other_run_completion_reminders(runtime)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
async def abefore_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
self._clear_other_run_completion_reminders(runtime)
|
self._clear_other_run_completion_reminders(runtime)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -262,7 +266,7 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
@override
|
@override
|
||||||
def after_model(
|
def after_model(
|
||||||
self,
|
self,
|
||||||
state: PlanningState,
|
state: ThreadState,
|
||||||
runtime: Runtime,
|
runtime: Runtime,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Prevent premature agent exit when todo items are still incomplete.
|
"""Prevent premature agent exit when todo items are still incomplete.
|
||||||
@@ -308,7 +312,7 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
@hook_config(can_jump_to=["model"])
|
@hook_config(can_jump_to=["model"])
|
||||||
async def aafter_model(
|
async def aafter_model(
|
||||||
self,
|
self,
|
||||||
state: PlanningState,
|
state: ThreadState,
|
||||||
runtime: Runtime,
|
runtime: Runtime,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Async version of after_model."""
|
"""Async version of after_model."""
|
||||||
@@ -349,11 +353,11 @@ class TodoMiddleware(TodoListMiddleware):
|
|||||||
return await handler(self._augment_request(request))
|
return await handler(self._augment_request(request))
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
def after_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
self._clear_current_run_completion_reminders(runtime)
|
self._clear_current_run_completion_reminders(runtime)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
async def aafter_agent(self, state: ThreadState, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
self._clear_current_run_completion_reminders(runtime)
|
self._clear_current_run_completion_reminders(runtime)
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from deerflow.agents.middlewares.todo_middleware import (
|
|||||||
_reminder_in_messages,
|
_reminder_in_messages,
|
||||||
_todos_in_messages,
|
_todos_in_messages,
|
||||||
)
|
)
|
||||||
|
from deerflow.agents.thread_state import ThreadState
|
||||||
|
|
||||||
|
|
||||||
def _ai_with_write_todos():
|
def _ai_with_write_todos():
|
||||||
@@ -510,6 +511,42 @@ class TestWrapModelCall:
|
|||||||
|
|
||||||
|
|
||||||
class TestTodoMiddlewareAgentGraphIntegration:
|
class TestTodoMiddlewareAgentGraphIntegration:
|
||||||
|
def test_reuses_thread_state_todos_schema_in_real_agent_graph(self):
|
||||||
|
mw = TodoMiddleware()
|
||||||
|
model = _CapturingFakeMessagesListChatModel(
|
||||||
|
responses=[
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "write_todos",
|
||||||
|
"id": "todos-1",
|
||||||
|
"args": {
|
||||||
|
"todos": [
|
||||||
|
{"content": "Step 1", "status": "pending"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessage(content="final"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = create_agent(
|
||||||
|
model=model,
|
||||||
|
tools=[],
|
||||||
|
middleware=[mw],
|
||||||
|
state_schema=ThreadState,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = graph.invoke(
|
||||||
|
{"messages": [("user", "create a todo")]},
|
||||||
|
context={"thread_id": "schema-thread", "run_id": "schema-run"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["todos"] == [{"content": "Step 1", "status": "pending"}]
|
||||||
|
|
||||||
def test_completion_reminder_is_transient_in_real_agent_graph(self):
|
def test_completion_reminder_is_transient_in_real_agent_graph(self):
|
||||||
mw = TodoMiddleware()
|
mw = TodoMiddleware()
|
||||||
model = _CapturingFakeMessagesListChatModel(
|
model = _CapturingFakeMessagesListChatModel(
|
||||||
|
|||||||
Reference in New Issue
Block a user