feat: flush memory before summarization (#2176)

* feat: flush memory before summarization

* fix: keep agent-scoped memory on summarization flush

* fix: harden summarization hook plumbing

* fix: address summarization review feedback

* style: format memory middleware
This commit is contained in:
DanielWalnut
2026-04-14 15:01:06 +08:00
committed by GitHub
parent e4f896e90d
commit 4ba3167f48
10 changed files with 667 additions and 189 deletions
@@ -8,6 +8,7 @@ import pytest
from deerflow.agents.lead_agent import agent as lead_agent_module
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.summarization_config import SummarizationConfig
@@ -145,6 +146,7 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
"get_summarization_config",
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
)
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
captured: dict[str, object] = {}
fake_model = object()
@@ -156,10 +158,32 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
return fake_model
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "SummarizationMiddleware", lambda **kwargs: kwargs)
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
middleware = lead_agent_module._create_summarization_middleware()
assert captured["name"] == "model-masswork"
assert captured["thinking_enabled"] is False
assert middleware["model"] is fake_model
def test_create_summarization_middleware_registers_memory_flush_hook_when_memory_enabled(monkeypatch):
monkeypatch.setattr(
lead_agent_module,
"get_summarization_config",
lambda: SummarizationConfig(enabled=True),
)
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=True))
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: object())
captured: dict[str, object] = {}
def _fake_middleware(**kwargs):
captured.update(kwargs)
return kwargs
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", _fake_middleware)
lead_agent_module._create_summarization_middleware()
assert captured["before_summarization"] == [lead_agent_module.memory_flush_hook]
+73
View File
@@ -1,3 +1,5 @@
import threading
import time
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
@@ -89,3 +91,74 @@ def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
correction_detected=False,
reinforcement_detected=True,
)
def test_flush_nowait_cancels_existing_timer_and_starts_immediate_timer() -> None:
queue = MemoryUpdateQueue()
existing_timer = MagicMock()
queue._timer = existing_timer
created_timer = MagicMock()
with patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls:
queue.flush_nowait()
existing_timer.cancel.assert_called_once_with()
timer_cls.assert_called_once_with(0, queue._process_queue)
assert created_timer.daemon is True
created_timer.start.assert_called_once_with()
assert queue._timer is created_timer
def test_add_nowait_cancels_existing_timer_and_starts_immediate_timer() -> None:
queue = MemoryUpdateQueue()
existing_timer = MagicMock()
queue._timer = existing_timer
created_timer = MagicMock()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls,
):
queue.add_nowait(thread_id="thread-1", messages=["conversation"], agent_name="lead-agent")
existing_timer.cancel.assert_called_once_with()
timer_cls.assert_called_once_with(0, queue._process_queue)
assert queue.pending_count == 1
assert queue._queue[0].agent_name == "lead-agent"
assert created_timer.daemon is True
created_timer.start.assert_called_once_with()
def test_process_queue_reschedules_immediately_when_already_processing() -> None:
queue = MemoryUpdateQueue()
queue._processing = True
created_timer = MagicMock()
with patch("deerflow.agents.memory.queue.threading.Timer", return_value=created_timer) as timer_cls:
queue._process_queue()
timer_cls.assert_called_once_with(0, queue._process_queue)
assert created_timer.daemon is True
created_timer.start.assert_called_once_with()
def test_flush_nowait_is_non_blocking() -> None:
queue = MemoryUpdateQueue()
started = threading.Event()
finished = threading.Event()
def _slow_process_queue() -> None:
started.set()
time.sleep(0.2)
finished.set()
queue._process_queue = _slow_process_queue
start = time.perf_counter()
queue.flush_nowait()
elapsed = time.perf_counter() - start
assert started.wait(0.1) is True
assert elapsed < 0.1
assert finished.is_set() is False
assert finished.wait(1.0) is True
+10 -10
View File
@@ -3,14 +3,14 @@
Covers two functions introduced to prevent ephemeral file-upload context from
persisting in long-term memory:
- _filter_messages_for_memory (memory_middleware)
- filter_messages_for_memory (message_processing)
- _strip_upload_mentions_from_memory (updater)
"""
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement
# ---------------------------------------------------------------------------
# Helpers
@@ -31,7 +31,7 @@ def _ai(text: str, tool_calls=None) -> AIMessage:
# ===========================================================================
# _filter_messages_for_memory
# filter_messages_for_memory
# ===========================================================================
@@ -45,7 +45,7 @@ class TestFilterMessagesForMemory:
_human(_UPLOAD_BLOCK),
_ai("I have read the file. It says: Hello."),
]
result = _filter_messages_for_memory(msgs)
result = filter_messages_for_memory(msgs)
assert result == []
def test_upload_with_real_question_preserves_question(self):
@@ -56,7 +56,7 @@ class TestFilterMessagesForMemory:
_human(combined),
_ai("The file contains: Hello DeerFlow."),
]
result = _filter_messages_for_memory(msgs)
result = filter_messages_for_memory(msgs)
assert len(result) == 2
human_result = result[0]
@@ -71,7 +71,7 @@ class TestFilterMessagesForMemory:
_human("What is the capital of France?"),
_ai("The capital of France is Paris."),
]
result = _filter_messages_for_memory(msgs)
result = filter_messages_for_memory(msgs)
assert len(result) == 2
assert result[0].content == "What is the capital of France?"
assert result[1].content == "The capital of France is Paris."
@@ -84,7 +84,7 @@ class TestFilterMessagesForMemory:
ToolMessage(content="Search results", tool_call_id="1"),
_ai("Here are the results."),
]
result = _filter_messages_for_memory(msgs)
result = filter_messages_for_memory(msgs)
human_msgs = [m for m in result if m.type == "human"]
ai_msgs = [m for m in result if m.type == "ai"]
assert len(human_msgs) == 1
@@ -101,7 +101,7 @@ class TestFilterMessagesForMemory:
_human("What is 2 + 2?"),
_ai("4"),
]
result = _filter_messages_for_memory(msgs)
result = filter_messages_for_memory(msgs)
human_contents = [m.content for m in result if m.type == "human"]
ai_contents = [m.content for m in result if m.type == "ai"]
@@ -121,14 +121,14 @@ class TestFilterMessagesForMemory:
]
)
msgs = [msg, _ai("Done.")]
result = _filter_messages_for_memory(msgs)
result = filter_messages_for_memory(msgs)
assert result == []
def test_file_path_not_in_filtered_content(self):
"""After filtering, no upload file path should appear in any message."""
combined = _UPLOAD_BLOCK + "\n\nSummarise the file please."
msgs = [_human(combined), _ai("It says hello.")]
result = _filter_messages_for_memory(msgs)
result = filter_messages_for_memory(msgs)
all_content = " ".join(m.content for m in result if isinstance(m.content, str))
assert "/mnt/user-data/uploads/" not in all_content
assert "<uploaded_files>" not in all_content
@@ -0,0 +1,186 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage
from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent
from deerflow.config.memory_config import MemoryConfig
def _messages() -> list:
return [
HumanMessage(content="user-1"),
AIMessage(content="assistant-1"),
HumanMessage(content="user-2"),
AIMessage(content="assistant-2"),
]
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
context = {}
if thread_id is not None:
context["thread_id"] = thread_id
if agent_name is not None:
context["agent_name"] = agent_name
return SimpleNamespace(context=context)
def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("messages", 2)) -> DeerFlowSummarizationMiddleware:
model = MagicMock()
model.invoke.return_value = SimpleNamespace(text="compressed summary")
return DeerFlowSummarizationMiddleware(
model=model,
trigger=trigger,
keep=keep,
token_counter=len,
before_summarization=before_summarization,
)
def test_before_summarization_hook_receives_messages_before_compression() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append])
result = middleware.before_model({"messages": _messages()}, _runtime())
assert len(captured) == 1
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"]
assert [message.content for message in captured[0].preserved_messages] == ["user-2", "assistant-2"]
assert captured[0].thread_id == "thread-1"
assert captured[0].agent_name is None
assert isinstance(result["messages"][0], RemoveMessage)
assert result["messages"][1].content.startswith("Here is a summary")
def test_before_summarization_hook_not_called_when_threshold_not_met() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append], trigger=("messages", 10))
result = middleware.before_model({"messages": _messages()}, _runtime())
assert captured == []
assert result is None
def test_before_summarization_hook_exception_does_not_block_compression(caplog: pytest.LogCaptureFixture) -> None:
def _broken_hook(_: SummarizationEvent) -> None:
raise RuntimeError("hook failure")
middleware = _middleware(before_summarization=[_broken_hook])
with caplog.at_level("ERROR"):
result = middleware.before_model({"messages": _messages()}, _runtime())
assert "before_summarization hook _broken_hook failed" in caplog.text
assert isinstance(result["messages"][0], RemoveMessage)
def test_multiple_before_summarization_hooks_run_in_registration_order() -> None:
call_order: list[str] = []
def _hook(name: str):
return lambda _: call_order.append(name)
middleware = _middleware(before_summarization=[_hook("first"), _hook("second"), _hook("third")])
middleware.before_model({"messages": _messages()}, _runtime())
assert call_order == ["first", "second", "third"]
@pytest.mark.anyio
async def test_abefore_model_calls_hooks_same_as_sync() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append])
await middleware.abefore_model({"messages": _messages()}, _runtime())
assert len(captured) == 1
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1", "assistant-1"]
def test_memory_flush_hook_skips_when_memory_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=False))
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
memory_flush_hook(
SummarizationEvent(
messages_to_summarize=tuple(_messages()[:2]),
preserved_messages=(),
thread_id="thread-1",
agent_name=None,
runtime=_runtime(),
)
)
queue.add_nowait.assert_not_called()
def test_memory_flush_hook_skips_when_thread_id_missing(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
memory_flush_hook(
SummarizationEvent(
messages_to_summarize=tuple(_messages()[:2]),
preserved_messages=(),
thread_id=None,
agent_name=None,
runtime=_runtime(None),
)
)
queue.add_nowait.assert_not_called()
def test_memory_flush_hook_enqueues_filtered_messages_and_flushes(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
messages = [
HumanMessage(content="Question"),
AIMessage(content="Calling tool", tool_calls=[{"name": "search", "id": "tool-1", "args": {}}]),
AIMessage(content="Final answer"),
]
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
memory_flush_hook(
SummarizationEvent(
messages_to_summarize=tuple(messages),
preserved_messages=(),
thread_id="thread-1",
agent_name=None,
runtime=_runtime(),
)
)
queue.add_nowait.assert_called_once()
add_kwargs = queue.add_nowait.call_args.kwargs
assert add_kwargs["thread_id"] == "thread-1"
assert [message.content for message in add_kwargs["messages"]] == ["Question", "Final answer"]
assert add_kwargs["correction_detected"] is False
assert add_kwargs["reinforcement_detected"] is False
def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
memory_flush_hook(
SummarizationEvent(
messages_to_summarize=tuple(_messages()[:2]),
preserved_messages=(),
thread_id="thread-1",
agent_name="research-agent",
runtime=_runtime(agent_name="research-agent"),
)
)
queue.add_nowait.assert_called_once()
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"