mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
fix(memory): case-insensitive fact deduplication and positive reinforcement detection (#1804)
* fix(memory): case-insensitive fact deduplication and positive reinforcement detection Two fixes to the memory system: 1. _fact_content_key() now lowercases content before comparison, preventing semantically duplicate facts like "User prefers Python" and "user prefers python" from being stored separately. 2. Adds detect_reinforcement() to MemoryMiddleware (closes #1719), mirroring detect_correction(). When users signal approval ("yes exactly", "perfect", "完全正确", etc.), the memory updater now receives reinforcement_detected=True and injects a hint prompting the LLM to record confirmed preferences and behaviors with high confidence. Changes across the full signal path: - memory_middleware.py: _REINFORCEMENT_PATTERNS + detect_reinforcement() - queue.py: reinforcement_detected field in ConversationContext and add() - updater.py: reinforcement_detected param in update_memory() and update_memory_from_conversation(); builds reinforcement_hint alongside the existing correction_hint Tests: 11 new tests covering deduplication, hint injection, and signal detection (Chinese + English patterns, window boundary, conflict with correction). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(memory): address Copilot review comments on reinforcement detection - Tighten _REINFORCEMENT_PATTERNS: remove 很好, require punctuation/end-of-string boundaries on remaining patterns, split this-is-good into stricter variants - Suppress reinforcement_detected when correction_detected is true to avoid mixed-signal noise - Use casefold() instead of lower() for Unicode-aware fact deduplication - Add missing test coverage for reinforcement_detected OR merge and forwarding in queue --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -47,4 +47,45 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
||||
thread_id="thread-1",
|
||||
agent_name="lead_agent",
|
||||
correction_detected=True,
|
||||
reinforcement_detected=False,
|
||||
)
|
||||
|
||||
|
||||
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)
|
||||
queue.add(thread_id="thread-1", messages=["second"], reinforcement_detected=False)
|
||||
|
||||
assert len(queue._queue) == 1
|
||||
assert queue._queue[0].messages == ["second"]
|
||||
assert queue._queue[0].reinforcement_detected is True
|
||||
|
||||
|
||||
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue._queue = [
|
||||
ConversationContext(
|
||||
thread_id="thread-1",
|
||||
messages=["conversation"],
|
||||
agent_name="lead_agent",
|
||||
reinforcement_detected=True,
|
||||
)
|
||||
]
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.update_memory.return_value = True
|
||||
|
||||
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
|
||||
queue._process_queue()
|
||||
|
||||
mock_updater.update_memory.assert_called_once_with(
|
||||
messages=["conversation"],
|
||||
thread_id="thread-1",
|
||||
agent_name="lead_agent",
|
||||
correction_detected=False,
|
||||
reinforcement_detected=True,
|
||||
)
|
||||
|
||||
@@ -619,3 +619,156 @@ class TestUpdateMemoryStructuredResponse:
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" not in prompt
|
||||
|
||||
|
||||
class TestFactDeduplicationCaseInsensitive:
|
||||
"""Tests that fact deduplication is case-insensitive."""
|
||||
|
||||
def test_duplicate_fact_different_case_not_stored(self):
|
||||
updater = MemoryUpdater()
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
"id": "fact_1",
|
||||
"content": "User prefers Python",
|
||||
"category": "preference",
|
||||
"confidence": 0.9,
|
||||
"createdAt": "2026-01-01T00:00:00Z",
|
||||
"source": "thread-a",
|
||||
},
|
||||
]
|
||||
)
|
||||
# Same fact with different casing should be treated as duplicate
|
||||
update_data = {
|
||||
"factsToRemove": [],
|
||||
"newFacts": [
|
||||
{"content": "user prefers python", "category": "preference", "confidence": 0.95},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
# Should still have only 1 fact (duplicate rejected)
|
||||
assert len(result["facts"]) == 1
|
||||
assert result["facts"][0]["content"] == "User prefers Python"
|
||||
|
||||
def test_unique_fact_different_case_and_content_stored(self):
|
||||
updater = MemoryUpdater()
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
"id": "fact_1",
|
||||
"content": "User prefers Python",
|
||||
"category": "preference",
|
||||
"confidence": 0.9,
|
||||
"createdAt": "2026-01-01T00:00:00Z",
|
||||
"source": "thread-a",
|
||||
},
|
||||
]
|
||||
)
|
||||
update_data = {
|
||||
"factsToRemove": [],
|
||||
"newFacts": [
|
||||
{"content": "User prefers Go", "category": "preference", "confidence": 0.85},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
assert len(result["facts"]) == 2
|
||||
|
||||
|
||||
class TestReinforcementHint:
|
||||
"""Tests that reinforcement_detected injects the correct hint into the prompt."""
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_model(json_response: str):
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = f"```json\n{json_response}\n```"
|
||||
model.invoke.return_value = response
|
||||
return model
|
||||
|
||||
def test_reinforcement_hint_injected_when_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Yes, exactly! That's what I needed."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Great to hear!"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
def test_reinforcement_hint_absent_when_not_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Tell me more."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Sure."
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Positive reinforcement signals were detected" not in prompt
|
||||
|
||||
def test_both_hints_present_when_both_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "No wait, that's wrong. Actually yes, exactly right."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Got it."
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
@@ -10,7 +10,7 @@ persisting in long-term memory:
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
|
||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction
|
||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -270,3 +270,73 @@ class TestStripUploadMentionsFromMemory:
|
||||
mem = {"user": {}, "history": {}, "facts": []}
|
||||
result = _strip_upload_mentions_from_memory(mem)
|
||||
assert result == {"user": {}, "history": {}, "facts": []}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# detect_reinforcement
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestDetectReinforcement:
|
||||
def test_detects_english_reinforcement_signal(self):
|
||||
msgs = [
|
||||
_human("Can you summarise it in bullet points?"),
|
||||
_ai("Here are the key points: ..."),
|
||||
_human("Yes, exactly! That's what I needed."),
|
||||
_ai("Glad it helped."),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is True
|
||||
|
||||
def test_detects_perfect_signal(self):
|
||||
msgs = [
|
||||
_human("Write it more concisely."),
|
||||
_ai("Here is the concise version."),
|
||||
_human("Perfect."),
|
||||
_ai("Great!"),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is True
|
||||
|
||||
def test_detects_chinese_reinforcement_signal(self):
|
||||
msgs = [
|
||||
_human("帮我用要点来总结"),
|
||||
_ai("好的,要点如下:..."),
|
||||
_human("完全正确,就是这个意思"),
|
||||
_ai("很高兴能帮到你"),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is True
|
||||
|
||||
def test_returns_false_without_signal(self):
|
||||
msgs = [
|
||||
_human("What does this function do?"),
|
||||
_ai("It processes the input data."),
|
||||
_human("Can you show me an example?"),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is False
|
||||
|
||||
def test_only_checks_recent_messages(self):
|
||||
# Reinforcement signal buried beyond the -6 window should not trigger
|
||||
msgs = [
|
||||
_human("Yes, exactly right."),
|
||||
_ai("Noted."),
|
||||
_human("Let's discuss tests."),
|
||||
_ai("Sure."),
|
||||
_human("What about linting?"),
|
||||
_ai("Use ruff."),
|
||||
_human("And formatting?"),
|
||||
_ai("Use make format."),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is False
|
||||
|
||||
def test_does_not_conflict_with_correction(self):
|
||||
# A message can trigger correction but not reinforcement
|
||||
msgs = [
|
||||
_human("That's wrong, try again."),
|
||||
_ai("Corrected."),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is False
|
||||
|
||||
Reference in New Issue
Block a user