feat(memory): thread user_id through memory updater layer
Add `user_id` keyword-only parameter to all public updater functions (_save_memory_to_file, get_memory_data, reload_memory_data, import_memory_data, clear_memory_data, create/delete/update_memory_fact) and regular keyword param to MemoryUpdater.update_memory + update_memory_from_conversation, propagating it to every storage load/save/reload call. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -27,27 +27,28 @@ def _create_empty_memory() -> dict[str, Any]:
|
|||||||
return create_empty_memory()
|
return create_empty_memory()
|
||||||
|
|
||||||
|
|
||||||
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||||
"""Backward-compatible wrapper around the configured memory storage save path."""
|
"""Backward-compatible wrapper around the configured memory storage save path."""
|
||||||
return get_memory_storage().save(memory_data, agent_name)
|
return get_memory_storage().save(memory_data, agent_name, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
def get_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Get the current memory data via storage provider."""
|
"""Get the current memory data via storage provider."""
|
||||||
return get_memory_storage().load(agent_name)
|
return get_memory_storage().load(agent_name, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
def reload_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Reload memory data via storage provider."""
|
"""Reload memory data via storage provider."""
|
||||||
return get_memory_storage().reload(agent_name)
|
return get_memory_storage().reload(agent_name, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]:
|
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Persist imported memory data via storage provider.
|
"""Persist imported memory data via storage provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_data: Full memory payload to persist.
|
memory_data: Full memory payload to persist.
|
||||||
agent_name: If provided, imports into per-agent memory.
|
agent_name: If provided, imports into per-agent memory.
|
||||||
|
user_id: If provided, scopes memory to a specific user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The saved memory data after storage normalization.
|
The saved memory data after storage normalization.
|
||||||
@@ -56,15 +57,15 @@ def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = Non
|
|||||||
OSError: If persisting the imported memory fails.
|
OSError: If persisting the imported memory fails.
|
||||||
"""
|
"""
|
||||||
storage = get_memory_storage()
|
storage = get_memory_storage()
|
||||||
if not storage.save(memory_data, agent_name):
|
if not storage.save(memory_data, agent_name, user_id=user_id):
|
||||||
raise OSError("Failed to save imported memory data")
|
raise OSError("Failed to save imported memory data")
|
||||||
return storage.load(agent_name)
|
return storage.load(agent_name, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
def clear_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Clear all stored memory data and persist an empty structure."""
|
"""Clear all stored memory data and persist an empty structure."""
|
||||||
cleared_memory = create_empty_memory()
|
cleared_memory = create_empty_memory()
|
||||||
if not _save_memory_to_file(cleared_memory, agent_name):
|
if not _save_memory_to_file(cleared_memory, agent_name, user_id=user_id):
|
||||||
raise OSError("Failed to save cleared memory data")
|
raise OSError("Failed to save cleared memory data")
|
||||||
return cleared_memory
|
return cleared_memory
|
||||||
|
|
||||||
@@ -81,6 +82,8 @@ def create_memory_fact(
|
|||||||
category: str = "context",
|
category: str = "context",
|
||||||
confidence: float = 0.5,
|
confidence: float = 0.5,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Create a new fact and persist the updated memory data."""
|
"""Create a new fact and persist the updated memory data."""
|
||||||
normalized_content = content.strip()
|
normalized_content = content.strip()
|
||||||
@@ -90,7 +93,7 @@ def create_memory_fact(
|
|||||||
normalized_category = category.strip() or "context"
|
normalized_category = category.strip() or "context"
|
||||||
validated_confidence = _validate_confidence(confidence)
|
validated_confidence = _validate_confidence(confidence)
|
||||||
now = utc_now_iso_z()
|
now = utc_now_iso_z()
|
||||||
memory_data = get_memory_data(agent_name)
|
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||||
updated_memory = dict(memory_data)
|
updated_memory = dict(memory_data)
|
||||||
facts = list(memory_data.get("facts", []))
|
facts = list(memory_data.get("facts", []))
|
||||||
facts.append(
|
facts.append(
|
||||||
@@ -105,15 +108,15 @@ def create_memory_fact(
|
|||||||
)
|
)
|
||||||
updated_memory["facts"] = facts
|
updated_memory["facts"] = facts
|
||||||
|
|
||||||
if not _save_memory_to_file(updated_memory, agent_name):
|
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||||
raise OSError("Failed to save memory data after creating fact")
|
raise OSError("Failed to save memory data after creating fact")
|
||||||
|
|
||||||
return updated_memory
|
return updated_memory
|
||||||
|
|
||||||
|
|
||||||
def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]:
|
def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Delete a fact by its id and persist the updated memory data."""
|
"""Delete a fact by its id and persist the updated memory data."""
|
||||||
memory_data = get_memory_data(agent_name)
|
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||||
facts = memory_data.get("facts", [])
|
facts = memory_data.get("facts", [])
|
||||||
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
|
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
|
||||||
if len(updated_facts) == len(facts):
|
if len(updated_facts) == len(facts):
|
||||||
@@ -122,7 +125,7 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str,
|
|||||||
updated_memory = dict(memory_data)
|
updated_memory = dict(memory_data)
|
||||||
updated_memory["facts"] = updated_facts
|
updated_memory["facts"] = updated_facts
|
||||||
|
|
||||||
if not _save_memory_to_file(updated_memory, agent_name):
|
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||||
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
|
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
|
||||||
|
|
||||||
return updated_memory
|
return updated_memory
|
||||||
@@ -134,9 +137,11 @@ def update_memory_fact(
|
|||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
confidence: float | None = None,
|
confidence: float | None = None,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Update an existing fact and persist the updated memory data."""
|
"""Update an existing fact and persist the updated memory data."""
|
||||||
memory_data = get_memory_data(agent_name)
|
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||||
updated_memory = dict(memory_data)
|
updated_memory = dict(memory_data)
|
||||||
updated_facts: list[dict[str, Any]] = []
|
updated_facts: list[dict[str, Any]] = []
|
||||||
found = False
|
found = False
|
||||||
@@ -163,7 +168,7 @@ def update_memory_fact(
|
|||||||
|
|
||||||
updated_memory["facts"] = updated_facts
|
updated_memory["facts"] = updated_facts
|
||||||
|
|
||||||
if not _save_memory_to_file(updated_memory, agent_name):
|
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||||
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
|
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
|
||||||
|
|
||||||
return updated_memory
|
return updated_memory
|
||||||
@@ -276,6 +281,7 @@ class MemoryUpdater:
|
|||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
correction_detected: bool = False,
|
correction_detected: bool = False,
|
||||||
reinforcement_detected: bool = False,
|
reinforcement_detected: bool = False,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Update memory based on conversation messages.
|
"""Update memory based on conversation messages.
|
||||||
|
|
||||||
@@ -285,6 +291,7 @@ class MemoryUpdater:
|
|||||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||||
correction_detected: Whether recent turns include an explicit correction signal.
|
correction_detected: Whether recent turns include an explicit correction signal.
|
||||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||||
|
user_id: If provided, scopes memory to a specific user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if update was successful, False otherwise.
|
True if update was successful, False otherwise.
|
||||||
@@ -298,7 +305,7 @@ class MemoryUpdater:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Get current memory
|
# Get current memory
|
||||||
current_memory = get_memory_data(agent_name)
|
current_memory = get_memory_data(agent_name, user_id=user_id)
|
||||||
|
|
||||||
# Format conversation for prompt
|
# Format conversation for prompt
|
||||||
conversation_text = format_conversation_for_update(messages)
|
conversation_text = format_conversation_for_update(messages)
|
||||||
@@ -353,7 +360,7 @@ class MemoryUpdater:
|
|||||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||||
|
|
||||||
# Save
|
# Save
|
||||||
return get_memory_storage().save(updated_memory, agent_name)
|
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||||
@@ -455,6 +462,7 @@ def update_memory_from_conversation(
|
|||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
correction_detected: bool = False,
|
correction_detected: bool = False,
|
||||||
reinforcement_detected: bool = False,
|
reinforcement_detected: bool = False,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Convenience function to update memory from a conversation.
|
"""Convenience function to update memory from a conversation.
|
||||||
|
|
||||||
@@ -464,9 +472,10 @@ def update_memory_from_conversation(
|
|||||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||||
correction_detected: Whether recent turns include an explicit correction signal.
|
correction_detected: Whether recent turns include an explicit correction signal.
|
||||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||||
|
user_id: If provided, scopes memory to a specific user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if successful, False otherwise.
|
True if successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
updater = MemoryUpdater()
|
updater = MemoryUpdater()
|
||||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
|
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected, user_id=user_id)
|
||||||
|
|||||||
@@ -301,8 +301,8 @@ def test_import_memory_data_saves_and_returns_imported_memory() -> None:
|
|||||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||||
result = import_memory_data(imported_memory)
|
result = import_memory_data(imported_memory)
|
||||||
|
|
||||||
mock_storage.save.assert_called_once_with(imported_memory, None)
|
mock_storage.save.assert_called_once_with(imported_memory, None, user_id=None)
|
||||||
mock_storage.load.assert_called_once_with(None)
|
mock_storage.load.assert_called_once_with(None, user_id=None)
|
||||||
assert result == imported_memory
|
assert result == imported_memory
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
"""Tests for user_id propagation in memory updater."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from deerflow.agents.memory.updater import get_memory_data, clear_memory_data, _save_memory_to_file
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_memory_data_passes_user_id():
|
||||||
|
mock_storage = MagicMock()
|
||||||
|
mock_storage.load.return_value = {"version": "1.0"}
|
||||||
|
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||||
|
get_memory_data(user_id="alice")
|
||||||
|
mock_storage.load.assert_called_once_with(None, user_id="alice")
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_memory_passes_user_id():
|
||||||
|
mock_storage = MagicMock()
|
||||||
|
mock_storage.save.return_value = True
|
||||||
|
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||||
|
_save_memory_to_file({"version": "1.0"}, user_id="bob")
|
||||||
|
mock_storage.save.assert_called_once_with({"version": "1.0"}, None, user_id="bob")
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_memory_data_passes_user_id():
|
||||||
|
mock_storage = MagicMock()
|
||||||
|
mock_storage.save.return_value = True
|
||||||
|
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||||
|
clear_memory_data(user_id="charlie")
|
||||||
|
# Verify save was called with user_id
|
||||||
|
assert mock_storage.save.call_args.kwargs["user_id"] == "charlie"
|
||||||
Reference in New Issue
Block a user