feat: switch memory updater to async LLM calls (#2138)

* docs: mark memory updater async migration as completed

- Update TODO.md to mark the replacement of sync model.invoke()
  with async model.ainvoke() in title_middleware and memory updater
  as completed using [x] format

Addresses #2131

* feat: switch memory updater to async LLM calls

- Add async aupdate_memory() method using await model.ainvoke()
- Convert sync update_memory() to use async wrapper
- Add _run_async_update_sync() for nested loop context handling
- Maintain backward compatibility with existing sync API
- Add ThreadPoolExecutor for async execution from sync contexts

Addresses #2131

* test: add tests for async memory updater

- Add test_async_update_memory_uses_ainvoke() to verify async path
- Convert existing tests to use AsyncMock and ainvoke assertions
- Add test_sync_update_memory_wrapper_works_in_running_loop()
- Update all model mocks to use async await patterns

Addresses #2131

* fix: apply ruff formatting to memory updater

- Format multi-line expressions to single line
- Ensure code style consistency with project standards
- Fix lint issues caught by GitHub Actions

* test: add comprehensive tests for async memory updater

- Add test_async_update_memory_uses_ainvoke() to verify async path
- Convert existing tests to use AsyncMock and ainvoke assertions
- Add test_sync_update_memory_wrapper_works_in_running_loop()
- Update all model mocks to use async await patterns
- Ensure backward compatibility with sync API

* fix: satisfy ruff formatting in memory updater test

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
luo jiyin
2026-04-14 11:10:42 +08:00
committed by GitHub
parent 55bc09ac33
commit 07fc25d285
3 changed files with 278 additions and 82 deletions
+118 -9
View File
@@ -1,9 +1,13 @@
from unittest.mock import MagicMock, patch
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from deerflow.agents.memory.prompt import format_conversation_for_update
from deerflow.agents.memory.updater import (
MemoryUpdater,
_extract_text,
_run_async_update_sync,
clear_memory_data,
create_memory_fact,
delete_memory_fact,
@@ -523,15 +527,16 @@ class TestUpdateMemoryStructuredResponse:
model = MagicMock()
response = MagicMock()
response.content = content
model.invoke.return_value = response
model.ainvoke = AsyncMock(return_value=response)
return model
def test_string_response_parses(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=self._make_mock_model(valid_json)),
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
@@ -546,6 +551,7 @@ class TestUpdateMemoryStructuredResponse:
result = updater.update_memory([msg, ai_msg])
assert result is True
model.ainvoke.assert_awaited_once()
def test_list_content_response_parses(self):
"""LLM response as list-of-blocks should be extracted, not repr'd."""
@@ -570,6 +576,29 @@ class TestUpdateMemoryStructuredResponse:
assert result is True
def test_async_update_memory_uses_ainvoke(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Hello"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Hi there"
ai_msg.tool_calls = []
result = asyncio.run(updater.aupdate_memory([msg, ai_msg]))
assert result is True
model.ainvoke.assert_awaited_once()
def test_correction_hint_injected_when_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
@@ -592,7 +621,7 @@ class TestUpdateMemoryStructuredResponse:
result = updater.update_memory([msg, ai_msg], correction_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
prompt = model.ainvoke.await_args.args[0]
assert "Explicit correction signals were detected" in prompt
def test_correction_hint_empty_when_not_detected(self):
@@ -617,9 +646,89 @@ class TestUpdateMemoryStructuredResponse:
result = updater.update_memory([msg, ai_msg], correction_detected=False)
assert result is True
prompt = model.invoke.call_args[0][0]
prompt = model.ainvoke.await_args.args[0]
assert "Explicit correction signals were detected" not in prompt
def test_sync_update_memory_wrapper_works_in_running_loop(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Hello from loop"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Hi"
ai_msg.tool_calls = []
async def run_in_loop():
return updater.update_memory([msg, ai_msg])
result = asyncio.run(run_in_loop())
assert result is True
model.ainvoke.assert_awaited_once()
def test_sync_update_memory_returns_false_when_bridge_submit_fails(self):
updater = MemoryUpdater()
with (
patch(
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
side_effect=RuntimeError("executor down"),
),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Hello from loop"
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Hi"
ai_msg.tool_calls = []
async def run_in_loop():
return updater.update_memory([msg, ai_msg])
result = asyncio.run(run_in_loop())
assert result is False
class TestRunAsyncUpdateSync:
def test_closes_unawaited_awaitable_when_bridge_fails_before_handoff(self):
class CloseableAwaitable:
def __init__(self):
self.closed = False
def __await__(self):
pytest.fail("awaitable should not have been awaited")
yield
def close(self):
self.closed = True
awaitable = CloseableAwaitable()
with patch(
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
side_effect=RuntimeError("executor down"),
):
async def run_in_loop():
return _run_async_update_sync(awaitable)
result = asyncio.run(run_in_loop())
assert result is False
assert awaitable.closed is True
class TestFactDeduplicationCaseInsensitive:
"""Tests that fact deduplication is case-insensitive."""
@@ -694,7 +803,7 @@ class TestReinforcementHint:
model = MagicMock()
response = MagicMock()
response.content = f"```json\n{json_response}\n```"
model.invoke.return_value = response
model.ainvoke = AsyncMock(return_value=response)
return model
def test_reinforcement_hint_injected_when_detected(self):
@@ -719,7 +828,7 @@ class TestReinforcementHint:
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
prompt = model.ainvoke.await_args.args[0]
assert "Positive reinforcement signals were detected" in prompt
def test_reinforcement_hint_absent_when_not_detected(self):
@@ -744,7 +853,7 @@ class TestReinforcementHint:
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
assert result is True
prompt = model.invoke.call_args[0][0]
prompt = model.ainvoke.await_args.args[0]
assert "Positive reinforcement signals were detected" not in prompt
def test_both_hints_present_when_both_detected(self):
@@ -769,6 +878,6 @@ class TestReinforcementHint:
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
prompt = model.ainvoke.await_args.args[0]
assert "Explicit correction signals were detected" in prompt
assert "Positive reinforcement signals were detected" in prompt