mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 07:26:50 +00:00
fix(middleware): avoid rescuing non-skill tool outputs during summarization (#2458)
* fix(middelware): narrow skill rescue to skill-related tool outputs * fix(summarization): address skill rescue review feedback * fix: wire summarization skill rescue config * fix: remove dead skill tool helper * fix(lint): fix format --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -207,3 +207,27 @@ def test_create_summarization_middleware_registers_memory_flush_hook_when_memory
|
||||
lead_agent_module._create_summarization_middleware()
|
||||
|
||||
assert captured["before_summarization"] == [lead_agent_module.memory_flush_hook]
|
||||
|
||||
|
||||
def test_create_summarization_middleware_passes_skill_read_tool_names(monkeypatch):
|
||||
app_config = _make_app_config([_make_model("default-model", supports_thinking=False)])
|
||||
monkeypatch.setattr(
|
||||
lead_agent_module,
|
||||
"get_summarization_config",
|
||||
lambda: SummarizationConfig(enabled=True, skill_file_read_tool_names=["read_file", "cat"]),
|
||||
)
|
||||
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config)
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: object())
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_middleware(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return kwargs
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", _fake_middleware)
|
||||
|
||||
lead_agent_module._create_summarization_middleware()
|
||||
|
||||
assert captured["skill_file_read_tool_names"] == ["read_file", "cat"]
|
||||
|
||||
@@ -4,7 +4,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
|
||||
|
||||
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||
from deerflow.agents.middlewares.summarization_middleware import DeerFlowSummarizationMiddleware, SummarizationEvent
|
||||
@@ -29,7 +29,16 @@ def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None)
|
||||
return SimpleNamespace(context=context)
|
||||
|
||||
|
||||
def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("messages", 2)) -> DeerFlowSummarizationMiddleware:
|
||||
def _middleware(
|
||||
*,
|
||||
before_summarization=None,
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
skill_file_read_tool_names=None,
|
||||
preserve_recent_skill_count: int = 0,
|
||||
preserve_recent_skill_tokens: int = 0,
|
||||
preserve_recent_skill_tokens_per_skill: int = 0,
|
||||
) -> DeerFlowSummarizationMiddleware:
|
||||
model = MagicMock()
|
||||
model.invoke.return_value = SimpleNamespace(text="compressed summary")
|
||||
return DeerFlowSummarizationMiddleware(
|
||||
@@ -38,9 +47,34 @@ def _middleware(*, before_summarization=None, trigger=("messages", 4), keep=("me
|
||||
keep=keep,
|
||||
token_counter=len,
|
||||
before_summarization=before_summarization,
|
||||
skill_file_read_tool_names=skill_file_read_tool_names,
|
||||
preserve_recent_skill_count=preserve_recent_skill_count,
|
||||
preserve_recent_skill_tokens=preserve_recent_skill_tokens,
|
||||
preserve_recent_skill_tokens_per_skill=preserve_recent_skill_tokens_per_skill,
|
||||
)
|
||||
|
||||
|
||||
def _skill_read_call(tool_id: str, skill: str) -> dict:
|
||||
return {
|
||||
"name": "read_file",
|
||||
"id": tool_id,
|
||||
"args": {"path": f"/mnt/skills/public/{skill}/SKILL.md"},
|
||||
}
|
||||
|
||||
|
||||
def _skill_conversation() -> list:
|
||||
return [
|
||||
HumanMessage(content="u1"),
|
||||
AIMessage(content="", tool_calls=[_skill_read_call("t1", "alpha")]),
|
||||
ToolMessage(content="alpha skill body", tool_call_id="t1"),
|
||||
HumanMessage(content="u2"),
|
||||
AIMessage(content="", tool_calls=[_skill_read_call("t2", "beta")]),
|
||||
ToolMessage(content="beta skill body", tool_call_id="t2"),
|
||||
HumanMessage(content="u3"),
|
||||
AIMessage(content="final"),
|
||||
]
|
||||
|
||||
|
||||
def test_before_summarization_hook_receives_messages_before_compression() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(before_summarization=[captured.append])
|
||||
@@ -167,6 +201,295 @@ def test_memory_flush_hook_enqueues_filtered_messages_and_flushes(monkeypatch: p
|
||||
assert add_kwargs["reinforcement_detected"] is False
|
||||
|
||||
|
||||
def test_skill_rescue_keeps_recent_skill_reads_out_of_summary() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(
|
||||
before_summarization=[captured.append],
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
preserve_recent_skill_count=5,
|
||||
preserve_recent_skill_tokens=10_000,
|
||||
preserve_recent_skill_tokens_per_skill=10_000,
|
||||
)
|
||||
|
||||
result = middleware.before_model({"messages": _skill_conversation()}, _runtime())
|
||||
|
||||
assert len(captured) == 1
|
||||
summarized_ids = {id(m) for m in captured[0].messages_to_summarize}
|
||||
preserved = captured[0].preserved_messages
|
||||
|
||||
# Both skill-read bundles should be rescued into preserved_messages,
|
||||
# tool_call ↔ tool_result pairs stay intact.
|
||||
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
|
||||
assert any(isinstance(m, ToolMessage) and m.content == "beta skill body" for m in preserved)
|
||||
for m in preserved:
|
||||
if isinstance(m, ToolMessage) and m.content in {"alpha skill body", "beta skill body"}:
|
||||
assert id(m) not in summarized_ids
|
||||
|
||||
# Preserved output order: rescued bundles first, then the tail kept by parent cutoff.
|
||||
contents = [getattr(m, "content", None) for m in preserved]
|
||||
assert contents[-2:] == ["u3", "final"]
|
||||
|
||||
# The final emitted state should start with RemoveMessage + summary, then preserved messages.
|
||||
emitted = result["messages"]
|
||||
assert isinstance(emitted[0], RemoveMessage)
|
||||
assert emitted[1].content.startswith("Here is a summary")
|
||||
assert list(emitted[-2:]) == list(preserved[-2:])
|
||||
|
||||
|
||||
def test_skill_rescue_respects_count_budget() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(
|
||||
before_summarization=[captured.append],
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
preserve_recent_skill_count=1,
|
||||
preserve_recent_skill_tokens=10_000,
|
||||
preserve_recent_skill_tokens_per_skill=10_000,
|
||||
)
|
||||
|
||||
middleware.before_model({"messages": _skill_conversation()}, _runtime())
|
||||
|
||||
preserved = captured[0].preserved_messages
|
||||
summarized = captured[0].messages_to_summarize
|
||||
# Newest skill (beta) rescued; older skill (alpha) falls into summary.
|
||||
assert any(isinstance(m, ToolMessage) and m.content == "beta skill body" for m in preserved)
|
||||
assert not any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
|
||||
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in summarized)
|
||||
|
||||
|
||||
def test_skill_rescue_uses_injected_skills_container_path() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(
|
||||
before_summarization=[captured.append],
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
preserve_recent_skill_count=5,
|
||||
preserve_recent_skill_tokens=10_000,
|
||||
preserve_recent_skill_tokens_per_skill=10_000,
|
||||
)
|
||||
middleware._skills_container_path = "/custom/skills"
|
||||
messages = [
|
||||
HumanMessage(content="u1"),
|
||||
AIMessage(content="", tool_calls=[{"name": "read_file", "id": "t1", "args": {"path": "/custom/skills/demo/SKILL.md"}}]),
|
||||
ToolMessage(content="demo skill body", tool_call_id="t1"),
|
||||
HumanMessage(content="u2"),
|
||||
AIMessage(content="final"),
|
||||
]
|
||||
|
||||
middleware.before_model({"messages": messages}, _runtime())
|
||||
|
||||
preserved = captured[0].preserved_messages
|
||||
assert any(isinstance(m, ToolMessage) and m.content == "demo skill body" for m in preserved)
|
||||
|
||||
|
||||
def test_skill_rescue_uses_configured_skill_read_tool_names() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(
|
||||
before_summarization=[captured.append],
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
skill_file_read_tool_names=["custom_read"],
|
||||
preserve_recent_skill_count=5,
|
||||
preserve_recent_skill_tokens=10_000,
|
||||
preserve_recent_skill_tokens_per_skill=10_000,
|
||||
)
|
||||
middleware._skills_container_path = "/custom/skills"
|
||||
messages = [
|
||||
HumanMessage(content="u1"),
|
||||
AIMessage(content="", tool_calls=[{"name": "custom_read", "id": "t1", "args": {"path": "/custom/skills/demo/SKILL.md"}}]),
|
||||
ToolMessage(content="demo skill body", tool_call_id="t1"),
|
||||
HumanMessage(content="u2"),
|
||||
AIMessage(content="final"),
|
||||
]
|
||||
|
||||
middleware.before_model({"messages": messages}, _runtime())
|
||||
|
||||
preserved = captured[0].preserved_messages
|
||||
assert any(isinstance(m, ToolMessage) and m.content == "demo skill body" for m in preserved)
|
||||
|
||||
|
||||
def test_skill_rescue_respects_per_skill_token_cap() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(
|
||||
before_summarization=[captured.append],
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
preserve_recent_skill_count=5,
|
||||
preserve_recent_skill_tokens=10_000,
|
||||
# token_counter=len counts one token per message; per-skill cap of 0 rejects every bundle.
|
||||
preserve_recent_skill_tokens_per_skill=0,
|
||||
)
|
||||
|
||||
middleware.before_model({"messages": _skill_conversation()}, _runtime())
|
||||
|
||||
preserved = captured[0].preserved_messages
|
||||
assert not any(isinstance(m, ToolMessage) and m.content in {"alpha skill body", "beta skill body"} for m in preserved)
|
||||
|
||||
|
||||
def test_skill_rescue_disabled_when_count_zero() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(
|
||||
before_summarization=[captured.append],
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
preserve_recent_skill_count=0,
|
||||
preserve_recent_skill_tokens=10_000,
|
||||
preserve_recent_skill_tokens_per_skill=10_000,
|
||||
)
|
||||
|
||||
middleware.before_model({"messages": _skill_conversation()}, _runtime())
|
||||
|
||||
preserved = captured[0].preserved_messages
|
||||
assert not any(isinstance(m, ToolMessage) for m in preserved)
|
||||
|
||||
|
||||
def test_skill_rescue_ignores_non_skill_tool_reads() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(
|
||||
before_summarization=[captured.append],
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
preserve_recent_skill_count=5,
|
||||
preserve_recent_skill_tokens=10_000,
|
||||
preserve_recent_skill_tokens_per_skill=10_000,
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="u1"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "read_file", "id": "t1", "args": {"path": "/mnt/user-data/workspace/notes.md"}}],
|
||||
),
|
||||
ToolMessage(content="user notes", tool_call_id="t1"),
|
||||
HumanMessage(content="u2"),
|
||||
AIMessage(content="done"),
|
||||
]
|
||||
|
||||
middleware.before_model({"messages": messages}, _runtime())
|
||||
|
||||
preserved = captured[0].preserved_messages
|
||||
assert not any(isinstance(m, ToolMessage) and m.content == "user notes" for m in preserved)
|
||||
|
||||
|
||||
def test_skill_rescue_does_not_preserve_non_skill_outputs_from_mixed_tool_calls() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(
|
||||
before_summarization=[captured.append],
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
preserve_recent_skill_count=5,
|
||||
preserve_recent_skill_tokens=10_000,
|
||||
preserve_recent_skill_tokens_per_skill=10_000,
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="u1"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
_skill_read_call("skill-1", "alpha"),
|
||||
{"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}},
|
||||
],
|
||||
),
|
||||
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
|
||||
ToolMessage(content="user notes", tool_call_id="file-1"),
|
||||
HumanMessage(content="u2"),
|
||||
AIMessage(content="done"),
|
||||
]
|
||||
|
||||
middleware.before_model({"messages": messages}, _runtime())
|
||||
|
||||
preserved = captured[0].preserved_messages
|
||||
summarized = captured[0].messages_to_summarize
|
||||
|
||||
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
|
||||
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
|
||||
|
||||
assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"]
|
||||
assert [tc["id"] for tc in summarized_ai.tool_calls] == ["file-1"]
|
||||
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
|
||||
assert not any(isinstance(m, ToolMessage) and m.content == "user notes" for m in preserved)
|
||||
assert any(isinstance(m, ToolMessage) and m.content == "user notes" for m in summarized)
|
||||
|
||||
|
||||
def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(
|
||||
before_summarization=[captured.append],
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
preserve_recent_skill_count=5,
|
||||
preserve_recent_skill_tokens=10_000,
|
||||
preserve_recent_skill_tokens_per_skill=10_000,
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="u1"),
|
||||
AIMessage(
|
||||
content="reading skill and notes",
|
||||
tool_calls=[
|
||||
_skill_read_call("skill-1", "alpha"),
|
||||
{"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}},
|
||||
],
|
||||
),
|
||||
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
|
||||
ToolMessage(content="user notes", tool_call_id="file-1"),
|
||||
HumanMessage(content="u2"),
|
||||
AIMessage(content="done"),
|
||||
]
|
||||
|
||||
middleware.before_model({"messages": messages}, _runtime())
|
||||
|
||||
preserved = captured[0].preserved_messages
|
||||
summarized = captured[0].messages_to_summarize
|
||||
|
||||
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
|
||||
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
|
||||
|
||||
assert preserved_ai.content == ""
|
||||
assert summarized_ai.content == "reading skill and notes"
|
||||
|
||||
|
||||
def test_skill_rescue_only_preserves_skill_calls_with_matched_tool_results() -> None:
|
||||
captured: list[SummarizationEvent] = []
|
||||
middleware = _middleware(
|
||||
before_summarization=[captured.append],
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2),
|
||||
preserve_recent_skill_count=5,
|
||||
preserve_recent_skill_tokens=10_000,
|
||||
preserve_recent_skill_tokens_per_skill=10_000,
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="u1"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
_skill_read_call("skill-1", "alpha"),
|
||||
_skill_read_call("skill-2", "beta"),
|
||||
],
|
||||
),
|
||||
ToolMessage(content="alpha skill body", tool_call_id="skill-1"),
|
||||
HumanMessage(content="u2"),
|
||||
AIMessage(content="done"),
|
||||
]
|
||||
|
||||
middleware.before_model({"messages": messages}, _runtime())
|
||||
|
||||
preserved = captured[0].preserved_messages
|
||||
summarized = captured[0].messages_to_summarize
|
||||
|
||||
preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls)
|
||||
summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls)
|
||||
|
||||
assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"]
|
||||
assert [tc["id"] for tc in summarized_ai.tool_calls] == ["skill-2"]
|
||||
assert any(isinstance(m, ToolMessage) and m.content == "alpha skill body" for m in preserved)
|
||||
assert not any(isinstance(m, ToolMessage) and getattr(m, "tool_call_id", None) == "skill-2" for m in preserved)
|
||||
|
||||
|
||||
def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
queue = MagicMock()
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
|
||||
Reference in New Issue
Block a user