mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
fix(memory): replace short-lived asyncio.run() with persistent event loop (#2627)
* fix(memory): replace short-lived asyncio.run() with persistent event loop to prevent zombie httpx connections The memory updater used asyncio.run() inside daemon threads, creating and destroying short-lived event loops on every update. Langchain providers (e.g. langchain-anthropic) cache httpx AsyncClient instances globally via @lru_cache, so SSL connections created on a loop that is subsequently destroyed become zombie connections in the shared pool. When the main agent's lead run later reuses one of these connections, httpx/anyio triggers RuntimeError: Event loop is closed during connection cleanup. Replace the ThreadPoolExecutor + asyncio.run() pattern with a _MemoryLoopRunner that maintains a single persistent event loop in a daemon thread for the process lifetime. Since the loop never closes, connections bound to it never become invalid. The _run_async_update_sync function now submits coroutines to this persistent loop via run_coroutine_threadsafe instead of creating throwaway loops. * update the code to address the review comments * Fix the review comments of 2615 P1 — user_id forwarded through sync path: Added user_id parameter to _prepare_update_prompt, _finalize_update, and _do_update_memory_sync, and forwarded it to get_memory_data(agent_name, user_id=user_id) and save(..., user_id=user_id). The update_memory() entry point now passes user_id through both the executor.submit path and the direct call path. Added TestUserIdForwarding with two regression tests (sync + async) verifying get_memory_data and save receive the correct user_id. P2 — aupdate_memory() delegates to sync: Replaced the model.ainvoke() call with asyncio.to_thread(self._do_update_memory_sync, ...). This eliminates the unsafe async provider client path entirely — all memory updater entry points now use the isolated sync model.invoke() path. Updated the test from asserting ainvoke is awaited to asserting invoke is called and ainvoke is not. Nit — duplicate comment removed: Removed the duplicated # Matches sentences... comment on line 230. * Chore(test): update the code of test_memory_updater --------- Co-authored-by: rayhpeng <rayhpeng@gmail.com>
This commit is contained in:
@@ -1,13 +1,10 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.memory.prompt import format_conversation_for_update
|
||||
from deerflow.agents.memory.updater import (
|
||||
MemoryUpdater,
|
||||
_extract_text,
|
||||
_run_async_update_sync,
|
||||
clear_memory_data,
|
||||
create_memory_fact,
|
||||
delete_memory_fact,
|
||||
@@ -528,6 +525,7 @@ class TestUpdateMemoryStructuredResponse:
|
||||
response = MagicMock()
|
||||
response.content = content
|
||||
model.ainvoke = AsyncMock(return_value=response)
|
||||
model.invoke = MagicMock(return_value=response)
|
||||
return model
|
||||
|
||||
def test_string_response_parses(self):
|
||||
@@ -551,7 +549,7 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg])
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
model.invoke.assert_called_once()
|
||||
|
||||
def test_list_content_response_parses(self):
|
||||
"""LLM response as list-of-blocks should be extracted, not repr'd."""
|
||||
@@ -576,7 +574,8 @@ class TestUpdateMemoryStructuredResponse:
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_async_update_memory_uses_ainvoke(self):
|
||||
def test_async_update_memory_delegates_to_sync(self):
|
||||
"""aupdate_memory should delegate to sync _do_update_memory_sync via to_thread."""
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
@@ -597,8 +596,9 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = asyncio.run(updater.aupdate_memory([msg, ai_msg]))
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
assert model.ainvoke.await_args.kwargs["config"] == {"run_name": "memory_agent"}
|
||||
# aupdate_memory delegates to sync path — model.invoke, not ainvoke
|
||||
model.invoke.assert_called_once()
|
||||
model.ainvoke.assert_not_called()
|
||||
|
||||
def test_correction_hint_injected_when_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
@@ -622,7 +622,7 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args.args[0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
|
||||
def test_correction_hint_empty_when_not_detected(self):
|
||||
@@ -647,7 +647,7 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args.args[0]
|
||||
assert "Explicit correction signals were detected" not in prompt
|
||||
|
||||
def test_sync_update_memory_wrapper_works_in_running_loop(self):
|
||||
@@ -675,9 +675,9 @@ class TestUpdateMemoryStructuredResponse:
|
||||
result = asyncio.run(run_in_loop())
|
||||
|
||||
assert result is True
|
||||
model.ainvoke.assert_awaited_once()
|
||||
model.invoke.assert_called_once()
|
||||
|
||||
def test_sync_update_memory_returns_false_when_bridge_submit_fails(self):
|
||||
def test_sync_update_memory_returns_false_when_executor_down(self):
|
||||
updater = MemoryUpdater()
|
||||
|
||||
with (
|
||||
@@ -702,33 +702,67 @@ class TestUpdateMemoryStructuredResponse:
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestRunAsyncUpdateSync:
|
||||
def test_closes_unawaited_awaitable_when_bridge_fails_before_handoff(self):
|
||||
class CloseableAwaitable:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
class TestSyncUpdateIsolatesProviderClientPool:
|
||||
"""Regression tests for issue #2615.
|
||||
|
||||
def __await__(self):
|
||||
pytest.fail("awaitable should not have been awaited")
|
||||
yield
|
||||
The sync ``update_memory`` path must use ``model.invoke()`` (sync HTTP)
|
||||
and never touch the async provider client pool shared with the lead agent.
|
||||
"""
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
def test_sync_update_uses_invoke_not_ainvoke(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = valid_json
|
||||
model.invoke = MagicMock(return_value=response)
|
||||
model.ainvoke = AsyncMock(return_value=response)
|
||||
|
||||
awaitable = CloseableAwaitable()
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
|
||||
side_effect=RuntimeError("executor down"),
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi"
|
||||
ai_msg.tool_calls = []
|
||||
result = updater.update_memory([msg, ai_msg])
|
||||
|
||||
async def run_in_loop():
|
||||
return _run_async_update_sync(awaitable)
|
||||
assert result is True
|
||||
model.invoke.assert_called_once()
|
||||
model.ainvoke.assert_not_called()
|
||||
|
||||
result = asyncio.run(run_in_loop())
|
||||
def test_no_event_loop_created_during_sync_update(self):
|
||||
"""Sync update must not create or destroy any event loop."""
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = valid_json
|
||||
model.invoke = MagicMock(return_value=response)
|
||||
|
||||
assert result is False
|
||||
assert awaitable.closed is True
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
patch("asyncio.run", side_effect=AssertionError("asyncio.run must not be called from sync update path")),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi"
|
||||
ai_msg.tool_calls = []
|
||||
result = updater.update_memory([msg, ai_msg])
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestFactDeduplicationCaseInsensitive:
|
||||
@@ -805,6 +839,7 @@ class TestReinforcementHint:
|
||||
response = MagicMock()
|
||||
response.content = f"```json\n{json_response}\n```"
|
||||
model.ainvoke = AsyncMock(return_value=response)
|
||||
model.invoke = MagicMock(return_value=response)
|
||||
return model
|
||||
|
||||
def test_reinforcement_hint_injected_when_detected(self):
|
||||
@@ -829,7 +864,7 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args.args[0]
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
def test_reinforcement_hint_absent_when_not_detected(self):
|
||||
@@ -854,7 +889,7 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args.args[0]
|
||||
assert "Positive reinforcement signals were detected" not in prompt
|
||||
|
||||
def test_both_hints_present_when_both_detected(self):
|
||||
@@ -879,7 +914,7 @@ class TestReinforcementHint:
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.ainvoke.await_args.args[0]
|
||||
prompt = model.invoke.call_args.args[0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
@@ -908,11 +943,11 @@ class TestFinalizeCacheIsolation:
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = new_fact_json
|
||||
mock_model = AsyncMock()
|
||||
mock_model.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_model = MagicMock()
|
||||
mock_model.invoke = MagicMock(return_value=mock_response)
|
||||
|
||||
saved_objects: list[dict] = []
|
||||
save_mock = MagicMock(side_effect=lambda m, a=None: saved_objects.append(m) or False) # always fails
|
||||
save_mock = MagicMock(side_effect=lambda m, a=None, **_: saved_objects.append(m) or False) # always fails
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=mock_model),
|
||||
@@ -929,6 +964,85 @@ class TestFinalizeCacheIsolation:
|
||||
ai_msg.tool_calls = []
|
||||
updater.update_memory([msg, ai_msg], thread_id="t1")
|
||||
|
||||
# save_mock must have been exercised — otherwise the deepcopy-on-save-failure path isn't covered
|
||||
save_mock.assert_called_once()
|
||||
assert len(saved_objects) == 1, "save must have been called with the updated memory object"
|
||||
|
||||
# original_memory must not have been mutated — deepcopy isolates the mutation
|
||||
assert len(original_memory["facts"]) == 1, "original_memory must not be mutated by _apply_updates"
|
||||
assert original_memory["facts"][0]["content"] == "original"
|
||||
|
||||
|
||||
class TestUserIdForwarding:
|
||||
"""Regression: user_id must flow through the entire sync update path.
|
||||
|
||||
When MemoryUpdateQueue captures context.user_id and passes it into
|
||||
update_memory(..., user_id=context.user_id), the sync path must forward
|
||||
it into _prepare_update_prompt → get_memory_data() and
|
||||
_finalize_update → save(), so per-user memory isolation is maintained.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_model(content):
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = content
|
||||
model.invoke = MagicMock(return_value=response)
|
||||
return model
|
||||
|
||||
def test_sync_update_forwards_user_id_to_load_and_save(self):
|
||||
"""update_memory must pass user_id to get_memory_data and storage.save."""
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()) as mock_load,
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi"
|
||||
ai_msg.tool_calls = []
|
||||
result = updater.update_memory([msg, ai_msg], user_id="user-42")
|
||||
|
||||
assert result is True
|
||||
mock_load.assert_called_once_with(None, user_id="user-42")
|
||||
mock_storage.save.assert_called_once()
|
||||
save_call = mock_storage.save.call_args
|
||||
assert save_call.kwargs.get("user_id") == "user-42" or (len(save_call.args) > 2 and save_call.args[2] == "user-42")
|
||||
|
||||
def test_async_update_forwards_user_id_to_load_and_save(self):
|
||||
"""aupdate_memory must pass user_id through to the sync delegate."""
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()) as mock_load,
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Hello"
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Hi"
|
||||
ai_msg.tool_calls = []
|
||||
result = asyncio.run(updater.aupdate_memory([msg, ai_msg], user_id="user-99"))
|
||||
|
||||
assert result is True
|
||||
mock_load.assert_called_once_with(None, user_id="user-99")
|
||||
save_call = mock_storage.save.call_args
|
||||
assert save_call.kwargs.get("user_id") == "user-99" or (len(save_call.args) > 2 and save_call.args[2] == "user-99")
|
||||
|
||||
Reference in New Issue
Block a user