mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 08:55:59 +00:00
fix(memory): replace short-lived asyncio.run() with persistent event loop (#2627)
* fix(memory): replace short-lived asyncio.run() with persistent event loop to prevent zombie httpx connections The memory updater used asyncio.run() inside daemon threads, creating and destroying short-lived event loops on every update. Langchain providers (e.g. langchain-anthropic) cache httpx AsyncClient instances globally via @lru_cache, so SSL connections created on a loop that is subsequently destroyed become zombie connections in the shared pool. When the main agent's lead run later reuses one of these connections, httpx/anyio triggers RuntimeError: Event loop is closed during connection cleanup. Replace the ThreadPoolExecutor + asyncio.run() pattern with a _MemoryLoopRunner that maintains a single persistent event loop in a daemon thread for the process lifetime. Since the loop never closes, connections bound to it never become invalid. The _run_async_update_sync function now submits coroutines to this persistent loop via run_coroutine_threadsafe instead of creating throwaway loops. * update the code to address the review comments * Fix the review comments of 2615 P1 — user_id forwarded through sync path: Added user_id parameter to _prepare_update_prompt, _finalize_update, and _do_update_memory_sync, and forwarded it to get_memory_data(agent_name, user_id=user_id) and save(..., user_id=user_id). The update_memory() entry point now passes user_id through both the executor.submit path and the direct call path. Added TestUserIdForwarding with two regression tests (sync + async) verifying get_memory_data and save receive the correct user_id. P2 — aupdate_memory() delegates to sync: Replaced the model.ainvoke() call with asyncio.to_thread(self._do_update_memory_sync, ...). This eliminates the unsafe async provider client path entirely — all memory updater entry points now use the isolated sync model.invoke() path. Updated the test from asserting ainvoke is awaited to asserting invoke is called and ainvoke is not. Nit — duplicate comment removed: Removed the duplicated # Matches sentences... comment on line 230. * Chore(test): update the code of test_memory_updater --------- Co-authored-by: rayhpeng <rayhpeng@gmail.com>
This commit is contained in:
@@ -9,7 +9,6 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Awaitable
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from deerflow.agents.memory.prompt import (
|
from deerflow.agents.memory.prompt import (
|
||||||
@@ -26,6 +25,12 @@ from deerflow.models import create_chat_model
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Thread pool for offloading sync memory updates when called from an async
|
||||||
|
# context. Unlike the previous asyncio.run() approach, this runs *sync*
|
||||||
|
# model.invoke() calls — no event loop is created, so the langchain async
|
||||||
|
# httpx client pool (globally cached via @lru_cache) is never touched and
|
||||||
|
# cross-loop connection reuse is impossible.
|
||||||
_SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
|
_SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
|
||||||
max_workers=4,
|
max_workers=4,
|
||||||
thread_name_prefix="memory-updater-sync",
|
thread_name_prefix="memory-updater-sync",
|
||||||
@@ -222,39 +227,6 @@ def _extract_text(content: Any) -> str:
|
|||||||
return str(content)
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
def _run_async_update_sync(coro: Awaitable[bool]) -> bool:
|
|
||||||
"""Run an async memory update from sync code, including nested-loop contexts."""
|
|
||||||
handed_off = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
loop = None
|
|
||||||
|
|
||||||
if loop is not None and loop.is_running():
|
|
||||||
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(asyncio.run, coro)
|
|
||||||
handed_off = True
|
|
||||||
return future.result()
|
|
||||||
|
|
||||||
handed_off = True
|
|
||||||
return asyncio.run(coro)
|
|
||||||
except Exception:
|
|
||||||
if not handed_off:
|
|
||||||
close = getattr(coro, "close", None)
|
|
||||||
if callable(close):
|
|
||||||
try:
|
|
||||||
close()
|
|
||||||
except Exception:
|
|
||||||
logger.debug(
|
|
||||||
"Failed to close un-awaited memory update coroutine",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.exception("Failed to run async memory update from sync context")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# Matches sentences that describe a file-upload *event* rather than general
|
# Matches sentences that describe a file-upload *event* rather than general
|
||||||
# file-related work. Deliberately narrow to avoid removing legitimate facts
|
# file-related work. Deliberately narrow to avoid removing legitimate facts
|
||||||
# such as "User works with CSV files" or "prefers PDF export".
|
# such as "User works with CSV files" or "prefers PDF export".
|
||||||
@@ -349,13 +321,14 @@ class MemoryUpdater:
|
|||||||
agent_name: str | None,
|
agent_name: str | None,
|
||||||
correction_detected: bool,
|
correction_detected: bool,
|
||||||
reinforcement_detected: bool,
|
reinforcement_detected: bool,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> tuple[dict[str, Any], str] | None:
|
) -> tuple[dict[str, Any], str] | None:
|
||||||
"""Load memory and build the update prompt for a conversation."""
|
"""Load memory and build the update prompt for a conversation."""
|
||||||
config = get_memory_config()
|
config = get_memory_config()
|
||||||
if not config.enabled or not messages:
|
if not config.enabled or not messages:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
current_memory = get_memory_data(agent_name)
|
current_memory = get_memory_data(agent_name, user_id=user_id)
|
||||||
conversation_text = format_conversation_for_update(messages)
|
conversation_text = format_conversation_for_update(messages)
|
||||||
if not conversation_text.strip():
|
if not conversation_text.strip():
|
||||||
return None
|
return None
|
||||||
@@ -377,6 +350,7 @@ class MemoryUpdater:
|
|||||||
response_content: Any,
|
response_content: Any,
|
||||||
thread_id: str | None,
|
thread_id: str | None,
|
||||||
agent_name: str | None,
|
agent_name: str | None,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Parse the model response, apply updates, and persist memory."""
|
"""Parse the model response, apply updates, and persist memory."""
|
||||||
response_text = _extract_text(response_content).strip()
|
response_text = _extract_text(response_content).strip()
|
||||||
@@ -390,7 +364,7 @@ class MemoryUpdater:
|
|||||||
# cannot corrupt the still-cached original object reference.
|
# cannot corrupt the still-cached original object reference.
|
||||||
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
|
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
|
||||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||||
return get_memory_storage().save(updated_memory, agent_name)
|
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
|
||||||
|
|
||||||
async def aupdate_memory(
|
async def aupdate_memory(
|
||||||
self,
|
self,
|
||||||
@@ -399,28 +373,63 @@ class MemoryUpdater:
|
|||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
correction_detected: bool = False,
|
correction_detected: bool = False,
|
||||||
reinforcement_detected: bool = False,
|
reinforcement_detected: bool = False,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Update memory asynchronously based on conversation messages."""
|
"""Update memory asynchronously by delegating to the sync path.
|
||||||
|
|
||||||
|
Uses ``asyncio.to_thread`` to run the *sync* ``model.invoke()`` path
|
||||||
|
in a worker thread so no second event loop is created and the
|
||||||
|
langchain async httpx client pool (shared with the lead agent) is
|
||||||
|
never touched. This eliminates the cross-loop connection-reuse bug
|
||||||
|
described in issue #2615.
|
||||||
|
"""
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self._do_update_memory_sync,
|
||||||
|
messages=messages,
|
||||||
|
thread_id=thread_id,
|
||||||
|
agent_name=agent_name,
|
||||||
|
correction_detected=correction_detected,
|
||||||
|
reinforcement_detected=reinforcement_detected,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _do_update_memory_sync(
|
||||||
|
self,
|
||||||
|
messages: list[Any],
|
||||||
|
thread_id: str | None = None,
|
||||||
|
agent_name: str | None = None,
|
||||||
|
correction_detected: bool = False,
|
||||||
|
reinforcement_detected: bool = False,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Pure-sync memory update using ``model.invoke()``.
|
||||||
|
|
||||||
|
Uses the *sync* LLM call path so no event loop is created. This
|
||||||
|
guarantees that the langchain provider's globally cached async
|
||||||
|
httpx ``AsyncClient`` / connection pool (the one shared with the
|
||||||
|
lead agent) is never touched — no cross-loop connection reuse is
|
||||||
|
possible.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
prepared = await asyncio.to_thread(
|
prepared = self._prepare_update_prompt(
|
||||||
self._prepare_update_prompt,
|
|
||||||
messages=messages,
|
messages=messages,
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
correction_detected=correction_detected,
|
correction_detected=correction_detected,
|
||||||
reinforcement_detected=reinforcement_detected,
|
reinforcement_detected=reinforcement_detected,
|
||||||
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
if prepared is None:
|
if prepared is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
current_memory, prompt = prepared
|
current_memory, prompt = prepared
|
||||||
model = self._get_model()
|
model = self._get_model()
|
||||||
response = await model.ainvoke(prompt, config={"run_name": "memory_agent"})
|
response = model.invoke(prompt, config={"run_name": "memory_agent"})
|
||||||
return await asyncio.to_thread(
|
return self._finalize_update(
|
||||||
self._finalize_update,
|
|
||||||
current_memory=current_memory,
|
current_memory=current_memory,
|
||||||
response_content=response.content,
|
response_content=response.content,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||||
@@ -438,7 +447,16 @@ class MemoryUpdater:
|
|||||||
reinforcement_detected: bool = False,
|
reinforcement_detected: bool = False,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Synchronously update memory via the async updater path.
|
"""Synchronously update memory using the sync LLM path.
|
||||||
|
|
||||||
|
Uses ``model.invoke()`` (sync HTTP) which operates on a completely
|
||||||
|
separate connection pool from the async ``AsyncClient`` shared by
|
||||||
|
the lead agent. This eliminates the cross-loop connection-reuse
|
||||||
|
bug described in issue #2615.
|
||||||
|
|
||||||
|
When called from within a running event loop (e.g. from a LangGraph
|
||||||
|
node), the blocking sync call is offloaded to a thread pool so the
|
||||||
|
caller's loop is not blocked.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of conversation messages.
|
messages: List of conversation messages.
|
||||||
@@ -451,14 +469,34 @@ class MemoryUpdater:
|
|||||||
Returns:
|
Returns:
|
||||||
True if update was successful, False otherwise.
|
True if update was successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
return _run_async_update_sync(
|
try:
|
||||||
self.aupdate_memory(
|
loop = asyncio.get_running_loop()
|
||||||
messages=messages,
|
except RuntimeError:
|
||||||
thread_id=thread_id,
|
loop = None
|
||||||
agent_name=agent_name,
|
|
||||||
correction_detected=correction_detected,
|
if loop is not None and loop.is_running():
|
||||||
reinforcement_detected=reinforcement_detected,
|
try:
|
||||||
)
|
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(
|
||||||
|
self._do_update_memory_sync,
|
||||||
|
messages=messages,
|
||||||
|
thread_id=thread_id,
|
||||||
|
agent_name=agent_name,
|
||||||
|
correction_detected=correction_detected,
|
||||||
|
reinforcement_detected=reinforcement_detected,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
return future.result()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to offload memory update to executor")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self._do_update_memory_sync(
|
||||||
|
messages=messages,
|
||||||
|
thread_id=thread_id,
|
||||||
|
agent_name=agent_name,
|
||||||
|
correction_detected=correction_detected,
|
||||||
|
reinforcement_detected=reinforcement_detected,
|
||||||
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _apply_updates(
|
def _apply_updates(
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from deerflow.agents.memory.prompt import format_conversation_for_update
|
from deerflow.agents.memory.prompt import format_conversation_for_update
|
||||||
from deerflow.agents.memory.updater import (
|
from deerflow.agents.memory.updater import (
|
||||||
MemoryUpdater,
|
MemoryUpdater,
|
||||||
_extract_text,
|
_extract_text,
|
||||||
_run_async_update_sync,
|
|
||||||
clear_memory_data,
|
clear_memory_data,
|
||||||
create_memory_fact,
|
create_memory_fact,
|
||||||
delete_memory_fact,
|
delete_memory_fact,
|
||||||
@@ -528,6 +525,7 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
response = MagicMock()
|
response = MagicMock()
|
||||||
response.content = content
|
response.content = content
|
||||||
model.ainvoke = AsyncMock(return_value=response)
|
model.ainvoke = AsyncMock(return_value=response)
|
||||||
|
model.invoke = MagicMock(return_value=response)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def test_string_response_parses(self):
|
def test_string_response_parses(self):
|
||||||
@@ -551,7 +549,7 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
result = updater.update_memory([msg, ai_msg])
|
result = updater.update_memory([msg, ai_msg])
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
model.ainvoke.assert_awaited_once()
|
model.invoke.assert_called_once()
|
||||||
|
|
||||||
def test_list_content_response_parses(self):
|
def test_list_content_response_parses(self):
|
||||||
"""LLM response as list-of-blocks should be extracted, not repr'd."""
|
"""LLM response as list-of-blocks should be extracted, not repr'd."""
|
||||||
@@ -576,7 +574,8 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
def test_async_update_memory_uses_ainvoke(self):
|
def test_async_update_memory_delegates_to_sync(self):
|
||||||
|
"""aupdate_memory should delegate to sync _do_update_memory_sync via to_thread."""
|
||||||
updater = MemoryUpdater()
|
updater = MemoryUpdater()
|
||||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
model = self._make_mock_model(valid_json)
|
model = self._make_mock_model(valid_json)
|
||||||
@@ -597,8 +596,9 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
result = asyncio.run(updater.aupdate_memory([msg, ai_msg]))
|
result = asyncio.run(updater.aupdate_memory([msg, ai_msg]))
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
model.ainvoke.assert_awaited_once()
|
# aupdate_memory delegates to sync path — model.invoke, not ainvoke
|
||||||
assert model.ainvoke.await_args.kwargs["config"] == {"run_name": "memory_agent"}
|
model.invoke.assert_called_once()
|
||||||
|
model.ainvoke.assert_not_called()
|
||||||
|
|
||||||
def test_correction_hint_injected_when_detected(self):
|
def test_correction_hint_injected_when_detected(self):
|
||||||
updater = MemoryUpdater()
|
updater = MemoryUpdater()
|
||||||
@@ -622,7 +622,7 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
result = updater.update_memory([msg, ai_msg], correction_detected=True)
|
result = updater.update_memory([msg, ai_msg], correction_detected=True)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
prompt = model.ainvoke.await_args.args[0]
|
prompt = model.invoke.call_args.args[0]
|
||||||
assert "Explicit correction signals were detected" in prompt
|
assert "Explicit correction signals were detected" in prompt
|
||||||
|
|
||||||
def test_correction_hint_empty_when_not_detected(self):
|
def test_correction_hint_empty_when_not_detected(self):
|
||||||
@@ -647,7 +647,7 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
result = updater.update_memory([msg, ai_msg], correction_detected=False)
|
result = updater.update_memory([msg, ai_msg], correction_detected=False)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
prompt = model.ainvoke.await_args.args[0]
|
prompt = model.invoke.call_args.args[0]
|
||||||
assert "Explicit correction signals were detected" not in prompt
|
assert "Explicit correction signals were detected" not in prompt
|
||||||
|
|
||||||
def test_sync_update_memory_wrapper_works_in_running_loop(self):
|
def test_sync_update_memory_wrapper_works_in_running_loop(self):
|
||||||
@@ -675,9 +675,9 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
result = asyncio.run(run_in_loop())
|
result = asyncio.run(run_in_loop())
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
model.ainvoke.assert_awaited_once()
|
model.invoke.assert_called_once()
|
||||||
|
|
||||||
def test_sync_update_memory_returns_false_when_bridge_submit_fails(self):
|
def test_sync_update_memory_returns_false_when_executor_down(self):
|
||||||
updater = MemoryUpdater()
|
updater = MemoryUpdater()
|
||||||
|
|
||||||
with (
|
with (
|
||||||
@@ -702,33 +702,67 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
class TestRunAsyncUpdateSync:
|
class TestSyncUpdateIsolatesProviderClientPool:
|
||||||
def test_closes_unawaited_awaitable_when_bridge_fails_before_handoff(self):
|
"""Regression tests for issue #2615.
|
||||||
class CloseableAwaitable:
|
|
||||||
def __init__(self):
|
|
||||||
self.closed = False
|
|
||||||
|
|
||||||
def __await__(self):
|
The sync ``update_memory`` path must use ``model.invoke()`` (sync HTTP)
|
||||||
pytest.fail("awaitable should not have been awaited")
|
and never touch the async provider client pool shared with the lead agent.
|
||||||
yield
|
"""
|
||||||
|
|
||||||
def close(self):
|
def test_sync_update_uses_invoke_not_ainvoke(self):
|
||||||
self.closed = True
|
updater = MemoryUpdater()
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
|
model = MagicMock()
|
||||||
|
response = MagicMock()
|
||||||
|
response.content = valid_json
|
||||||
|
model.invoke = MagicMock(return_value=response)
|
||||||
|
model.ainvoke = AsyncMock(return_value=response)
|
||||||
|
|
||||||
awaitable = CloseableAwaitable()
|
with (
|
||||||
|
patch.object(updater, "_get_model", return_value=model),
|
||||||
with patch(
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||||
side_effect=RuntimeError("executor down"),
|
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||||
):
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "Hello"
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Hi"
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
result = updater.update_memory([msg, ai_msg])
|
||||||
|
|
||||||
async def run_in_loop():
|
assert result is True
|
||||||
return _run_async_update_sync(awaitable)
|
model.invoke.assert_called_once()
|
||||||
|
model.ainvoke.assert_not_called()
|
||||||
|
|
||||||
result = asyncio.run(run_in_loop())
|
def test_no_event_loop_created_during_sync_update(self):
|
||||||
|
"""Sync update must not create or destroy any event loop."""
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
|
model = MagicMock()
|
||||||
|
response = MagicMock()
|
||||||
|
response.content = valid_json
|
||||||
|
model.invoke = MagicMock(return_value=response)
|
||||||
|
|
||||||
assert result is False
|
with (
|
||||||
assert awaitable.closed is True
|
patch.object(updater, "_get_model", return_value=model),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||||
|
patch("asyncio.run", side_effect=AssertionError("asyncio.run must not be called from sync update path")),
|
||||||
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "Hello"
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Hi"
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
result = updater.update_memory([msg, ai_msg])
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
class TestFactDeduplicationCaseInsensitive:
|
class TestFactDeduplicationCaseInsensitive:
|
||||||
@@ -805,6 +839,7 @@ class TestReinforcementHint:
|
|||||||
response = MagicMock()
|
response = MagicMock()
|
||||||
response.content = f"```json\n{json_response}\n```"
|
response.content = f"```json\n{json_response}\n```"
|
||||||
model.ainvoke = AsyncMock(return_value=response)
|
model.ainvoke = AsyncMock(return_value=response)
|
||||||
|
model.invoke = MagicMock(return_value=response)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def test_reinforcement_hint_injected_when_detected(self):
|
def test_reinforcement_hint_injected_when_detected(self):
|
||||||
@@ -829,7 +864,7 @@ class TestReinforcementHint:
|
|||||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
|
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
prompt = model.ainvoke.await_args.args[0]
|
prompt = model.invoke.call_args.args[0]
|
||||||
assert "Positive reinforcement signals were detected" in prompt
|
assert "Positive reinforcement signals were detected" in prompt
|
||||||
|
|
||||||
def test_reinforcement_hint_absent_when_not_detected(self):
|
def test_reinforcement_hint_absent_when_not_detected(self):
|
||||||
@@ -854,7 +889,7 @@ class TestReinforcementHint:
|
|||||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
|
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
prompt = model.ainvoke.await_args.args[0]
|
prompt = model.invoke.call_args.args[0]
|
||||||
assert "Positive reinforcement signals were detected" not in prompt
|
assert "Positive reinforcement signals were detected" not in prompt
|
||||||
|
|
||||||
def test_both_hints_present_when_both_detected(self):
|
def test_both_hints_present_when_both_detected(self):
|
||||||
@@ -879,7 +914,7 @@ class TestReinforcementHint:
|
|||||||
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
|
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
prompt = model.ainvoke.await_args.args[0]
|
prompt = model.invoke.call_args.args[0]
|
||||||
assert "Explicit correction signals were detected" in prompt
|
assert "Explicit correction signals were detected" in prompt
|
||||||
assert "Positive reinforcement signals were detected" in prompt
|
assert "Positive reinforcement signals were detected" in prompt
|
||||||
|
|
||||||
@@ -908,11 +943,11 @@ class TestFinalizeCacheIsolation:
|
|||||||
)
|
)
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.content = new_fact_json
|
mock_response.content = new_fact_json
|
||||||
mock_model = AsyncMock()
|
mock_model = MagicMock()
|
||||||
mock_model.ainvoke = AsyncMock(return_value=mock_response)
|
mock_model.invoke = MagicMock(return_value=mock_response)
|
||||||
|
|
||||||
saved_objects: list[dict] = []
|
saved_objects: list[dict] = []
|
||||||
save_mock = MagicMock(side_effect=lambda m, a=None: saved_objects.append(m) or False) # always fails
|
save_mock = MagicMock(side_effect=lambda m, a=None, **_: saved_objects.append(m) or False) # always fails
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(updater, "_get_model", return_value=mock_model),
|
patch.object(updater, "_get_model", return_value=mock_model),
|
||||||
@@ -929,6 +964,85 @@ class TestFinalizeCacheIsolation:
|
|||||||
ai_msg.tool_calls = []
|
ai_msg.tool_calls = []
|
||||||
updater.update_memory([msg, ai_msg], thread_id="t1")
|
updater.update_memory([msg, ai_msg], thread_id="t1")
|
||||||
|
|
||||||
|
# save_mock must have been exercised — otherwise the deepcopy-on-save-failure path isn't covered
|
||||||
|
save_mock.assert_called_once()
|
||||||
|
assert len(saved_objects) == 1, "save must have been called with the updated memory object"
|
||||||
|
|
||||||
# original_memory must not have been mutated — deepcopy isolates the mutation
|
# original_memory must not have been mutated — deepcopy isolates the mutation
|
||||||
assert len(original_memory["facts"]) == 1, "original_memory must not be mutated by _apply_updates"
|
assert len(original_memory["facts"]) == 1, "original_memory must not be mutated by _apply_updates"
|
||||||
assert original_memory["facts"][0]["content"] == "original"
|
assert original_memory["facts"][0]["content"] == "original"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserIdForwarding:
|
||||||
|
"""Regression: user_id must flow through the entire sync update path.
|
||||||
|
|
||||||
|
When MemoryUpdateQueue captures context.user_id and passes it into
|
||||||
|
update_memory(..., user_id=context.user_id), the sync path must forward
|
||||||
|
it into _prepare_update_prompt → get_memory_data() and
|
||||||
|
_finalize_update → save(), so per-user memory isolation is maintained.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_mock_model(content):
|
||||||
|
model = MagicMock()
|
||||||
|
response = MagicMock()
|
||||||
|
response.content = content
|
||||||
|
model.invoke = MagicMock(return_value=response)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def test_sync_update_forwards_user_id_to_load_and_save(self):
|
||||||
|
"""update_memory must pass user_id to get_memory_data and storage.save."""
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
|
model = self._make_mock_model(valid_json)
|
||||||
|
mock_storage = MagicMock()
|
||||||
|
mock_storage.save = MagicMock(return_value=True)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(updater, "_get_model", return_value=model),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()) as mock_load,
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage),
|
||||||
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "Hello"
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Hi"
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
result = updater.update_memory([msg, ai_msg], user_id="user-42")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_load.assert_called_once_with(None, user_id="user-42")
|
||||||
|
mock_storage.save.assert_called_once()
|
||||||
|
save_call = mock_storage.save.call_args
|
||||||
|
assert save_call.kwargs.get("user_id") == "user-42" or (len(save_call.args) > 2 and save_call.args[2] == "user-42")
|
||||||
|
|
||||||
|
def test_async_update_forwards_user_id_to_load_and_save(self):
|
||||||
|
"""aupdate_memory must pass user_id through to the sync delegate."""
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
|
model = self._make_mock_model(valid_json)
|
||||||
|
mock_storage = MagicMock()
|
||||||
|
mock_storage.save = MagicMock(return_value=True)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(updater, "_get_model", return_value=model),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()) as mock_load,
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage),
|
||||||
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "Hello"
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Hi"
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
result = asyncio.run(updater.aupdate_memory([msg, ai_msg], user_id="user-99"))
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_load.assert_called_once_with(None, user_id="user-99")
|
||||||
|
save_call = mock_storage.save.call_args
|
||||||
|
assert save_call.kwargs.get("user_id") == "user-99" or (len(save_call.args) > 2 and save_call.args[2] == "user-99")
|
||||||
|
|||||||
Reference in New Issue
Block a user