mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
feat: flush memory before summarization (#2176)
* feat: flush memory before summarization * fix: keep agent-scoped memory on summarization flush * fix: harden summarization hook plumbing * fix: address summarization review feedback * style: format memory middleware
This commit is contained in:
@@ -8,6 +8,7 @@ import pytest
|
||||
|
||||
from deerflow.agents.lead_agent import agent as lead_agent_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.summarization_config import SummarizationConfig
|
||||
@@ -145,6 +146,7 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
"get_summarization_config",
|
||||
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
|
||||
)
|
||||
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
fake_model = object()
|
||||
@@ -156,10 +158,32 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
return fake_model
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||
monkeypatch.setattr(lead_agent_module, "SummarizationMiddleware", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
|
||||
|
||||
middleware = lead_agent_module._create_summarization_middleware()
|
||||
|
||||
assert captured["name"] == "model-masswork"
|
||||
assert captured["thinking_enabled"] is False
|
||||
assert middleware["model"] is fake_model
|
||||
|
||||
|
||||
def test_create_summarization_middleware_registers_memory_flush_hook_when_memory_enabled(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
lead_agent_module,
|
||||
"get_summarization_config",
|
||||
lambda: SummarizationConfig(enabled=True),
|
||||
)
|
||||
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: object())
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_middleware(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return kwargs
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", _fake_middleware)
|
||||
|
||||
lead_agent_module._create_summarization_middleware()
|
||||
|
||||
assert captured["before_summarization"] == [lead_agent_module.memory_flush_hook]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
@@ -89,3 +91,74 @@ def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
||||
correction_detected=False,
|
||||
reinforcement_detected=True,
|
||||
)
|
||||
|
||||
|
||||
def test_flush_nowait_cancels_existing_timer_and_starts_immediate_timer() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
existing_timer = MagicMock()
|
||||
queue._timer = existing_timer
|
||||
created_timer = MagicMock()
|
||||
|
||||
with patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls:
|
||||
queue.flush_nowait()
|
||||
|
||||
existing_timer.cancel.assert_called_once_with()
|
||||
timer_cls.assert_called_once_with(0, queue._process_queue)
|
||||
assert created_timer.daemon is True
|
||||
created_timer.start.assert_called_once_with()
|
||||
assert queue._timer is created_timer
|
||||
|
||||
|
||||
def test_add_nowait_cancels_existing_timer_and_starts_immediate_timer() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
existing_timer = MagicMock()
|
||||
queue._timer = existing_timer
|
||||
created_timer = MagicMock()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls,
|
||||
):
|
||||
queue.add_nowait(thread_id="thread-1", messages=["conversation"], agent_name="lead-agent")
|
||||
|
||||
existing_timer.cancel.assert_called_once_with()
|
||||
timer_cls.assert_called_once_with(0, queue._process_queue)
|
||||
assert queue.pending_count == 1
|
||||
assert queue._queue[0].agent_name == "lead-agent"
|
||||
assert created_timer.daemon is True
|
||||
created_timer.start.assert_called_once_with()
|
||||
|
||||
|
||||
def test_process_queue_reschedules_immediately_when_already_processing() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue._processing = True
|
||||
created_timer = MagicMock()
|
||||
|
||||
with patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls:
|
||||
queue._process_queue()
|
||||
|
||||
timer_cls.assert_called_once_with(0, queue._process_queue)
|
||||
assert created_timer.daemon is True
|
||||
created_timer.start.assert_called_once_with()
|
||||
|
||||
|
||||
def test_flush_nowait_is_non_blocking() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
started = threading.Event()
|
||||
finished = threading.Event()
|
||||
|
||||
def _slow_process_queue() -> None:
|
||||
started.set()
|
||||
time.sleep(0.2)
|
||||
finished.set()
|
||||
|
||||
queue._process_queue = _slow_process_queue
|
||||
|
||||
start = time.perf_counter()
|
||||
queue.flush_nowait()
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
assert started.wait(0.1) is True
|
||||
assert elapsed < 0.1
|
||||
assert finished.is_set() is False
|
||||
assert finished.wait(1.0) is True
|
||||
|
||||
@@ -3,14 +3,14 @@
|
||||
Covers two functions introduced to prevent ephemeral file-upload context from
|
||||
persisting in long-term memory:
|
||||
|
||||
- _filter_messages_for_memory (memory_middleware)
|
||||
- filter_messages_for_memory (message_processing)
|
||||
- _strip_upload_mentions_from_memory (updater)
|
||||
"""
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
|
||||
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
|
||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -31,7 +31,7 @@ def _ai(text: str, tool_calls=None) -> AIMessage:
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _filter_messages_for_memory
|
||||
# filter_messages_for_memory
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestFilterMessagesForMemory:
|
||||
_human(_UPLOAD_BLOCK),
|
||||
_ai("I have read the file. It says: Hello."),
|
||||
]
|
||||
result = _filter_messages_for_memory(msgs)
|
||||
result = filter_messages_for_memory(msgs)
|
||||
assert result == []
|
||||
|
||||
def test_upload_with_real_question_preserves_question(self):
|
||||
@@ -56,7 +56,7 @@ class TestFilterMessagesForMemory:
|
||||
_human(combined),
|
||||
_ai("The file contains: Hello DeerFlow."),
|
||||
]
|
||||
result = _filter_messages_for_memory(msgs)
|
||||
result = filter_messages_for_memory(msgs)
|
||||
|
||||
assert len(result) == 2
|
||||
human_result = result[0]
|
||||
@@ -71,7 +71,7 @@ class TestFilterMessagesForMemory:
|
||||
_human("What is the capital of France?"),
|
||||
_ai("The capital of France is Paris."),
|
||||
]
|
||||
result = _filter_messages_for_memory(msgs)
|
||||
result = filter_messages_for_memory(msgs)
|
||||
assert len(result) == 2
|
||||
assert result[0].content == "What is the capital of France?"
|
||||
assert result[1].content == "The capital of France is Paris."
|
||||
@@ -84,7 +84,7 @@ class TestFilterMessagesForMemory:
|
||||
ToolMessage(content="Search results", tool_call_id="1"),
|
||||
_ai("Here are the results."),
|
||||
]
|
||||
result = _filter_messages_for_memory(msgs)
|
||||
result = filter_messages_for_memory(msgs)
|
||||
human_msgs = [m for m in result if m.type == "human"]
|
||||
ai_msgs = [m for m in result if m.type == "ai"]
|
||||
assert len(human_msgs) == 1
|
||||
@@ -101,7 +101,7 @@ class TestFilterMessagesForMemory:
|
||||
_human("What is 2 + 2?"),
|
||||
_ai("4"),
|
||||
]
|
||||
result = _filter_messages_for_memory(msgs)
|
||||
result = filter_messages_for_memory(msgs)
|
||||
human_contents = [m.content for m in result if m.type == "human"]
|
||||
ai_contents = [m.content for m in result if m.type == "ai"]
|
||||
|
||||
@@ -121,14 +121,14 @@ class TestFilterMessagesForMemory:
|
||||
]
|
||||
)
|
||||
msgs = [msg, _ai("Done.")]
|
||||
result = _filter_messages_for_memory(msgs)
|
||||
result = filter_messages_for_memory(msgs)
|
||||
assert result == []
|
||||
|
||||
def test_file_path_not_in_filtered_content(self):
|
||||
"""After filtering, no upload file path should appear in any message."""
|
||||
combined = _UPLOAD_BLOCK + "\n\nSummarise the file please."
|
||||
msgs = [_human(combined), _ai("It says hello.")]
|
||||
result = _filter_messages_for_memory(msgs)
|
||||
result = filter_messages_for_memory(msgs)
|
||||
all_content = " ".join(m.content for m in result if isinstance(m.content, str))
|
||||
assert "/mnt/user-data/uploads/" not in all_content
|
||||
assert "<uploaded_files>" not in all_content
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage
|
||||
|
||||
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||
from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
|
||||
def _messages() -> list:
|
||||
return [
|
||||
HumanMessage(content="user-1"),
|
||||
AIMessage(content="assistant-1"),
|
||||
HumanMessage(content="user-2"),
|
||||
AIMessage(content="assistant-2"),
|
||||
]
|
||||
|
||||
|
||||
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
|
||||
context = {}
|
||||
if thread_id is not None:
|
||||
context["thread_id"] = thread_id
|
||||
if agent_name is not None:
|
||||
context["agent_name"] = agent_name
|
||||
return SimpleNamespace(context=context)
|
||||
|
||||
|
||||
def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("messages", 2)) -> DeerFlowSummarizationMiddleware:
|
||||
model = MagicMock()
|
||||
model.invoke.return_value = SimpleNamespace(text="compressed summary")
|
||||
return DeerFlowSummarizationMiddleware(
|
||||
model=model,
|
||||
trigger=trigger,
|
||||
keep=keep,
|
||||
token_counter=len,
|
||||
before_summarization=before_summarization,
|
||||
)
|
||||
|
||||
|
||||
def test_before_summarization_hook_receives_messages_before_compression() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(before_summarization=[captured.append])
|
||||
|
||||
result = middleware.before_model({"messages": _messages()}, _runtime())
|
||||
|
||||
assert len(captured) == 1
|
||||
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"]
|
||||
assert [message.content for message in captured[0].preserved_messages] == ["user-2", "assistant-2"]
|
||||
assert captured[0].thread_id == "thread-1"
|
||||
assert captured[0].agent_name is None
|
||||
assert isinstance(result["messages"][0], RemoveMessage)
|
||||
assert result["messages"][1].content.startswith("Here is a summary")
|
||||
|
||||
|
||||
def test_before_summarization_hook_not_called_when_threshold_not_met() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(before_summarization=[captured.append], trigger=("messages", 10))
|
||||
|
||||
result = middleware.before_model({"messages": _messages()}, _runtime())
|
||||
|
||||
assert captured == []
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_before_summarization_hook_exception_does_not_block_compression(caplog: pytest.LogCaptureFixture) -> None:
|
||||
def _broken_hook(_: SummarizationEvent) -> None:
|
||||
raise RuntimeError("hook failure")
|
||||
|
||||
middleware = _middleware(before_summarization=[_broken_hook])
|
||||
|
||||
with caplog.at_level("ERROR"):
|
||||
result = middleware.before_model({"messages": _messages()}, _runtime())
|
||||
|
||||
assert "before_summarization hook _broken_hook failed" in caplog.text
|
||||
assert isinstance(result["messages"][0], RemoveMessage)
|
||||
|
||||
|
||||
def test_multiple_before_summarization_hooks_run_in_registration_order() -> None:
|
||||
call_order: list[str] = []
|
||||
|
||||
def _hook(name: str):
|
||||
return lambda _: call_order.append(name)
|
||||
|
||||
middleware = _middleware(before_summarization=[_hook("first"), _hook("second"), _hook("third")])
|
||||
|
||||
middleware.before_model({"messages": _messages()}, _runtime())
|
||||
|
||||
assert call_order == ["first", "second", "third"]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_abefore_model_calls_hooks_same_as_sync() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(before_summarization=[captured.append])
|
||||
|
||||
await middleware.abefore_model({"messages": _messages()}, _runtime())
|
||||
|
||||
assert len(captured) == 1
|
||||
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"]
|
||||
|
||||
|
||||
def test_memory_flush_hook_skips_when_memory_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
queue = MagicMock()
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=False))
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||
|
||||
memory_flush_hook(
|
||||
SummarizationEvent(
|
||||
messages_to_summarize=tuple(_messages()[:2]),
|
||||
preserved_messages=(),
|
||||
thread_id="thread-1",
|
||||
agent_name=None,
|
||||
runtime=_runtime(),
|
||||
)
|
||||
)
|
||||
|
||||
queue.add_nowait.assert_not_called()
|
||||
|
||||
|
||||
def test_memory_flush_hook_skips_when_thread_id_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
queue = MagicMock()
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||
|
||||
memory_flush_hook(
|
||||
SummarizationEvent(
|
||||
messages_to_summarize=tuple(_messages()[:2]),
|
||||
preserved_messages=(),
|
||||
thread_id=None,
|
||||
agent_name=None,
|
||||
runtime=_runtime(None),
|
||||
)
|
||||
)
|
||||
|
||||
queue.add_nowait.assert_not_called()
|
||||
|
||||
|
||||
def test_memory_flush_hook_enqueues_filtered_messages_and_flushes(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
queue = MagicMock()
|
||||
messages = [
|
||||
HumanMessage(content="Question"),
|
||||
AIMessage(content="Calling tool", tool_calls=[{"name": "search", "id": "tool-1", "args": {}}]),
|
||||
AIMessage(content="Final answer"),
|
||||
]
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||
|
||||
memory_flush_hook(
|
||||
SummarizationEvent(
|
||||
messages_to_summarize=tuple(messages),
|
||||
preserved_messages=(),
|
||||
thread_id="thread-1",
|
||||
agent_name=None,
|
||||
runtime=_runtime(),
|
||||
)
|
||||
)
|
||||
|
||||
queue.add_nowait.assert_called_once()
|
||||
add_kwargs = queue.add_nowait.call_args.kwargs
|
||||
assert add_kwargs["thread_id"] == "thread-1"
|
||||
assert [message.content for message in add_kwargs["messages"]] == ["Question", "Final answer"]
|
||||
assert add_kwargs["correction_detected"] is False
|
||||
assert add_kwargs["reinforcement_detected"] is False
|
||||
|
||||
|
||||
def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
queue = MagicMock()
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||
|
||||
memory_flush_hook(
|
||||
SummarizationEvent(
|
||||
messages_to_summarize=tuple(_messages()[:2]),
|
||||
preserved_messages=(),
|
||||
thread_id="thread-1",
|
||||
agent_name="research-agent",
|
||||
runtime=_runtime(agent_name="research-agent"),
|
||||
)
|
||||
)
|
||||
|
||||
queue.add_nowait.assert_called_once()
|
||||
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
|
||||
Reference in New Issue
Block a user