fix(middleware): avoid rescuing non-skill tool outputs during summarization (#2458)

* fix(middelware): narrow skill rescue to skill-related tool outputs

* fix(summarization): address skill rescue review feedback

* fix: wire summarization skill rescue config

* fix: remove dead skill tool helper

* fix(lint): fix format

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Nan Gao
2026-04-24 15:19:46 +02:00
committed by GitHub
parent c2332bb790
commit f9ff3a698d
7 changed files with 629 additions and 9 deletions
@@ -207,3 +207,27 @@ def test_create_summarization_middleware_registers_memory_flush_hook_when_memory
lead_agent_module._create_summarization_middleware()
assert captured["before_summarization"] == [lead_agent_module.memory_flush_hook]
def test_create_summarization_middleware_passes_skill_read_tool_names(monkeypatch):
app_config = _make_app_config([_make_model("default-model", supports_thinking=False)])
monkeypatch.setattr(
lead_agent_module,
"get_summarization_config",
lambda: SummarizationConfig(enabled=True, skill_file_read_tool_names=["read_file", "cat"]),
)
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: object())
captured: dict[str, object] = {}
def _fake_middleware(**kwargs):
captured.update(kwargs)
return kwargs
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", _fake_middleware)
lead_agent_module._create_summarization_middleware()
assert captured["skill_file_read_tool_names"] == ["read_file", "cat"]
+325 -2
View File
@@ -4,7 +4,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
from deerflow.agents.memory.summarization_hook import memory_flush_hook
from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent
@@ -29,7 +29,16 @@ def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None)
return SimpleNamespace(context=context)
def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("messages", 2)) -> DeerFlowSummarizationMiddleware:
def _middleware(
*,
before_summarization=None,
trigger=("messages", 4),
keep=("messages", 2),
skill_file_read_tool_names=None,
preserve_recent_skill_count: int = 0,
preserve_recent_skill_tokens: int = 0,
preserve_recent_skill_tokens_per_skill: int = 0,
) -> DeerFlowSummarizationMiddleware:
model = MagicMock()
model.invoke.return_value = SimpleNamespace(text="compressed summary")
return DeerFlowSummarizationMiddleware(
@@ -38,9 +47,34 @@ def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("me
keep=keep,
token_counter=len,
before_summarization=before_summarization,
skill_file_read_tool_names=skill_file_read_tool_names,
preserve_recent_skill_count=preserve_recent_skill_count,
preserve_recent_skill_tokens=preserve_recent_skill_tokens,
preserve_recent_skill_tokens_per_skill=preserve_recent_skill_tokens_per_skill,
)
def _skill_read_call(tool_id: str, skill: str) -> dict:
return {
"name": "read_file",
"id": tool_id,
"args": {"path": f"/mnt/skills/public/{skill}/SKILL.md"},
}
def _skill_conversation() -> list:
return [
HumanMessage(content="u1"),
AIMessage(content="", tool_calls=[_skill_read_call("t1", "alpha")]),
ToolMessage(content="alpha skill body", tool_call_id="t1"),
HumanMessage(content="u2"),
AIMessage(content="", tool_calls=[_skill_read_call("t2", "beta")]),
ToolMessage(content="beta skill body", tool_call_id="t2"),
HumanMessage(content="u3"),
AIMessage(content="final"),
]
def test_before_summarization_hook_receives_messages_before_compression() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(before_summarization=[captured.append])
@@ -167,6 +201,295 @@ def test_memory_flush_hook_enqueues_filtered_messages_and_flushes(monkeypatch: p
assert add_kwargs["reinforcement_detected"] is False
def test_skill_rescue_keeps_recent_skill_reads_out_of_summary() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
result = middleware.before_model({"messages": _skill_conversation()}, _runtime())
assert len(captured) == 1
summarized_ids = {id(m) for m in captured[0].messages_to_summarize}
preserved = captured[0].preserved_messages
# Both skill-read bundles should be rescued into preserved_messages,
# tool_call ↔ tool_result pairs stay intact.
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
assert any(isinstance(m, ToolMessage) and m.content == "beta skill body" for m in preserved)
for m in preserved:
if isinstance(m, ToolMessage) and m.content in {"alpha skill body", "beta skill body"}:
assert id(m) not in summarized_ids
# Preserved output order: rescued bundles first, then the tail kept by parent cutoff.
contents = [getattr(m, "content", None) for m in preserved]
assert contents[-2:] == ["u3", "final"]
# The final emitted state should start with RemoveMessage + summary, then preserved messages.
emitted = result["messages"]
assert isinstance(emitted[0], RemoveMessage)
assert emitted[1].content.startswith("Here is a summary")
assert list(emitted[-2:]) == list(preserved[-2:])
def test_skill_rescue_respects_count_budget() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=1,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
middleware.before_model({"messages": _skill_conversation()}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
# Newest skill (beta) rescued; older skill (alpha) falls into summary.
assert any(isinstance(m, ToolMessage) and m.content == "beta skill body" for m in preserved)
assert not any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in summarized)
def test_skill_rescue_uses_injected_skills_container_path() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
middleware._skills_container_path = "/custom/skills"
messages = [
HumanMessage(content="u1"),
AIMessage(content="", tool_calls=[{"name": "read_file", "id": "t1", "args": {"path": "/custom/skills/demo/SKILL.md"}}]),
ToolMessage(content="demo skill body", tool_call_id="t1"),
HumanMessage(content="u2"),
AIMessage(content="final"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
assert any(isinstance(m, ToolMessage) and m.content == "demo skill body" for m in preserved)
def test_skill_rescue_uses_configured_skill_read_tool_names() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
skill_file_read_tool_names=["custom_read"],
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
middleware._skills_container_path = "/custom/skills"
messages = [
HumanMessage(content="u1"),
AIMessage(content="", tool_calls=[{"name": "custom_read", "id": "t1", "args": {"path": "/custom/skills/demo/SKILL.md"}}]),
ToolMessage(content="demo skill body", tool_call_id="t1"),
HumanMessage(content="u2"),
AIMessage(content="final"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
assert any(isinstance(m, ToolMessage) and m.content == "demo skill body" for m in preserved)
def test_skill_rescue_respects_per_skill_token_cap() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
# token_counter=len counts one token per message; per-skill cap of 0 rejects every bundle.
preserve_recent_skill_tokens_per_skill=0,
)
middleware.before_model({"messages": _skill_conversation()}, _runtime())
preserved = captured[0].preserved_messages
assert not any(isinstance(m, ToolMessage) and m.content in {"alpha skill body", "beta skill body"} for m in preserved)
def test_skill_rescue_disabled_when_count_zero() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=0,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
middleware.before_model({"messages": _skill_conversation()}, _runtime())
preserved = captured[0].preserved_messages
assert not any(isinstance(m, ToolMessage) for m in preserved)
def test_skill_rescue_ignores_non_skill_tool_reads() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="",
tool_calls=[{"name": "read_file", "id": "t1", "args": {"path": "/mnt/user-data/workspace/notes.md"}}],
),
ToolMessage(content="user notes", tool_call_id="t1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
assert not any(isinstance(m, ToolMessage) and m.content == "user notes" for m in preserved)
def test_skill_rescue_does_not_preserve_non_skill_outputs_from_mixed_tool_calls() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="",
tool_calls=[
_skill_read_call("skill-1", "alpha"),
{"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}},
],
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
ToolMessage(content="user notes", tool_call_id="file-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"]
assert [tc["id"] for tc in summarized_ai.tool_calls] == ["file-1"]
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
assert not any(isinstance(m, ToolMessage) and m.content == "user notes" for m in preserved)
assert any(isinstance(m, ToolMessage) and m.content == "user notes" for m in summarized)
def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="reading skill and notes",
tool_calls=[
_skill_read_call("skill-1", "alpha"),
{"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}},
],
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
ToolMessage(content="user notes", tool_call_id="file-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
assert preserved_ai.content == ""
assert summarized_ai.content == "reading skill and notes"
def test_skill_rescue_only_preserves_skill_calls_with_matched_tool_results() -> None:
captured: list[SummarizationEvent] = []
middleware = _middleware(
before_summarization=[captured.append],
trigger=("messages", 4),
keep=("messages", 2),
preserve_recent_skill_count=5,
preserve_recent_skill_tokens=10_000,
preserve_recent_skill_tokens_per_skill=10_000,
)
messages = [
HumanMessage(content="u1"),
AIMessage(
content="",
tool_calls=[
_skill_read_call("skill-1", "alpha"),
_skill_read_call("skill-2", "beta"),
],
),
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
HumanMessage(content="u2"),
AIMessage(content="done"),
]
middleware.before_model({"messages": messages}, _runtime())
preserved = captured[0].preserved_messages
summarized = captured[0].messages_to_summarize
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"]
assert [tc["id"] for tc in summarized_ai.tool_calls] == ["skill-2"]
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
assert not any(isinstance(m, ToolMessage) and getattr(m, "tool_call_id", None) == "skill-2" for m in preserved)
def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))