mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 08:55:59 +00:00
fix(memory): isolate queued memory updates by agent (#2941)
* fix(memory): isolate queued memory updates by agent * fix(memory): include user in queue identity * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * Fix the lint error --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -40,6 +40,15 @@ class MemoryUpdateQueue:
|
|||||||
self._timer: threading.Timer | None = None
|
self._timer: threading.Timer | None = None
|
||||||
self._processing = False
|
self._processing = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _queue_key(
|
||||||
|
thread_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
agent_name: str | None,
|
||||||
|
) -> tuple[str, str | None, str | None]:
|
||||||
|
"""Return the debounce identity for a memory update target."""
|
||||||
|
return (thread_id, user_id, agent_name)
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
@@ -115,8 +124,9 @@ class MemoryUpdateQueue:
|
|||||||
correction_detected: bool,
|
correction_detected: bool,
|
||||||
reinforcement_detected: bool,
|
reinforcement_detected: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
queue_key = self._queue_key(thread_id, user_id, agent_name)
|
||||||
existing_context = next(
|
existing_context = next(
|
||||||
(context for context in self._queue if context.thread_id == thread_id),
|
(context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) == queue_key),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||||
@@ -130,7 +140,7 @@ class MemoryUpdateQueue:
|
|||||||
reinforcement_detected=merged_reinforcement_detected,
|
reinforcement_detected=merged_reinforcement_detected,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
self._queue = [context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) != queue_key]
|
||||||
self._queue.append(context)
|
self._queue.append(context)
|
||||||
|
|
||||||
def _reset_timer(self) -> None:
|
def _reset_timer(self) -> None:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from deerflow.agents.memory.message_processing import detect_correction, detect_
|
|||||||
from deerflow.agents.memory.queue import get_memory_queue
|
from deerflow.agents.memory.queue import get_memory_queue
|
||||||
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
|
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
|
||||||
from deerflow.config.memory_config import get_memory_config
|
from deerflow.config.memory_config import get_memory_config
|
||||||
|
from deerflow.runtime.user_context import resolve_runtime_user_id
|
||||||
|
|
||||||
|
|
||||||
def memory_flush_hook(event: SummarizationEvent) -> None:
|
def memory_flush_hook(event: SummarizationEvent) -> None:
|
||||||
@@ -21,11 +22,13 @@ def memory_flush_hook(event: SummarizationEvent) -> None:
|
|||||||
|
|
||||||
correction_detected = detect_correction(filtered_messages)
|
correction_detected = detect_correction(filtered_messages)
|
||||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||||
|
user_id = resolve_runtime_user_id(event.runtime)
|
||||||
queue = get_memory_queue()
|
queue = get_memory_queue()
|
||||||
queue.add_nowait(
|
queue.add_nowait(
|
||||||
thread_id=event.thread_id,
|
thread_id=event.thread_id,
|
||||||
messages=filtered_messages,
|
messages=filtered_messages,
|
||||||
agent_name=event.agent_name,
|
agent_name=event.agent_name,
|
||||||
|
user_id=user_id,
|
||||||
correction_detected=correction_detected,
|
correction_detected=correction_detected,
|
||||||
reinforcement_detected=reinforcement_detected,
|
reinforcement_detected=reinforcement_detected,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||||
from deerflow.config.memory_config import MemoryConfig
|
from deerflow.config.memory_config import MemoryConfig
|
||||||
@@ -164,3 +164,85 @@ def test_flush_nowait_is_non_blocking() -> None:
|
|||||||
assert elapsed < 0.1
|
assert elapsed < 0.1
|
||||||
assert finished.is_set() is False
|
assert finished.is_set() is False
|
||||||
assert finished.wait(1.0) is True
|
assert finished.wait(1.0) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_queue_keeps_updates_for_different_agents_in_same_thread() -> None:
|
||||||
|
queue = MemoryUpdateQueue()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch.object(queue, "_reset_timer"),
|
||||||
|
):
|
||||||
|
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
|
||||||
|
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
|
||||||
|
|
||||||
|
assert queue.pending_count == 2
|
||||||
|
assert [context.agent_name for context in queue._queue] == ["agent-a", "agent-b"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_queue_still_coalesces_updates_for_same_agent_in_same_thread() -> None:
|
||||||
|
queue = MemoryUpdateQueue()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch.object(queue, "_reset_timer"),
|
||||||
|
):
|
||||||
|
queue.add(
|
||||||
|
thread_id="thread-1",
|
||||||
|
messages=["first"],
|
||||||
|
agent_name="agent-a",
|
||||||
|
correction_detected=True,
|
||||||
|
)
|
||||||
|
queue.add(
|
||||||
|
thread_id="thread-1",
|
||||||
|
messages=["second"],
|
||||||
|
agent_name="agent-a",
|
||||||
|
correction_detected=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert queue.pending_count == 1
|
||||||
|
assert queue._queue[0].agent_name == "agent-a"
|
||||||
|
assert queue._queue[0].messages == ["second"]
|
||||||
|
assert queue._queue[0].correction_detected is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_queue_updates_different_agents_in_same_thread_separately() -> None:
|
||||||
|
queue = MemoryUpdateQueue()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch.object(queue, "_reset_timer"),
|
||||||
|
):
|
||||||
|
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
|
||||||
|
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
|
||||||
|
|
||||||
|
mock_updater = MagicMock()
|
||||||
|
mock_updater.update_memory.return_value = True
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater),
|
||||||
|
patch("deerflow.agents.memory.queue.time.sleep"),
|
||||||
|
):
|
||||||
|
queue.flush()
|
||||||
|
|
||||||
|
assert mock_updater.update_memory.call_count == 2
|
||||||
|
mock_updater.update_memory.assert_has_calls(
|
||||||
|
[
|
||||||
|
call(
|
||||||
|
messages=["agent-a"],
|
||||||
|
thread_id="thread-1",
|
||||||
|
agent_name="agent-a",
|
||||||
|
correction_detected=False,
|
||||||
|
reinforcement_detected=False,
|
||||||
|
user_id=None,
|
||||||
|
),
|
||||||
|
call(
|
||||||
|
messages=["agent-b"],
|
||||||
|
thread_id="thread-1",
|
||||||
|
agent_name="agent-b",
|
||||||
|
correction_detected=False,
|
||||||
|
reinforcement_detected=False,
|
||||||
|
user_id=None,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -38,3 +38,42 @@ def test_queue_process_passes_user_id_to_updater():
|
|||||||
mock_updater.update_memory.assert_called_once()
|
mock_updater.update_memory.assert_called_once()
|
||||||
call_kwargs = mock_updater.update_memory.call_args.kwargs
|
call_kwargs = mock_updater.update_memory.call_args.kwargs
|
||||||
assert call_kwargs["user_id"] == "alice"
|
assert call_kwargs["user_id"] == "alice"
|
||||||
|
|
||||||
|
|
||||||
|
def test_queue_keeps_updates_for_different_users_in_same_thread_and_agent():
|
||||||
|
q = MemoryUpdateQueue()
|
||||||
|
|
||||||
|
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
||||||
|
q.add(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
|
||||||
|
q.add(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
|
||||||
|
|
||||||
|
assert q.pending_count == 2
|
||||||
|
assert [context.user_id for context in q._queue] == ["alice", "bob"]
|
||||||
|
assert [context.messages for context in q._queue] == [["alice update"], ["bob update"]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_queue_still_coalesces_updates_for_same_user_thread_and_agent():
|
||||||
|
q = MemoryUpdateQueue()
|
||||||
|
|
||||||
|
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
||||||
|
q.add(thread_id="main", messages=["first"], agent_name="researcher", user_id="alice")
|
||||||
|
q.add(thread_id="main", messages=["second"], agent_name="researcher", user_id="alice")
|
||||||
|
|
||||||
|
assert q.pending_count == 1
|
||||||
|
assert q._queue[0].messages == ["second"]
|
||||||
|
assert q._queue[0].user_id == "alice"
|
||||||
|
assert q._queue[0].agent_name == "researcher"
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_nowait_keeps_different_users_separate():
|
||||||
|
q = MemoryUpdateQueue()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)),
|
||||||
|
patch.object(q, "_schedule_timer"),
|
||||||
|
):
|
||||||
|
q.add_nowait(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
|
||||||
|
q.add_nowait(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
|
||||||
|
|
||||||
|
assert q.pending_count == 2
|
||||||
|
assert [context.user_id for context in q._queue] == ["alice", "bob"]
|
||||||
|
|||||||
@@ -30,12 +30,18 @@ def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
|
def _runtime(
|
||||||
|
thread_id: str | None = "thread-1",
|
||||||
|
agent_name: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> SimpleNamespace:
|
||||||
context = {}
|
context = {}
|
||||||
if thread_id is not None:
|
if thread_id is not None:
|
||||||
context["thread_id"] = thread_id
|
context["thread_id"] = thread_id
|
||||||
if agent_name is not None:
|
if agent_name is not None:
|
||||||
context["agent_name"] = agent_name
|
context["agent_name"] = agent_name
|
||||||
|
if user_id is not None:
|
||||||
|
context["user_id"] = user_id
|
||||||
return SimpleNamespace(context=context)
|
return SimpleNamespace(context=context)
|
||||||
|
|
||||||
|
|
||||||
@@ -634,3 +640,22 @@ def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.Mon
|
|||||||
|
|
||||||
queue.add_nowait.assert_called_once()
|
queue.add_nowait.assert_called_once()
|
||||||
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
|
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
queue = MagicMock()
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||||
|
|
||||||
|
memory_flush_hook(
|
||||||
|
SummarizationEvent(
|
||||||
|
messages_to_summarize=tuple(_messages()[:2]),
|
||||||
|
preserved_messages=(),
|
||||||
|
thread_id="main",
|
||||||
|
agent_name="researcher",
|
||||||
|
runtime=_runtime(thread_id="main", agent_name="researcher", user_id="alice"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
queue.add_nowait.assert_called_once()
|
||||||
|
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
|
||||||
|
|||||||
Reference in New Issue
Block a user