mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 01:45:58 +00:00
fix(memory): parse wrapped memory update json responses (#3252)
* fix(memory): parse wrapped memory update json responses * test(memory): format wrapped response coverage * fix(memory): guard malformed nested memory facts * fix(memory): require full update object when parsing responses * fix(memory): fail closed on unsafe partial removals * style(memory): format updater tests
This commit is contained in:
@@ -227,6 +227,110 @@ def _extract_text(content: Any) -> str:
|
|||||||
return str(content)
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
_REQUIRED_MEMORY_UPDATE_TOP_LEVEL_KEYS = frozenset({"user", "history", "newFacts", "factsToRemove"})
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_memory_update_fact(fact: Any) -> dict[str, Any] | None:
|
||||||
|
"""Normalize a single fact entry from a model-produced memory update."""
|
||||||
|
if not isinstance(fact, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
raw_content = fact.get("content")
|
||||||
|
if not isinstance(raw_content, str):
|
||||||
|
return None
|
||||||
|
content = raw_content.strip()
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
raw_category = fact.get("category")
|
||||||
|
category = raw_category.strip() if isinstance(raw_category, str) and raw_category.strip() else "context"
|
||||||
|
|
||||||
|
raw_confidence = fact.get("confidence", 0.5)
|
||||||
|
if isinstance(raw_confidence, bool):
|
||||||
|
return None
|
||||||
|
if isinstance(raw_confidence, str):
|
||||||
|
raw_confidence = raw_confidence.strip()
|
||||||
|
if not raw_confidence:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
raw_confidence = float(raw_confidence)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
elif isinstance(raw_confidence, (int, float)):
|
||||||
|
raw_confidence = float(raw_confidence)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not math.isfinite(raw_confidence):
|
||||||
|
return None
|
||||||
|
|
||||||
|
normalized_fact = {
|
||||||
|
"content": content,
|
||||||
|
"category": category,
|
||||||
|
"confidence": raw_confidence,
|
||||||
|
}
|
||||||
|
source_error = fact.get("sourceError")
|
||||||
|
if isinstance(source_error, str):
|
||||||
|
normalized_source_error = source_error.strip()
|
||||||
|
if normalized_source_error:
|
||||||
|
normalized_fact["sourceError"] = normalized_source_error
|
||||||
|
|
||||||
|
return normalized_fact
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_memory_update_data(update_data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Coerce parsed memory update data into the shape consumed by _apply_updates."""
|
||||||
|
user = update_data.get("user")
|
||||||
|
history = update_data.get("history")
|
||||||
|
new_facts = update_data.get("newFacts")
|
||||||
|
facts_to_remove = update_data.get("factsToRemove")
|
||||||
|
normalized_facts_to_remove = [fact_id for fact_id in facts_to_remove if isinstance(fact_id, str)] if isinstance(facts_to_remove, list) else []
|
||||||
|
normalized_new_facts = []
|
||||||
|
dropped_new_fact = not isinstance(new_facts, list)
|
||||||
|
if isinstance(new_facts, list):
|
||||||
|
for fact in new_facts:
|
||||||
|
normalized_fact = _normalize_memory_update_fact(fact)
|
||||||
|
if normalized_fact is not None:
|
||||||
|
normalized_new_facts.append(normalized_fact)
|
||||||
|
else:
|
||||||
|
dropped_new_fact = True
|
||||||
|
|
||||||
|
if normalized_facts_to_remove and dropped_new_fact:
|
||||||
|
raise json.JSONDecodeError(
|
||||||
|
"Unsafe partial memory update: factsToRemove with malformed newFacts",
|
||||||
|
json.dumps(update_data, ensure_ascii=False),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user": user if isinstance(user, dict) else {},
|
||||||
|
"history": history if isinstance(history, dict) else {},
|
||||||
|
"newFacts": normalized_new_facts,
|
||||||
|
"factsToRemove": normalized_facts_to_remove,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_memory_update_response(response_content: Any) -> dict[str, Any]:
|
||||||
|
"""Parse the first valid memory-update JSON object from an LLM response.
|
||||||
|
|
||||||
|
Some providers may wrap JSON in thinking traces, prose, or markdown fences
|
||||||
|
even when prompted to return JSON only. This parser accepts safely
|
||||||
|
extractable JSON objects but does not repair truncated or malformed JSON.
|
||||||
|
"""
|
||||||
|
response_text = _extract_text(response_content).strip()
|
||||||
|
decoder = json.JSONDecoder()
|
||||||
|
|
||||||
|
for match in re.finditer(r"\{", response_text):
|
||||||
|
try:
|
||||||
|
parsed, _end = decoder.raw_decode(response_text[match.start() :])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if isinstance(parsed, dict) and _REQUIRED_MEMORY_UPDATE_TOP_LEVEL_KEYS.issubset(parsed):
|
||||||
|
return _normalize_memory_update_data(parsed)
|
||||||
|
|
||||||
|
raise json.JSONDecodeError("No valid memory update JSON object found", response_text, 0)
|
||||||
|
|
||||||
|
|
||||||
# Matches sentences that describe a file-upload *event* rather than general
|
# Matches sentences that describe a file-upload *event* rather than general
|
||||||
# file-related work. Deliberately narrow to avoid removing legitimate facts
|
# file-related work. Deliberately narrow to avoid removing legitimate facts
|
||||||
# such as "User works with CSV files" or "prefers PDF export".
|
# such as "User works with CSV files" or "prefers PDF export".
|
||||||
@@ -353,13 +457,7 @@ class MemoryUpdater:
|
|||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Parse the model response, apply updates, and persist memory."""
|
"""Parse the model response, apply updates, and persist memory."""
|
||||||
response_text = _extract_text(response_content).strip()
|
update_data = _parse_memory_update_response(response_content)
|
||||||
|
|
||||||
if response_text.startswith("```"):
|
|
||||||
lines = response_text.split("\n")
|
|
||||||
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
|
||||||
|
|
||||||
update_data = json.loads(response_text)
|
|
||||||
# Deep-copy before in-place mutation so a subsequent save() failure
|
# Deep-copy before in-place mutation so a subsequent save() failure
|
||||||
# cannot corrupt the still-cached original object reference.
|
# cannot corrupt the still-cached original object reference.
|
||||||
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
|
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
|
||||||
|
|||||||
@@ -563,6 +563,28 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
model.invoke = MagicMock(return_value=response)
|
model.invoke = MagicMock(return_value=response)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def _run_update_with_response(self, content):
|
||||||
|
updater = MemoryUpdater()
|
||||||
|
mock_storage = MagicMock()
|
||||||
|
mock_storage.save = MagicMock(return_value=True)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(updater, "_get_model", return_value=self._make_mock_model(content)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True, fact_confidence_threshold=0.7, max_facts=100)),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||||
|
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage),
|
||||||
|
):
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.type = "human"
|
||||||
|
msg.content = "Remember that I prefer concise updates."
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.type = "ai"
|
||||||
|
ai_msg.content = "Got it."
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
result = updater.update_memory([msg, ai_msg], thread_id="thread-memory")
|
||||||
|
|
||||||
|
return result, mock_storage
|
||||||
|
|
||||||
def test_string_response_parses(self):
|
def test_string_response_parses(self):
|
||||||
updater = MemoryUpdater()
|
updater = MemoryUpdater()
|
||||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||||
@@ -609,6 +631,82 @@ class TestUpdateMemoryStructuredResponse:
|
|||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
def test_wrapped_json_responses_parse(self):
|
||||||
|
"""Memory update should tolerate provider wrappers around valid JSON."""
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [{"content": "User prefers concise updates", "category": "preference", "confidence": 0.9}], "factsToRemove": []}'
|
||||||
|
response_variants = [
|
||||||
|
f"<think>Analyze the conversation first.</think>\n{valid_json}",
|
||||||
|
f"<think>Analyze the conversation first.\n{valid_json}",
|
||||||
|
f"Here is the memory update:\n{valid_json}",
|
||||||
|
f"{valid_json}\nDone.",
|
||||||
|
f"```json\n{valid_json}\n```",
|
||||||
|
]
|
||||||
|
|
||||||
|
for content in response_variants:
|
||||||
|
result, mock_storage = self._run_update_with_response(content)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
saved_memory = mock_storage.save.call_args.args[0]
|
||||||
|
assert saved_memory["facts"][0]["content"] == "User prefers concise updates"
|
||||||
|
|
||||||
|
def test_ignores_unrelated_json_before_memory_update(self):
|
||||||
|
"""Parser should not select unrelated JSON objects before the memory update."""
|
||||||
|
valid_json = '{"user": {}, "history": {}, "newFacts": [{"content": "Remember the actual update", "category": "context", "confidence": 0.9}], "factsToRemove": []}'
|
||||||
|
response = f'Example object: {{"user": "alice"}}\nActual memory update:\n{valid_json}'
|
||||||
|
|
||||||
|
result, mock_storage = self._run_update_with_response(response)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
saved_memory = mock_storage.save.call_args.args[0]
|
||||||
|
assert saved_memory["facts"][0]["content"] == "Remember the actual update"
|
||||||
|
|
||||||
|
def test_invalid_json_response_is_skipped_without_saving(self):
|
||||||
|
"""Truncated JSON should remain a safe skipped update, not guessed repair."""
|
||||||
|
result, mock_storage = self._run_update_with_response('{"user": {}, "history": {}, "newFacts": [')
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
mock_storage.save.assert_not_called()
|
||||||
|
|
||||||
|
def test_schema_guard_ignores_invalid_update_fields(self):
|
||||||
|
"""Parsed JSON with bad field types should not break the memory update."""
|
||||||
|
response = '{"user": "bad", "history": [], "newFacts": ["bad", {"content": "User works on DeerFlow", "category": "context", "confidence": 0.91}], "factsToRemove": "bad"}'
|
||||||
|
|
||||||
|
result, mock_storage = self._run_update_with_response(response)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
saved_memory = mock_storage.save.call_args.args[0]
|
||||||
|
assert [fact["content"] for fact in saved_memory["facts"]] == ["User works on DeerFlow"]
|
||||||
|
|
||||||
|
def test_fact_schema_guard_coerces_and_filters_nested_fields(self):
|
||||||
|
"""Malformed fact entries should be normalized per fact, not fail the whole update."""
|
||||||
|
response = (
|
||||||
|
'{"user": {}, "history": {}, "newFacts": ['
|
||||||
|
'{"content": " User likes async updates ", "category": 9, "confidence": "0.91", "sourceError": " parse issue "}, '
|
||||||
|
'{"content": "skip invalid confidence", "category": "context", "confidence": "high"}, '
|
||||||
|
'{"content": 12, "category": "context", "confidence": 0.9}, '
|
||||||
|
'{"content": " ", "category": "context", "confidence": 0.9}'
|
||||||
|
'], "factsToRemove": []}'
|
||||||
|
)
|
||||||
|
|
||||||
|
result, mock_storage = self._run_update_with_response(response)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
saved_memory = mock_storage.save.call_args.args[0]
|
||||||
|
assert len(saved_memory["facts"]) == 1
|
||||||
|
assert saved_memory["facts"][0]["content"] == "User likes async updates"
|
||||||
|
assert saved_memory["facts"][0]["category"] == "context"
|
||||||
|
assert saved_memory["facts"][0]["confidence"] == 0.91
|
||||||
|
assert saved_memory["facts"][0]["sourceError"] == "parse issue"
|
||||||
|
|
||||||
|
def test_malformed_replacement_update_fails_closed(self):
|
||||||
|
"""Malformed replacement facts should not turn remove+add into delete-only."""
|
||||||
|
response = '{"user": {}, "history": {}, "newFacts": [{"content": "replacement fact", "category": "context", "confidence": "bad"}], "factsToRemove": ["fact_old"]}'
|
||||||
|
|
||||||
|
result, mock_storage = self._run_update_with_response(response)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
mock_storage.save.assert_not_called()
|
||||||
|
|
||||||
def test_async_update_memory_delegates_to_sync(self):
|
def test_async_update_memory_delegates_to_sync(self):
|
||||||
"""aupdate_memory should delegate to sync _do_update_memory_sync via to_thread."""
|
"""aupdate_memory should delegate to sync _do_update_memory_sync via to_thread."""
|
||||||
updater = MemoryUpdater()
|
updater = MemoryUpdater()
|
||||||
|
|||||||
Reference in New Issue
Block a user