fix(harness): preserve dynamic context across summarization (#2823)
This commit is contained in:
@@ -45,6 +45,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
|
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
|
||||||
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
|
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
|
||||||
|
_SUMMARY_MESSAGE_NAME = "summary"
|
||||||
|
|
||||||
|
|
||||||
def _extract_date(content: str) -> str | None:
|
def _extract_date(content: str) -> str | None:
|
||||||
@@ -72,6 +73,16 @@ def _last_injected_date(messages: list) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_dynamic_context_reminder(message: object) -> bool:
|
||||||
|
"""Return whether *message* is a hidden dynamic-context reminder."""
|
||||||
|
return isinstance(message, HumanMessage) and bool(message.additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY))
|
||||||
|
|
||||||
|
|
||||||
|
def _is_user_injection_target(message: object) -> bool:
|
||||||
|
"""Return whether *message* can receive a dynamic-context reminder."""
|
||||||
|
return isinstance(message, HumanMessage) and not is_dynamic_context_reminder(message) and message.name != _SUMMARY_MESSAGE_NAME
|
||||||
|
|
||||||
|
|
||||||
class DynamicContextMiddleware(AgentMiddleware):
|
class DynamicContextMiddleware(AgentMiddleware):
|
||||||
"""Inject memory and current date into HumanMessages as a <system-reminder>.
|
"""Inject memory and current date into HumanMessages as a <system-reminder>.
|
||||||
|
|
||||||
@@ -163,7 +174,7 @@ class DynamicContextMiddleware(AgentMiddleware):
|
|||||||
|
|
||||||
if last_date is None:
|
if last_date is None:
|
||||||
# ── First turn: inject full reminder as a separate HumanMessage ─────
|
# ── First turn: inject full reminder as a separate HumanMessage ─────
|
||||||
first_idx = next((i for i, m in enumerate(messages) if isinstance(m, HumanMessage)), None)
|
first_idx = next((i for i, m in enumerate(messages) if _is_user_injection_target(m)), None)
|
||||||
if first_idx is None:
|
if first_idx is None:
|
||||||
return None
|
return None
|
||||||
full_reminder = self._build_full_reminder()
|
full_reminder = self._build_full_reminder()
|
||||||
@@ -181,7 +192,7 @@ class DynamicContextMiddleware(AgentMiddleware):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# ── Midnight crossed: inject date-update reminder as a separate HumanMessage ──
|
# ── Midnight crossed: inject date-update reminder as a separate HumanMessage ──
|
||||||
last_human_idx = next((i for i in reversed(range(len(messages))) if isinstance(messages[i], HumanMessage)), None)
|
last_human_idx = next((i for i in reversed(range(len(messages))) if _is_user_injection_target(messages[i])), None)
|
||||||
if last_human_idx is None:
|
if last_human_idx is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from langgraph.config import get_config
|
|||||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
|
||||||
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -135,6 +136,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
||||||
|
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
|
||||||
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
||||||
summary = self._create_summary(messages_to_summarize)
|
summary = self._create_summary(messages_to_summarize)
|
||||||
new_messages = self._build_new_messages(summary)
|
new_messages = self._build_new_messages(summary)
|
||||||
@@ -160,6 +162,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
||||||
|
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
|
||||||
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
||||||
summary = await self._acreate_summary(messages_to_summarize)
|
summary = await self._acreate_summary(messages_to_summarize)
|
||||||
new_messages = self._build_new_messages(summary)
|
new_messages = self._build_new_messages(summary)
|
||||||
@@ -179,6 +182,24 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
|||||||
"""
|
"""
|
||||||
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
|
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
|
||||||
|
|
||||||
|
def _preserve_dynamic_context_reminders(
|
||||||
|
self,
|
||||||
|
messages_to_summarize: list[AnyMessage],
|
||||||
|
preserved_messages: list[AnyMessage],
|
||||||
|
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
||||||
|
"""Keep hidden dynamic-context reminders out of summary compression.
|
||||||
|
|
||||||
|
These reminders carry the current date and optional memory. If summarization
|
||||||
|
removes them, DynamicContextMiddleware can mistake the summary HumanMessage
|
||||||
|
for the first user message and inject the reminder in the wrong place.
|
||||||
|
"""
|
||||||
|
reminders = [msg for msg in messages_to_summarize if is_dynamic_context_reminder(msg)]
|
||||||
|
if not reminders:
|
||||||
|
return messages_to_summarize, preserved_messages
|
||||||
|
|
||||||
|
remaining = [msg for msg in messages_to_summarize if not is_dynamic_context_reminder(msg)]
|
||||||
|
return remaining, reminders + preserved_messages
|
||||||
|
|
||||||
def _partition_with_skill_rescue(
|
def _partition_with_skill_rescue(
|
||||||
self,
|
self,
|
||||||
messages: list[AnyMessage],
|
messages: list[AnyMessage],
|
||||||
|
|||||||
@@ -139,6 +139,30 @@ def test_injects_only_into_first_human_message_not_later_ones():
|
|||||||
assert all(m.id != "msg-2" for m in msgs)
|
assert all(m.id != "msg-2" for m in msgs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_summary_human_message_is_not_used_as_injection_target():
|
||||||
|
"""After summarization, the synthetic summary HumanMessage is not a user turn."""
|
||||||
|
mw = _make_middleware()
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
HumanMessage(content="Here is a summary of the conversation to date:\n\n...", id="summary-1", name="summary"),
|
||||||
|
AIMessage(content="Earlier reply"),
|
||||||
|
HumanMessage(content="Follow-up", id="msg-2"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with mock.patch("deerflow.agents.lead_agent.prompt._get_memory_context", return_value=""), mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||||
|
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||||
|
result = mw.before_agent(state, _fake_runtime())
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
msgs = result["messages"]
|
||||||
|
assert len(msgs) == 2
|
||||||
|
assert msgs[0].id == "msg-2"
|
||||||
|
assert msgs[0].additional_kwargs.get(_DYNAMIC_CONTEXT_REMINDER_KEY) is True
|
||||||
|
assert msgs[1].id == "msg-2__user"
|
||||||
|
assert msgs[1].content == "Follow-up"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Edge cases
|
# Edge cases
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
from unittest import mock
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
|
||||||
|
|
||||||
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||||
|
from deerflow.agents.middlewares.dynamic_context_middleware import _DYNAMIC_CONTEXT_REMINDER_KEY, DynamicContextMiddleware
|
||||||
from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent
|
from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent
|
||||||
from deerflow.config.memory_config import MemoryConfig
|
from deerflow.config.memory_config import MemoryConfig
|
||||||
|
|
||||||
@@ -20,6 +22,14 @@ def _messages() -> list:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
|
||||||
|
return HumanMessage(
|
||||||
|
content="<system-reminder>\n<current_date>2026-05-08, Friday</current_date>\n</system-reminder>",
|
||||||
|
id=msg_id,
|
||||||
|
additional_kwargs={"hide_from_ui": True, _DYNAMIC_CONTEXT_REMINDER_KEY: True},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
|
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
|
||||||
context = {}
|
context = {}
|
||||||
if thread_id is not None:
|
if thread_id is not None:
|
||||||
@@ -98,6 +108,38 @@ def test_before_summarization_hook_receives_messages_before_compression() -> Non
|
|||||||
assert result["messages"][1].content.startswith("Here is a summary")
|
assert result["messages"][1].content.startswith("Here is a summary")
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_context_reminder_is_preserved_across_summarization() -> None:
|
||||||
|
captured: list[SummarizationEvent] = []
|
||||||
|
middleware = _middleware(before_summarization=[captured.append])
|
||||||
|
reminder = _dynamic_context_reminder()
|
||||||
|
|
||||||
|
result = middleware.before_model(
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
reminder,
|
||||||
|
HumanMessage(content="user-1"),
|
||||||
|
AIMessage(content="assistant-1"),
|
||||||
|
HumanMessage(content="user-2"),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
_runtime(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured) == 1
|
||||||
|
assert [message.content for message in captured[0].messages_to_summarize] == ["user-1"]
|
||||||
|
assert captured[0].preserved_messages[0] is reminder
|
||||||
|
|
||||||
|
emitted = result["messages"]
|
||||||
|
assert isinstance(emitted[0], RemoveMessage)
|
||||||
|
assert emitted[1].name == "summary"
|
||||||
|
assert emitted[2] is reminder
|
||||||
|
|
||||||
|
followup_state = {"messages": [*emitted[1:], HumanMessage(content="Follow-up", id="msg-2")]}
|
||||||
|
with mock.patch("deerflow.agents.middlewares.dynamic_context_middleware.datetime") as mock_dt:
|
||||||
|
mock_dt.now.return_value.strftime.return_value = "2026-05-08, Friday"
|
||||||
|
assert DynamicContextMiddleware().before_agent(followup_state, _runtime()) is None
|
||||||
|
|
||||||
|
|
||||||
def test_before_summarization_hook_not_called_when_threshold_not_met() -> None:
|
def test_before_summarization_hook_not_called_when_threshold_not_met() -> None:
|
||||||
captured: list[SummarizationEvent] = []
|
captured: list[SummarizationEvent] = []
|
||||||
middleware = _middleware(before_summarization=[captured.append], trigger=("messages", 10))
|
middleware = _middleware(before_summarization=[captured.append], trigger=("messages", 10))
|
||||||
|
|||||||
Reference in New Issue
Block a user