mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
fix(middleware): avoid rescuing non-skill tool outputs during summarization (#2458)
* fix(middelware): narrow skill rescue to skill-related tool outputs * fix(summarization): address skill rescue review feedback * fix: wire summarization skill rescue config * fix: remove dead skill tool helper * fix(lint): fix format --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -84,7 +84,24 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
|
||||
if get_memory_config().enabled:
|
||||
hooks.append(memory_flush_hook)
|
||||
|
||||
return DeerFlowSummarizationMiddleware(**kwargs, before_summarization=hooks)
|
||||
# The logic below relies on two assumptions holding true: this factory is
|
||||
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
|
||||
# config is not expected to change after startup.
|
||||
try:
|
||||
skills_container_path = get_app_config().skills.container_path or "/mnt/skills"
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve skills container path; falling back to default")
|
||||
skills_container_path = "/mnt/skills"
|
||||
|
||||
return DeerFlowSummarizationMiddleware(
|
||||
**kwargs,
|
||||
skills_container_path=skills_container_path,
|
||||
skill_file_read_tool_names=config.skill_file_read_tool_names,
|
||||
before_summarization=hooks,
|
||||
preserve_recent_skill_count=config.preserve_recent_skill_count,
|
||||
preserve_recent_skill_tokens=config.preserve_recent_skill_tokens,
|
||||
preserve_recent_skill_tokens_per_skill=config.preserve_recent_skill_tokens_per_skill,
|
||||
)
|
||||
|
||||
|
||||
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
|
||||
|
||||
@@ -3,12 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Collection
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol, runtime_checkable
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import SummarizationMiddleware
|
||||
from langchain_core.messages import AnyMessage, RemoveMessage
|
||||
from langchain_core.messages import AIMessage, AnyMessage, RemoveMessage, ToolMessage
|
||||
from langgraph.config import get_config
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
@@ -58,17 +59,63 @@ def _resolve_agent_name(runtime: Runtime) -> str | None:
|
||||
return agent_name
|
||||
|
||||
|
||||
def _tool_call_path(tool_call: dict[str, Any]) -> str | None:
|
||||
"""Best-effort extraction of a file path argument from a read_file-like tool call."""
|
||||
args = tool_call.get("args") or {}
|
||||
if not isinstance(args, dict):
|
||||
return None
|
||||
for key in ("path", "file_path", "filepath"):
|
||||
value = args.get(key)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _clone_ai_message(
|
||||
message: AIMessage,
|
||||
tool_calls: list[dict[str, Any]],
|
||||
*,
|
||||
content: Any | None = None,
|
||||
) -> AIMessage:
|
||||
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
|
||||
update: dict[str, Any] = {"tool_calls": tool_calls}
|
||||
if content is not None:
|
||||
update["content"] = content
|
||||
return message.model_copy(update=update)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SkillBundle:
|
||||
"""Skill-related tool calls and tool results associated with one AIMessage."""
|
||||
|
||||
ai_index: int
|
||||
skill_tool_indices: tuple[int, ...]
|
||||
skill_tool_call_ids: frozenset[str]
|
||||
skill_tool_tokens: int
|
||||
skill_key: str
|
||||
|
||||
|
||||
class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
"""Summarization middleware with pre-compression hook dispatch."""
|
||||
"""Summarization middleware with pre-compression hook dispatch and skill rescue."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
skills_container_path: str | None = None,
|
||||
skill_file_read_tool_names: Collection[str] | None = None,
|
||||
before_summarization: list[BeforeSummarizationHook] | None = None,
|
||||
preserve_recent_skill_count: int = 5,
|
||||
preserve_recent_skill_tokens: int = 25_000,
|
||||
preserve_recent_skill_tokens_per_skill: int = 5_000,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._skills_container_path = skills_container_path or "/mnt/skills"
|
||||
self._skill_file_read_tool_names = frozenset(skill_file_read_tool_names or {"read_file", "read", "view", "cat"})
|
||||
self._before_summarization_hooks = before_summarization or []
|
||||
self._preserve_recent_skill_count = max(0, preserve_recent_skill_count)
|
||||
self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens)
|
||||
self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill)
|
||||
|
||||
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._maybe_summarize(state, runtime)
|
||||
@@ -88,7 +135,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
if cutoff_index <= 0:
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
|
||||
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
||||
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
||||
summary = self._create_summary(messages_to_summarize)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
@@ -113,7 +160,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
if cutoff_index <= 0:
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
|
||||
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
||||
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
||||
summary = await self._acreate_summary(messages_to_summarize)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
@@ -126,6 +173,155 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
]
|
||||
}
|
||||
|
||||
def _partition_with_skill_rescue(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
cutoff_index: int,
|
||||
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
||||
"""Partition like the parent, then rescue recently-loaded skill bundles."""
|
||||
to_summarize, preserved = self._partition_messages(messages, cutoff_index)
|
||||
|
||||
if self._preserve_recent_skill_count == 0 or self._preserve_recent_skill_tokens == 0 or not to_summarize:
|
||||
return to_summarize, preserved
|
||||
|
||||
try:
|
||||
bundles = self._find_skill_bundles(to_summarize, self._skills_container_path)
|
||||
except Exception:
|
||||
logger.exception("Skill-preserving summarization rescue failed; falling back to default partition")
|
||||
return to_summarize, preserved
|
||||
|
||||
if not bundles:
|
||||
return to_summarize, preserved
|
||||
|
||||
rescue_bundles = self._select_bundles_to_rescue(bundles)
|
||||
if not rescue_bundles:
|
||||
return to_summarize, preserved
|
||||
|
||||
bundles_by_ai_index = {bundle.ai_index: bundle for bundle in rescue_bundles}
|
||||
rescue_tool_indices = {idx for bundle in rescue_bundles for idx in bundle.skill_tool_indices}
|
||||
rescued: list[AnyMessage] = []
|
||||
remaining: list[AnyMessage] = []
|
||||
for i, msg in enumerate(to_summarize):
|
||||
bundle = bundles_by_ai_index.get(i)
|
||||
if bundle is not None and isinstance(msg, AIMessage):
|
||||
rescued_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") in bundle.skill_tool_call_ids]
|
||||
remaining_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") not in bundle.skill_tool_call_ids]
|
||||
|
||||
if rescued_tool_calls:
|
||||
rescued.append(_clone_ai_message(msg, rescued_tool_calls, content=""))
|
||||
if remaining_tool_calls or msg.content:
|
||||
remaining.append(_clone_ai_message(msg, remaining_tool_calls))
|
||||
continue
|
||||
|
||||
if i in rescue_tool_indices:
|
||||
rescued.append(msg)
|
||||
continue
|
||||
|
||||
remaining.append(msg)
|
||||
|
||||
return remaining, rescued + preserved
|
||||
|
||||
def _find_skill_bundles(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
skills_root: str,
|
||||
) -> list[_SkillBundle]:
|
||||
"""Locate AIMessage + paired ToolMessage groups that load skill files."""
|
||||
bundles: list[_SkillBundle] = []
|
||||
n = len(messages)
|
||||
i = 0
|
||||
while i < n:
|
||||
msg = messages[i]
|
||||
if not (isinstance(msg, AIMessage) and msg.tool_calls):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
tool_calls = list(msg.tool_calls)
|
||||
skill_paths_by_id: dict[str, str] = {}
|
||||
for tc in tool_calls:
|
||||
if self._is_skill_tool_call(tc, skills_root):
|
||||
tc_id = tc.get("id")
|
||||
path = _tool_call_path(tc)
|
||||
if tc_id and path:
|
||||
skill_paths_by_id[tc_id] = path
|
||||
|
||||
if not skill_paths_by_id:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
skill_tool_tokens = 0
|
||||
skill_key_parts: list[str] = []
|
||||
skill_tool_indices: list[int] = []
|
||||
matched_skill_call_ids: set[str] = set()
|
||||
|
||||
j = i + 1
|
||||
while j < n and isinstance(messages[j], ToolMessage):
|
||||
j += 1
|
||||
|
||||
for k in range(i + 1, j):
|
||||
tool_msg = messages[k]
|
||||
if isinstance(tool_msg, ToolMessage) and tool_msg.tool_call_id in skill_paths_by_id:
|
||||
skill_tool_tokens += self.token_counter([tool_msg])
|
||||
skill_key_parts.append(skill_paths_by_id[tool_msg.tool_call_id])
|
||||
skill_tool_indices.append(k)
|
||||
matched_skill_call_ids.add(tool_msg.tool_call_id)
|
||||
|
||||
if not skill_tool_indices:
|
||||
i = j
|
||||
continue
|
||||
|
||||
bundles.append(
|
||||
_SkillBundle(
|
||||
ai_index=i,
|
||||
skill_tool_indices=tuple(skill_tool_indices),
|
||||
skill_tool_call_ids=frozenset(matched_skill_call_ids),
|
||||
skill_tool_tokens=skill_tool_tokens,
|
||||
skill_key="|".join(sorted(skill_key_parts)),
|
||||
)
|
||||
)
|
||||
i = j
|
||||
|
||||
return bundles
|
||||
|
||||
def _select_bundles_to_rescue(self, bundles: list[_SkillBundle]) -> list[_SkillBundle]:
|
||||
"""Pick bundles to keep, walking newest-first under count/token budgets."""
|
||||
selected: list[_SkillBundle] = []
|
||||
if not bundles:
|
||||
return selected
|
||||
|
||||
seen_skill_keys: set[str] = set()
|
||||
total_tokens = 0
|
||||
kept = 0
|
||||
|
||||
for bundle in reversed(bundles):
|
||||
if kept >= self._preserve_recent_skill_count:
|
||||
break
|
||||
if bundle.skill_key in seen_skill_keys:
|
||||
continue
|
||||
if bundle.skill_tool_tokens > self._preserve_recent_skill_tokens_per_skill:
|
||||
continue
|
||||
if total_tokens + bundle.skill_tool_tokens > self._preserve_recent_skill_tokens:
|
||||
continue
|
||||
|
||||
selected.append(bundle)
|
||||
total_tokens += bundle.skill_tool_tokens
|
||||
kept += 1
|
||||
seen_skill_keys.add(bundle.skill_key)
|
||||
|
||||
selected.reverse()
|
||||
return selected
|
||||
|
||||
def _is_skill_tool_call(self, tool_call: dict[str, Any], skills_root: str) -> bool:
|
||||
"""Return True when ``tool_call`` reads a file under the configured skills root."""
|
||||
name = tool_call.get("name") or ""
|
||||
if name not in self._skill_file_read_tool_names:
|
||||
return False
|
||||
path = _tool_call_path(tool_call)
|
||||
if not path:
|
||||
return False
|
||||
normalized_root = skills_root.rstrip("/")
|
||||
return path == normalized_root or path.startswith(normalized_root + "/")
|
||||
|
||||
def _fire_hooks(
|
||||
self,
|
||||
messages_to_summarize: list[AnyMessage],
|
||||
|
||||
@@ -51,6 +51,25 @@ class SummarizationConfig(BaseModel):
|
||||
default=None,
|
||||
description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.",
|
||||
)
|
||||
preserve_recent_skill_count: int = Field(
|
||||
default=5,
|
||||
ge=0,
|
||||
description="Number of most-recently-loaded skill files to exclude from summarization. Set to 0 to disable skill preservation.",
|
||||
)
|
||||
preserve_recent_skill_tokens: int = Field(
|
||||
default=25000,
|
||||
ge=0,
|
||||
description="Total token budget reserved for recently-loaded skill files that must be preserved across summarization.",
|
||||
)
|
||||
preserve_recent_skill_tokens_per_skill: int = Field(
|
||||
default=5000,
|
||||
ge=0,
|
||||
description="Per-skill token cap when preserving skill files across summarization. Skill reads above this size are not rescued.",
|
||||
)
|
||||
skill_file_read_tool_names: list[str] = Field(
|
||||
default_factory=lambda: ["read_file", "read", "view", "cat"],
|
||||
description="Tool names treated as skill file reads when preserving recently-loaded skills across summarization.",
|
||||
)
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
|
||||
Reference in New Issue
Block a user