mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 07:56:48 +00:00
feat: switch memory updater to async LLM calls (#2138)
* docs: mark memory updater async migration as completed - Update TODO.md to mark the replacement of sync model.invoke() with async model.ainvoke() in title_middleware and memory updater as completed using [x] format Addresses #2131 * feat: switch memory updater to async LLM calls - Add async aupdate_memory() method using await model.ainvoke() - Convert sync update_memory() to use async wrapper - Add _run_async_update_sync() for nested loop context handling - Maintain backward compatibility with existing sync API - Add ThreadPoolExecutor for async execution from sync contexts Addresses #2131 * test: add tests for async memory updater - Add test_async_update_memory_uses_ainvoke() to verify async path - Convert existing tests to use AsyncMock and ainvoke assertions - Add test_sync_update_memory_wrapper_works_in_running_loop() - Update all model mocks to use async await patterns Addresses #2131 * fix: apply ruff formatting to memory updater - Format multi-line expressions to single line - Ensure code style consistency with project standards - Fix lint issues caught by GitHub Actions * test: add comprehensive tests for async memory updater - Add test_async_update_memory_uses_ainvoke() to verify async path - Convert existing tests to use AsyncMock and ainvoke assertions - Add test_sync_update_memory_wrapper_works_in_running_loop() - Update all model mocks to use async await patterns - Ensure backward compatibility with sync API * fix: satisfy ruff formatting in memory updater test --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -24,7 +24,7 @@
|
|||||||
- [ ] Optimize async concurrency in agent hot path (IM channels multi-task scenario)
|
- [ ] Optimize async concurrency in agent hot path (IM channels multi-task scenario)
|
||||||
- [ ] Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
|
- [ ] Replace `subprocess.run()` with `asyncio.create_subprocess_shell()` in `packages/harness/deerflow/sandbox/local/local_sandbox.py`
|
||||||
- Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search)
|
- Replace sync `requests` with `httpx.AsyncClient` in community tools (tavily, jina_ai, firecrawl, infoquest, image_search)
|
||||||
- Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
|
- [x] Replace sync `model.invoke()` with async `model.ainvoke()` in title_middleware and memory updater
|
||||||
- Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O
|
- Consider `asyncio.to_thread()` wrapper for remaining blocking file I/O
|
||||||
- For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker)
|
- For production: use `langgraph up` (multi-worker) instead of `langgraph dev` (single-worker)
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
"""Memory updater for reading, writing, and updating memory data."""
|
"""Memory updater for reading, writing, and updating memory data."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import atexit
|
||||||
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import Awaitable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from deerflow.agents.memory.prompt import (
|
from deerflow.agents.memory.prompt import (
|
||||||
@@ -21,6 +25,12 @@ from deerflow.models import create_chat_model
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
|
||||||
|
max_workers=4,
|
||||||
|
thread_name_prefix="memory-updater-sync",
|
||||||
|
)
|
||||||
|
atexit.register(lambda: _SYNC_MEMORY_UPDATER_EXECUTOR.shutdown(wait=False))
|
||||||
|
|
||||||
|
|
||||||
def _create_empty_memory() -> dict[str, Any]:
|
def _create_empty_memory() -> dict[str, Any]:
|
||||||
"""Backward-compatible wrapper around the storage-layer empty-memory factory."""
|
"""Backward-compatible wrapper around the storage-layer empty-memory factory."""
|
||||||
@@ -206,6 +216,39 @@ def _extract_text(content: Any) -> str:
|
|||||||
return str(content)
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_async_update_sync(coro: Awaitable[bool]) -> bool:
|
||||||
|
"""Run an async memory update from sync code, including nested-loop contexts."""
|
||||||
|
handed_off = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = None
|
||||||
|
|
||||||
|
if loop is not None and loop.is_running():
|
||||||
|
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(asyncio.run, coro)
|
||||||
|
handed_off = True
|
||||||
|
return future.result()
|
||||||
|
|
||||||
|
handed_off = True
|
||||||
|
return asyncio.run(coro)
|
||||||
|
except Exception:
|
||||||
|
if not handed_off:
|
||||||
|
close = getattr(coro, "close", None)
|
||||||
|
if callable(close):
|
||||||
|
try:
|
||||||
|
close()
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"Failed to close un-awaited memory update coroutine",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.exception("Failed to run async memory update from sync context")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# Matches sentences that describe a file-upload *event* rather than general
|
# Matches sentences that describe a file-upload *event* rather than general
|
||||||
# file-related work. Deliberately narrow to avoid removing legitimate facts
|
# file-related work. Deliberately narrow to avoid removing legitimate facts
|
||||||
# such as "User works with CSV files" or "prefers PDF export".
|
# such as "User works with CSV files" or "prefers PDF export".
|
||||||
@@ -269,6 +312,113 @@ class MemoryUpdater:
|
|||||||
model_name = self._model_name or config.model_name
|
model_name = self._model_name or config.model_name
|
||||||
return create_chat_model(name=model_name, thinking_enabled=False)
|
return create_chat_model(name=model_name, thinking_enabled=False)
|
||||||
|
|
||||||
|
def _build_correction_hint(
|
||||||
|
self,
|
||||||
|
correction_detected: bool,
|
||||||
|
reinforcement_detected: bool,
|
||||||
|
) -> str:
|
||||||
|
"""Build optional prompt hints for correction and reinforcement signals."""
|
||||||
|
correction_hint = ""
|
||||||
|
if correction_detected:
|
||||||
|
correction_hint = (
|
||||||
|
"IMPORTANT: Explicit correction signals were detected in this conversation. "
|
||||||
|
"Pay special attention to what the agent got wrong, what the user corrected, "
|
||||||
|
"and record the correct approach as a fact with category "
|
||||||
|
'"correction" and confidence >= 0.95 when appropriate.'
|
||||||
|
)
|
||||||
|
if reinforcement_detected:
|
||||||
|
reinforcement_hint = (
|
||||||
|
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
|
||||||
|
"The user explicitly confirmed the agent's approach was correct or helpful. "
|
||||||
|
"Record the confirmed approach, style, or preference as a fact with category "
|
||||||
|
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
|
||||||
|
)
|
||||||
|
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
|
||||||
|
|
||||||
|
return correction_hint
|
||||||
|
|
||||||
|
def _prepare_update_prompt(
|
||||||
|
self,
|
||||||
|
messages: list[Any],
|
||||||
|
agent_name: str | None,
|
||||||
|
correction_detected: bool,
|
||||||
|
reinforcement_detected: bool,
|
||||||
|
) -> tuple[dict[str, Any], str] | None:
|
||||||
|
"""Load memory and build the update prompt for a conversation."""
|
||||||
|
config = get_memory_config()
|
||||||
|
if not config.enabled or not messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
current_memory = get_memory_data(agent_name)
|
||||||
|
conversation_text = format_conversation_for_update(messages)
|
||||||
|
if not conversation_text.strip():
|
||||||
|
return None
|
||||||
|
|
||||||
|
correction_hint = self._build_correction_hint(
|
||||||
|
correction_detected=correction_detected,
|
||||||
|
reinforcement_detected=reinforcement_detected,
|
||||||
|
)
|
||||||
|
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||||
|
current_memory=json.dumps(current_memory, indent=2),
|
||||||
|
conversation=conversation_text,
|
||||||
|
correction_hint=correction_hint,
|
||||||
|
)
|
||||||
|
return current_memory, prompt
|
||||||
|
|
||||||
|
def _finalize_update(
|
||||||
|
self,
|
||||||
|
current_memory: dict[str, Any],
|
||||||
|
response_content: Any,
|
||||||
|
thread_id: str | None,
|
||||||
|
agent_name: str | None,
|
||||||
|
) -> bool:
|
||||||
|
"""Parse the model response, apply updates, and persist memory."""
|
||||||
|
response_text = _extract_text(response_content).strip()
|
||||||
|
|
||||||
|
if response_text.startswith("```"):
|
||||||
|
lines = response_text.split("\n")
|
||||||
|
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
||||||
|
|
||||||
|
update_data = json.loads(response_text)
|
||||||
|
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
|
||||||
|
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||||
|
return get_memory_storage().save(updated_memory, agent_name)
|
||||||
|
|
||||||
|
async def aupdate_memory(
|
||||||
|
self,
|
||||||
|
messages: list[Any],
|
||||||
|
thread_id: str | None = None,
|
||||||
|
agent_name: str | None = None,
|
||||||
|
correction_detected: bool = False,
|
||||||
|
reinforcement_detected: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""Update memory asynchronously based on conversation messages."""
|
||||||
|
try:
|
||||||
|
prepared = self._prepare_update_prompt(
|
||||||
|
messages=messages,
|
||||||
|
agent_name=agent_name,
|
||||||
|
correction_detected=correction_detected,
|
||||||
|
reinforcement_detected=reinforcement_detected,
|
||||||
|
)
|
||||||
|
if prepared is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_memory, prompt = prepared
|
||||||
|
model = self._get_model()
|
||||||
|
response = await model.ainvoke(prompt)
|
||||||
|
return self._finalize_update(
|
||||||
|
current_memory=current_memory,
|
||||||
|
response_content=response.content,
|
||||||
|
thread_id=thread_id,
|
||||||
|
agent_name=agent_name,
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Memory update failed: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
def update_memory(
|
def update_memory(
|
||||||
self,
|
self,
|
||||||
messages: list[Any],
|
messages: list[Any],
|
||||||
@@ -277,7 +427,7 @@ class MemoryUpdater:
|
|||||||
correction_detected: bool = False,
|
correction_detected: bool = False,
|
||||||
reinforcement_detected: bool = False,
|
reinforcement_detected: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Update memory based on conversation messages.
|
"""Synchronously update memory via the async updater path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of conversation messages.
|
messages: List of conversation messages.
|
||||||
@@ -289,78 +439,15 @@ class MemoryUpdater:
|
|||||||
Returns:
|
Returns:
|
||||||
True if update was successful, False otherwise.
|
True if update was successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
config = get_memory_config()
|
return _run_async_update_sync(
|
||||||
if not config.enabled:
|
self.aupdate_memory(
|
||||||
return False
|
messages=messages,
|
||||||
|
thread_id=thread_id,
|
||||||
if not messages:
|
agent_name=agent_name,
|
||||||
return False
|
correction_detected=correction_detected,
|
||||||
|
reinforcement_detected=reinforcement_detected,
|
||||||
try:
|
|
||||||
# Get current memory
|
|
||||||
current_memory = get_memory_data(agent_name)
|
|
||||||
|
|
||||||
# Format conversation for prompt
|
|
||||||
conversation_text = format_conversation_for_update(messages)
|
|
||||||
|
|
||||||
if not conversation_text.strip():
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Build prompt
|
|
||||||
correction_hint = ""
|
|
||||||
if correction_detected:
|
|
||||||
correction_hint = (
|
|
||||||
"IMPORTANT: Explicit correction signals were detected in this conversation. "
|
|
||||||
"Pay special attention to what the agent got wrong, what the user corrected, "
|
|
||||||
"and record the correct approach as a fact with category "
|
|
||||||
'"correction" and confidence >= 0.95 when appropriate.'
|
|
||||||
)
|
|
||||||
if reinforcement_detected:
|
|
||||||
reinforcement_hint = (
|
|
||||||
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
|
|
||||||
"The user explicitly confirmed the agent's approach was correct or helpful. "
|
|
||||||
"Record the confirmed approach, style, or preference as a fact with category "
|
|
||||||
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
|
|
||||||
)
|
|
||||||
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
|
|
||||||
|
|
||||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
|
||||||
current_memory=json.dumps(current_memory, indent=2),
|
|
||||||
conversation=conversation_text,
|
|
||||||
correction_hint=correction_hint,
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
# Call LLM
|
|
||||||
model = self._get_model()
|
|
||||||
response = model.invoke(prompt)
|
|
||||||
response_text = _extract_text(response.content).strip()
|
|
||||||
|
|
||||||
# Parse response
|
|
||||||
# Remove markdown code blocks if present
|
|
||||||
if response_text.startswith("```"):
|
|
||||||
lines = response_text.split("\n")
|
|
||||||
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
|
||||||
|
|
||||||
update_data = json.loads(response_text)
|
|
||||||
|
|
||||||
# Apply updates
|
|
||||||
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
|
|
||||||
|
|
||||||
# Strip file-upload mentions from all summaries before saving.
|
|
||||||
# Uploaded files are session-scoped and won't exist in future sessions,
|
|
||||||
# so recording upload events in long-term memory causes the agent to
|
|
||||||
# try (and fail) to locate those files in subsequent conversations.
|
|
||||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
|
||||||
|
|
||||||
# Save
|
|
||||||
return get_memory_storage().save(updated_memory, agent_name)
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Memory update failed: %s", e)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _apply_updates(
|
def _apply_updates(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
from unittest.mock import MagicMock, patch
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from deerflow.agents.memory.prompt import format_conversation_for_update
|
from deerflow.agents.memory.prompt import format_conversation_for_update
|
||||||
from deerflow.agents.memory.updater import (
|
from deerflow.agents.memory.updater import (
|
||||||
MemoryUpdater,
|
MemoryUpdater,
|
||||||
_extract_text,
|
_extract_text,
|
||||||
|
_run_async_update_sync,
|
||||||
clear_memory_data,
|
clear_memory_data,
|
||||||
create_memory_fact,
|
create_memory_fact,
|
||||||
delete_memory_fact,
|
delete_memory_fact,
|
||||||
@@ -523,15 +527,16 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
model = MagicMock()
|
model = MagicMock()
|
||||||
response = MagicMock()
|
response = MagicMock()
|
||||||
response.content = content
|
response.content = content
|
||||||
model.invoke.return_value = response
|
model.ainvoke = AsyncMock(return_value=response)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def test_string_response_parses(self):
|
def test_string_response_parses(self):
|
||||||
updater = MemoryUpdater()
|
updater = MemoryUpdater()
|
||||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
|
model = self._make_mock_model(valid_json)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(updater, "_get_model", return_value=self._make_mock_model(valid_json)),
|
patch.object(updater, "_get_model", return_value=model),
|
||||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||||
@@ -546,6 +551,7 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
result = updater.update_memory([msg, ai_msg])
|
result = updater.update_memory([msg, ai_msg])
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
|
model.ainvoke.assert_awaited_once()
|
||||||
|
|
||||||
def test_list_content_response_parses(self):
|
def test_list_content_response_parses(self):
|
||||||
"""LLM response as list-of-blocks should be extracted, not repr'd."""
|
"""LLM response as list-of-blocks should be extracted, not repr'd."""
|
||||||
@@ -570,6 +576,29 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
def test_async_update_memory_uses_ainvoke(self):
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
|
model = self._make_mock_model(valid_json)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(updater, "_get_model", return_value=model),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||||
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "Hello"
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Hi there"
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
result = asyncio.run(updater.aupdate_memory([msg, ai_msg]))
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
model.ainvoke.assert_awaited_once()
|
||||||
|
|
||||||
def test_correction_hint_injected_when_detected(self):
|
def test_correction_hint_injected_when_detected(self):
|
||||||
updater = MemoryUpdater()
|
updater = MemoryUpdater()
|
||||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
@@ -592,7 +621,7 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
result = updater.update_memory([msg, ai_msg], correction_detected=True)
|
result = updater.update_memory([msg, ai_msg], correction_detected=True)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
prompt = model.invoke.call_args[0][0]
|
prompt = model.ainvoke.await_args.args[0]
|
||||||
assert "Explicit correction signals were detected" in prompt
|
assert "Explicit correction signals were detected" in prompt
|
||||||
|
|
||||||
def test_correction_hint_empty_when_not_detected(self):
|
def test_correction_hint_empty_when_not_detected(self):
|
||||||
@@ -617,9 +646,89 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
result = updater.update_memory([msg, ai_msg], correction_detected=False)
|
result = updater.update_memory([msg, ai_msg], correction_detected=False)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
prompt = model.invoke.call_args[0][0]
|
prompt = model.ainvoke.await_args.args[0]
|
||||||
assert "Explicit correction signals were detected" not in prompt
|
assert "Explicit correction signals were detected" not in prompt
|
||||||
|
|
||||||
|
def test_sync_update_memory_wrapper_works_in_running_loop(self):
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
|
model = self._make_mock_model(valid_json)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(updater, "_get_model", return_value=model),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||||
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "Hello from loop"
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Hi"
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
|
||||||
|
async def run_in_loop():
|
||||||
|
return updater.update_memory([msg, ai_msg])
|
||||||
|
|
||||||
|
result = asyncio.run(run_in_loop())
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
model.ainvoke.assert_awaited_once()
|
||||||
|
|
||||||
|
def test_sync_update_memory_returns_false_when_bridge_submit_fails(self):
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
|
||||||
|
side_effect=RuntimeError("executor down"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "Hello from loop"
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Hi"
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
|
||||||
|
async def run_in_loop():
|
||||||
|
return updater.update_memory([msg, ai_msg])
|
||||||
|
|
||||||
|
result = asyncio.run(run_in_loop())
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunAsyncUpdateSync:
|
||||||
|
def test_closes_unawaited_awaitable_when_bridge_fails_before_handoff(self):
|
||||||
|
class CloseableAwaitable:
|
||||||
|
def __init__(self):
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
def __await__(self):
|
||||||
|
pytest.fail("awaitable should not have been awaited")
|
||||||
|
yield
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
awaitable = CloseableAwaitable()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"deerflow.agents.memory.updater._SYNC_MEMORY_UPDATER_EXECUTOR.submit",
|
||||||
|
side_effect=RuntimeError("executor down"),
|
||||||
|
):
|
||||||
|
|
||||||
|
async def run_in_loop():
|
||||||
|
return _run_async_update_sync(awaitable)
|
||||||
|
|
||||||
|
result = asyncio.run(run_in_loop())
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert awaitable.closed is True
|
||||||
|
|
||||||
|
|
||||||
class TestFactDeduplicationCaseInsensitive:
|
class TestFactDeduplicationCaseInsensitive:
|
||||||
"""Tests that fact deduplication is case-insensitive."""
|
"""Tests that fact deduplication is case-insensitive."""
|
||||||
@@ -694,7 +803,7 @@ class TestReinforcementHint:
|
|||||||
model = MagicMock()
|
model = MagicMock()
|
||||||
response = MagicMock()
|
response = MagicMock()
|
||||||
response.content = f"```json\n{json_response}\n```"
|
response.content = f"```json\n{json_response}\n```"
|
||||||
model.invoke.return_value = response
|
model.ainvoke = AsyncMock(return_value=response)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def test_reinforcement_hint_injected_when_detected(self):
|
def test_reinforcement_hint_injected_when_detected(self):
|
||||||
@@ -719,7 +828,7 @@ class TestReinforcementHint:
|
|||||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
|
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
prompt = model.invoke.call_args[0][0]
|
prompt = model.ainvoke.await_args.args[0]
|
||||||
assert "Positive reinforcement signals were detected" in prompt
|
assert "Positive reinforcement signals were detected" in prompt
|
||||||
|
|
||||||
def test_reinforcement_hint_absent_when_not_detected(self):
|
def test_reinforcement_hint_absent_when_not_detected(self):
|
||||||
@@ -744,7 +853,7 @@ class TestReinforcementHint:
|
|||||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
|
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
prompt = model.invoke.call_args[0][0]
|
prompt = model.ainvoke.await_args.args[0]
|
||||||
assert "Positive reinforcement signals were detected" not in prompt
|
assert "Positive reinforcement signals were detected" not in prompt
|
||||||
|
|
||||||
def test_both_hints_present_when_both_detected(self):
|
def test_both_hints_present_when_both_detected(self):
|
||||||
@@ -769,6 +878,6 @@ class TestReinforcementHint:
|
|||||||
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
|
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
prompt = model.invoke.call_args[0][0]
|
prompt = model.ainvoke.await_args.args[0]
|
||||||
assert "Explicit correction signals were detected" in prompt
|
assert "Explicit correction signals were detected" in prompt
|
||||||
assert "Positive reinforcement signals were detected" in prompt
|
assert "Positive reinforcement signals were detected" in prompt
|
||||||
|
|||||||
Reference in New Issue
Block a user