feat(memory): thread user_id through memory updater layer

Add `user_id` keyword-only parameter to all public updater functions
(_save_memory_to_file, get_memory_data, reload_memory_data, import_memory_data,
clear_memory_data, create/delete/update_memory_fact) and regular keyword param
to MemoryUpdater.update_memory + update_memory_from_conversation, propagating
it to every storage load/save/reload call.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-12 13:37:08 +08:00
parent 3877aabcfd
commit dfa9fc47b3
3 changed files with 61 additions and 23 deletions
@@ -27,27 +27,28 @@ def _create_empty_memory() -> dict[str, Any]:
return create_empty_memory() return create_empty_memory()
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool: def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
"""Backward-compatible wrapper around the configured memory storage save path.""" """Backward-compatible wrapper around the configured memory storage save path."""
return get_memory_storage().save(memory_data, agent_name) return get_memory_storage().save(memory_data, agent_name, user_id=user_id)
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]: def get_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Get the current memory data via storage provider.""" """Get the current memory data via storage provider."""
return get_memory_storage().load(agent_name) return get_memory_storage().load(agent_name, user_id=user_id)
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]: def reload_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Reload memory data via storage provider.""" """Reload memory data via storage provider."""
return get_memory_storage().reload(agent_name) return get_memory_storage().reload(agent_name, user_id=user_id)
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]: def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Persist imported memory data via storage provider. """Persist imported memory data via storage provider.
Args: Args:
memory_data: Full memory payload to persist. memory_data: Full memory payload to persist.
agent_name: If provided, imports into per-agent memory. agent_name: If provided, imports into per-agent memory.
user_id: If provided, scopes memory to a specific user.
Returns: Returns:
The saved memory data after storage normalization. The saved memory data after storage normalization.
@@ -56,15 +57,15 @@ def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = Non
OSError: If persisting the imported memory fails. OSError: If persisting the imported memory fails.
""" """
storage = get_memory_storage() storage = get_memory_storage()
if not storage.save(memory_data, agent_name): if not storage.save(memory_data, agent_name, user_id=user_id):
raise OSError("Failed to save imported memory data") raise OSError("Failed to save imported memory data")
return storage.load(agent_name) return storage.load(agent_name, user_id=user_id)
def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]: def clear_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Clear all stored memory data and persist an empty structure.""" """Clear all stored memory data and persist an empty structure."""
cleared_memory = create_empty_memory() cleared_memory = create_empty_memory()
if not _save_memory_to_file(cleared_memory, agent_name): if not _save_memory_to_file(cleared_memory, agent_name, user_id=user_id):
raise OSError("Failed to save cleared memory data") raise OSError("Failed to save cleared memory data")
return cleared_memory return cleared_memory
@@ -81,6 +82,8 @@ def create_memory_fact(
category: str = "context", category: str = "context",
confidence: float = 0.5, confidence: float = 0.5,
agent_name: str | None = None, agent_name: str | None = None,
*,
user_id: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Create a new fact and persist the updated memory data.""" """Create a new fact and persist the updated memory data."""
normalized_content = content.strip() normalized_content = content.strip()
@@ -90,7 +93,7 @@ def create_memory_fact(
normalized_category = category.strip() or "context" normalized_category = category.strip() or "context"
validated_confidence = _validate_confidence(confidence) validated_confidence = _validate_confidence(confidence)
now = utc_now_iso_z() now = utc_now_iso_z()
memory_data = get_memory_data(agent_name) memory_data = get_memory_data(agent_name, user_id=user_id)
updated_memory = dict(memory_data) updated_memory = dict(memory_data)
facts = list(memory_data.get("facts", [])) facts = list(memory_data.get("facts", []))
facts.append( facts.append(
@@ -105,15 +108,15 @@ def create_memory_fact(
) )
updated_memory["facts"] = facts updated_memory["facts"] = facts
if not _save_memory_to_file(updated_memory, agent_name): if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
raise OSError("Failed to save memory data after creating fact") raise OSError("Failed to save memory data after creating fact")
return updated_memory return updated_memory
def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]: def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Delete a fact by its id and persist the updated memory data.""" """Delete a fact by its id and persist the updated memory data."""
memory_data = get_memory_data(agent_name) memory_data = get_memory_data(agent_name, user_id=user_id)
facts = memory_data.get("facts", []) facts = memory_data.get("facts", [])
updated_facts = [fact for fact in facts if fact.get("id") != fact_id] updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
if len(updated_facts) == len(facts): if len(updated_facts) == len(facts):
@@ -122,7 +125,7 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str,
updated_memory = dict(memory_data) updated_memory = dict(memory_data)
updated_memory["facts"] = updated_facts updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name): if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'") raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
return updated_memory return updated_memory
@@ -134,9 +137,11 @@ def update_memory_fact(
category: str | None = None, category: str | None = None,
confidence: float | None = None, confidence: float | None = None,
agent_name: str | None = None, agent_name: str | None = None,
*,
user_id: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Update an existing fact and persist the updated memory data.""" """Update an existing fact and persist the updated memory data."""
memory_data = get_memory_data(agent_name) memory_data = get_memory_data(agent_name, user_id=user_id)
updated_memory = dict(memory_data) updated_memory = dict(memory_data)
updated_facts: list[dict[str, Any]] = [] updated_facts: list[dict[str, Any]] = []
found = False found = False
@@ -163,7 +168,7 @@ def update_memory_fact(
updated_memory["facts"] = updated_facts updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name): if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'") raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
return updated_memory return updated_memory
@@ -276,6 +281,7 @@ class MemoryUpdater:
agent_name: str | None = None, agent_name: str | None = None,
correction_detected: bool = False, correction_detected: bool = False,
reinforcement_detected: bool = False, reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool: ) -> bool:
"""Update memory based on conversation messages. """Update memory based on conversation messages.
@@ -285,6 +291,7 @@ class MemoryUpdater:
agent_name: If provided, updates per-agent memory. If None, updates global memory. agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal. correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal. reinforcement_detected: Whether recent turns include a positive reinforcement signal.
user_id: If provided, scopes memory to a specific user.
Returns: Returns:
True if update was successful, False otherwise. True if update was successful, False otherwise.
@@ -298,7 +305,7 @@ class MemoryUpdater:
try: try:
# Get current memory # Get current memory
current_memory = get_memory_data(agent_name) current_memory = get_memory_data(agent_name, user_id=user_id)
# Format conversation for prompt # Format conversation for prompt
conversation_text = format_conversation_for_update(messages) conversation_text = format_conversation_for_update(messages)
@@ -353,7 +360,7 @@ class MemoryUpdater:
updated_memory = _strip_upload_mentions_from_memory(updated_memory) updated_memory = _strip_upload_mentions_from_memory(updated_memory)
# Save # Save
return get_memory_storage().save(updated_memory, agent_name) return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e) logger.warning("Failed to parse LLM response for memory update: %s", e)
@@ -455,6 +462,7 @@ def update_memory_from_conversation(
agent_name: str | None = None, agent_name: str | None = None,
correction_detected: bool = False, correction_detected: bool = False,
reinforcement_detected: bool = False, reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool: ) -> bool:
"""Convenience function to update memory from a conversation. """Convenience function to update memory from a conversation.
@@ -464,9 +472,10 @@ def update_memory_from_conversation(
agent_name: If provided, updates per-agent memory. If None, updates global memory. agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal. correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal. reinforcement_detected: Whether recent turns include a positive reinforcement signal.
user_id: If provided, scopes memory to a specific user.
Returns: Returns:
True if successful, False otherwise. True if successful, False otherwise.
""" """
updater = MemoryUpdater() updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected) return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected, user_id=user_id)
+2 -2
View File
@@ -301,8 +301,8 @@ def test_import_memory_data_saves_and_returns_imported_memory() -> None:
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage): with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
result = import_memory_data(imported_memory) result = import_memory_data(imported_memory)
mock_storage.save.assert_called_once_with(imported_memory, None) mock_storage.save.assert_called_once_with(imported_memory, None, user_id=None)
mock_storage.load.assert_called_once_with(None) mock_storage.load.assert_called_once_with(None, user_id=None)
assert result == imported_memory assert result == imported_memory
@@ -0,0 +1,29 @@
"""Tests for user_id propagation in memory updater."""
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.updater import get_memory_data, clear_memory_data, _save_memory_to_file
def test_get_memory_data_passes_user_id():
mock_storage = MagicMock()
mock_storage.load.return_value = {"version": "1.0"}
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
get_memory_data(user_id="alice")
mock_storage.load.assert_called_once_with(None, user_id="alice")
def test_save_memory_passes_user_id():
mock_storage = MagicMock()
mock_storage.save.return_value = True
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
_save_memory_to_file({"version": "1.0"}, user_id="bob")
mock_storage.save.assert_called_once_with({"version": "1.0"}, None, user_id="bob")
def test_clear_memory_data_passes_user_id():
mock_storage = MagicMock()
mock_storage.save.return_value = True
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
clear_memory_data(user_id="charlie")
# Verify save was called with user_id
assert mock_storage.save.call_args.kwargs["user_id"] == "charlie"