mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
feat: support memory import and export (#1521)
* feat: support memory import and export * fix(memory): address review feedback * style: format memory settings page --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from deerflow.agents.memory.updater import (
|
||||
create_memory_fact,
|
||||
delete_memory_fact,
|
||||
get_memory_data,
|
||||
import_memory_data,
|
||||
reload_memory_data,
|
||||
update_memory_fact,
|
||||
)
|
||||
@@ -248,6 +249,34 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/memory/export",
|
||||
response_model=MemoryResponse,
|
||||
summary="Export Memory Data",
|
||||
description="Export the current global memory data as JSON for backup or transfer.",
|
||||
)
|
||||
async def export_memory() -> MemoryResponse:
|
||||
"""Export the current memory data."""
|
||||
memory_data = get_memory_data()
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/memory/import",
|
||||
response_model=MemoryResponse,
|
||||
summary="Import Memory Data",
|
||||
description="Import and overwrite the current global memory data from a JSON payload.",
|
||||
)
|
||||
async def import_memory(request: MemoryResponse) -> MemoryResponse:
|
||||
"""Import and persist memory data."""
|
||||
try:
|
||||
memory_data = import_memory_data(request.model_dump())
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
||||
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/memory/config",
|
||||
response_model=MemoryConfigResponse,
|
||||
|
||||
@@ -39,6 +39,25 @@ def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
return get_memory_storage().reload(agent_name)
|
||||
|
||||
|
||||
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Persist imported memory data via storage provider.
|
||||
|
||||
Args:
|
||||
memory_data: Full memory payload to persist.
|
||||
agent_name: If provided, imports into per-agent memory.
|
||||
|
||||
Returns:
|
||||
The saved memory data after storage normalization.
|
||||
|
||||
Raises:
|
||||
OSError: If persisting the imported memory fails.
|
||||
"""
|
||||
storage = get_memory_storage()
|
||||
if not storage.save(memory_data, agent_name):
|
||||
raise OSError("Failed to save imported memory data")
|
||||
return storage.load(agent_name)
|
||||
|
||||
|
||||
def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Clear all stored memory data and persist an empty structure."""
|
||||
cleared_memory = create_empty_memory()
|
||||
|
||||
@@ -507,6 +507,18 @@ class DeerFlowClient:
|
||||
|
||||
return get_memory_data()
|
||||
|
||||
def export_memory(self) -> dict:
|
||||
"""Export current memory data for backup or transfer."""
|
||||
from deerflow.agents.memory.updater import get_memory_data
|
||||
|
||||
return get_memory_data()
|
||||
|
||||
def import_memory(self, memory_data: dict) -> dict:
|
||||
"""Import and persist full memory data."""
|
||||
from deerflow.agents.memory.updater import import_memory_data
|
||||
|
||||
return import_memory_data(memory_data)
|
||||
|
||||
def get_model(self, name: str) -> dict | None:
|
||||
"""Get a specific model's configuration by name.
|
||||
|
||||
|
||||
@@ -145,6 +145,13 @@ class TestConfigQueries:
|
||||
mock_mem.assert_called_once()
|
||||
assert result == memory
|
||||
|
||||
def test_export_memory(self, client):
|
||||
memory = {"version": "1.0", "facts": []}
|
||||
with patch("deerflow.agents.memory.updater.get_memory_data", return_value=memory) as mock_mem:
|
||||
result = client.export_memory()
|
||||
mock_mem.assert_called_once()
|
||||
assert result == memory
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stream / chat
|
||||
@@ -661,6 +668,14 @@ class TestSkillsManagement:
|
||||
|
||||
|
||||
class TestMemoryManagement:
|
||||
def test_import_memory(self, client):
|
||||
imported = {"version": "1.0", "facts": []}
|
||||
with patch("deerflow.agents.memory.updater.import_memory_data", return_value=imported) as mock_import:
|
||||
result = client.import_memory(imported)
|
||||
|
||||
mock_import.assert_called_once_with(imported)
|
||||
assert result == imported
|
||||
|
||||
def test_reload_memory(self, client):
|
||||
data = {"version": "1.0", "facts": []}
|
||||
with patch("deerflow.agents.memory.updater.reload_memory_data", return_value=data):
|
||||
|
||||
@@ -24,6 +24,54 @@ def _sample_memory(facts: list[dict] | None = None) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def test_export_memory_route_returns_current_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
exported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
"id": "fact_export",
|
||||
"content": "User prefers concise responses.",
|
||||
"category": "preference",
|
||||
"confidence": 0.9,
|
||||
"createdAt": "2026-03-20T00:00:00Z",
|
||||
"source": "thread-1",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
with patch("app.gateway.routers.memory.get_memory_data", return_value=exported_memory):
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/memory/export")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["facts"] == exported_memory["facts"]
|
||||
|
||||
|
||||
def test_import_memory_route_returns_imported_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
imported_memory = _sample_memory(
|
||||
facts=[
|
||||
{
|
||||
"id": "fact_import",
|
||||
"content": "User works on DeerFlow.",
|
||||
"category": "context",
|
||||
"confidence": 0.87,
|
||||
"createdAt": "2026-03-20T00:00:00Z",
|
||||
"source": "manual",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
with patch("app.gateway.routers.memory.import_memory_data", return_value=imported_memory):
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/memory/import", json=imported_memory)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["facts"] == imported_memory["facts"]
|
||||
|
||||
|
||||
def test_clear_memory_route_returns_cleared_memory() -> None:
|
||||
app = FastAPI()
|
||||
app.include_router(memory.router)
|
||||
|
||||
@@ -7,6 +7,7 @@ from deerflow.agents.memory.updater import (
|
||||
clear_memory_data,
|
||||
create_memory_fact,
|
||||
delete_memory_fact,
|
||||
import_memory_data,
|
||||
update_memory_fact,
|
||||
)
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
@@ -233,6 +234,31 @@ def test_delete_memory_fact_raises_for_unknown_id() -> None:
|
||||
raise AssertionError("Expected KeyError for missing fact id")
|
||||
|
||||
|
||||
def test_import_memory_data_saves_and_returns_imported_memory() -> None:
|
||||
imported_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
"id": "fact_import",
|
||||
"content": "User works on DeerFlow.",
|
||||
"category": "context",
|
||||
"confidence": 0.87,
|
||||
"createdAt": "2026-03-20T00:00:00Z",
|
||||
"source": "manual",
|
||||
}
|
||||
]
|
||||
)
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save.return_value = True
|
||||
mock_storage.load.return_value = imported_memory
|
||||
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
result = import_memory_data(imported_memory)
|
||||
|
||||
mock_storage.save.assert_called_once_with(imported_memory, None)
|
||||
mock_storage.load.assert_called_once_with(None)
|
||||
assert result == imported_memory
|
||||
|
||||
|
||||
def test_update_memory_fact_updates_only_matching_fact() -> None:
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
@@ -349,7 +375,7 @@ def test_update_memory_fact_rejects_invalid_confidence() -> None:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_text — LLM response content normalization
|
||||
# _extract_text - LLM response content normalization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -409,7 +435,7 @@ class TestExtractText:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# format_conversation_for_update — handles mixed list content
|
||||
# format_conversation_for_update - handles mixed list content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -439,7 +465,7 @@ class TestFormatConversationForUpdate:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_memory — structured LLM response handling
|
||||
# update_memory - structured LLM response handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user