fix(memory): case-insensitive fact deduplication and positive reinforcement detection (#1804)

* fix(memory): case-insensitive fact deduplication and positive reinforcement detection

Two fixes to the memory system:

1. _fact_content_key() now lowercases content before comparison, preventing
   semantically duplicate facts like "User prefers Python" and "user prefers
   python" from being stored separately.

2. Adds detect_reinforcement() to MemoryMiddleware (closes #1719), mirroring
   detect_correction(). When users signal approval ("yes exactly", "perfect",
   "完全正确", etc.), the memory updater now receives reinforcement_detected=True
   and injects a hint prompting the LLM to record confirmed preferences and
   behaviors with high confidence.

   Changes across the full signal path:
   - memory_middleware.py: _REINFORCEMENT_PATTERNS + detect_reinforcement()
   - queue.py: reinforcement_detected field in ConversationContext and add()
   - updater.py: reinforcement_detected param in update_memory() and
     update_memory_from_conversation(); builds reinforcement_hint alongside
     the existing correction_hint

Tests: 11 new tests covering deduplication, hint injection, and signal
detection (Chinese + English patterns, window boundary, conflict with correction).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(memory): address Copilot review comments on reinforcement detection

- Tighten _REINFORCEMENT_PATTERNS: remove 很好, require punctuation/end-of-string boundaries on remaining patterns, split this-is-good into stricter variants
- Suppress reinforcement_detected when correction_detected is true to avoid mixed-signal noise
- Use casefold() instead of lower() for Unicode-aware fact deduplication
- Add missing test coverage for reinforcement_detected OR merge and forwarding in queue

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
thefoolgy
2026-04-05 16:23:00 +08:00
committed by GitHub
parent 9ca68ffaaa
commit 8049785de6
6 changed files with 326 additions and 3 deletions
@@ -21,6 +21,7 @@ class ConversationContext:
timestamp: datetime = field(default_factory=datetime.utcnow)
agent_name: str | None = None
correction_detected: bool = False
reinforcement_detected: bool = False
class MemoryUpdateQueue:
@@ -44,6 +45,7 @@ class MemoryUpdateQueue:
messages: list[Any],
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
"""Add a conversation to the update queue.
@@ -52,6 +54,7 @@ class MemoryUpdateQueue:
messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
"""
config = get_memory_config()
if not config.enabled:
@@ -63,11 +66,13 @@ class MemoryUpdateQueue:
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
)
# Check if this thread already has a pending update
@@ -130,6 +135,7 @@ class MemoryUpdateQueue:
thread_id=context.thread_id,
agent_name=context.agent_name,
correction_detected=context.correction_detected,
reinforcement_detected=context.reinforcement_detected,
)
if success:
logger.info("Memory updated successfully for thread %s", context.thread_id)
@@ -246,7 +246,7 @@ def _fact_content_key(content: Any) -> str | None:
stripped = content.strip()
if not stripped:
return None
return stripped
return stripped.casefold()
class MemoryUpdater:
@@ -272,6 +272,7 @@ class MemoryUpdater:
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Update memory based on conversation messages.
@@ -280,6 +281,7 @@ class MemoryUpdater:
thread_id: Optional thread ID for tracking source.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
Returns:
True if update was successful, False otherwise.
@@ -310,6 +312,14 @@ class MemoryUpdater:
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
@@ -441,6 +451,7 @@ def update_memory_from_conversation(
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Convenience function to update memory from a conversation.
@@ -449,9 +460,10 @@ def update_memory_from_conversation(
thread_id: Optional thread ID.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
Returns:
True if successful, False otherwise.
"""
updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name, correction_detected)
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
@@ -29,6 +29,22 @@ _CORRECTION_PATTERNS = (
re.compile(r"改用"),
)
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
@@ -132,6 +148,29 @@ def detect_correction(messages: list[Any]) -> bool:
return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns.
Complements detect_correction() by identifying when the user confirms the
agent's approach was correct. This allows the memory system to record what
worked well, not just what went wrong.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale signals from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution.
@@ -196,12 +235,14 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
# Queue the filtered conversation for memory update
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue()
queue.add(
thread_id=thread_id,
messages=filtered_messages,
agent_name=self._agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
return None