mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 09:25:57 +00:00
fix(chat): preserve messages after summarization (#3280)
* fix(chat): preserve messages after summarization * make format * fix(chat): address summarization review comments
This commit is contained in:
@@ -5,7 +5,10 @@ from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY, DynamicContextMiddleware
|
||||
@@ -22,6 +25,23 @@ def _messages() -> list:
|
||||
]
|
||||
|
||||
|
||||
class _StaticChatModel(BaseChatModel):
|
||||
text: str = "ok"
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "static-test-chat-model"
|
||||
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=self.text))])
|
||||
|
||||
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
|
||||
def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
|
||||
return HumanMessage(
|
||||
content="<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
|
||||
@@ -114,6 +134,32 @@ def test_before_summarization_hook_receives_messages_before_compression() -> Non
|
||||
assert result["messages"][1].content.startswith("Here is a summary")
|
||||
|
||||
|
||||
def test_summarization_middleware_emits_frontend_update_key_in_agent_stream() -> None:
|
||||
middleware = DeerFlowSummarizationMiddleware(
|
||||
model=_StaticChatModel(text="compressed summary"),
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
token_counter=len,
|
||||
)
|
||||
agent = create_agent(
|
||||
model=_StaticChatModel(text="done"),
|
||||
tools=[],
|
||||
middleware=[middleware],
|
||||
)
|
||||
|
||||
chunks = list(agent.stream({"messages": _messages()}, stream_mode="updates"))
|
||||
update = next(
|
||||
(chunk["DeerFlowSummarizationMiddleware.before_model"] for chunk in chunks if "DeerFlowSummarizationMiddleware.before_model" in chunk),
|
||||
None,
|
||||
)
|
||||
|
||||
assert update is not None
|
||||
emitted = update["messages"]
|
||||
assert isinstance(emitted[0], RemoveMessage)
|
||||
assert emitted[1].name == "summary"
|
||||
assert emitted[1].content == ("Here is a summary of the conversation to date:\n\ncompressed summary")
|
||||
|
||||
|
||||
def test_dynamic_context_reminder_is_preserved_across_summarization() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(before_summarization=[captured.append])
|
||||
|
||||
Reference in New Issue
Block a user