mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
c995c3a394
Fixes #2804 The useStream hook tracks messages-tuple mode which captures ALL LLM events, including the internal summarization LLM call. This caused the English summary text to briefly appear as an AI message in the UI before the REMOVE_ALL_MESSAGES state update replaced it. Fix by overriding _create_summary/_acreate_summary to pass callbacks=[] when invoking the summary model, preventing LangGraph from forwarding the internal LLM events to the frontend stream. Also add hide_from_ui=True to the summary HumanMessage's additional_kwargs as a belt-and-suspenders safety net alongside the existing name=summary check.
440 lines
17 KiB
Python
440 lines
17 KiB
Python
"""Summarization middleware extensions for DeerFlow."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from collections.abc import Collection
|
|
from dataclasses import dataclass
|
|
from typing import Any, Protocol, override, runtime_checkable
|
|
|
|
from langchain.agents import AgentState
|
|
from langchain.agents.middleware import SummarizationMiddleware
|
|
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
|
|
from langchain_core.messages.utils import get_buffer_string
|
|
from langgraph.config import get_config
|
|
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
|
from langgraph.runtime import Runtime
|
|
|
|
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
|
|
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SummarizationEvent:
|
|
"""Context emitted before conversation history is summarized away."""
|
|
|
|
messages_to_summarize: tuple[AnyMessage, ...]
|
|
preserved_messages: tuple[AnyMessage, ...]
|
|
thread_id: str | None
|
|
agent_name: str | None
|
|
runtime: Runtime
|
|
|
|
|
|
@runtime_checkable
|
|
class BeforeSummarizationHook(Protocol):
|
|
"""Hook invoked before summarization removes messages from state."""
|
|
|
|
def __call__(self, event: SummarizationEvent) -> None: ...
|
|
|
|
|
|
def _resolve_thread_id(runtime: Runtime) -> str | None:
|
|
"""Resolve the current thread ID from runtime context or LangGraph config."""
|
|
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
|
if thread_id is None:
|
|
try:
|
|
config_data = get_config()
|
|
except RuntimeError:
|
|
return None
|
|
thread_id = config_data.get("configurable", {}).get("thread_id")
|
|
return thread_id
|
|
|
|
|
|
def _resolve_agent_name(runtime: Runtime) -> str | None:
|
|
"""Resolve the current agent name from runtime context or LangGraph config."""
|
|
agent_name = runtime.context.get("agent_name") if runtime.context else None
|
|
if agent_name is None:
|
|
try:
|
|
config_data = get_config()
|
|
except RuntimeError:
|
|
return None
|
|
agent_name = config_data.get("configurable", {}).get("agent_name")
|
|
return agent_name
|
|
|
|
|
|
def _tool_call_path(tool_call: dict[str, Any]) -> str | None:
|
|
"""Best-effort extraction of a file path argument from a read_file-like tool call."""
|
|
args = tool_call.get("args") or {}
|
|
if not isinstance(args, dict):
|
|
return None
|
|
for key in ("path", "file_path", "filepath"):
|
|
value = args.get(key)
|
|
if isinstance(value, str) and value:
|
|
return value
|
|
return None
|
|
|
|
|
|
def _clone_ai_message(
|
|
message: AIMessage,
|
|
tool_calls: list[dict[str, Any]],
|
|
*,
|
|
content: Any | None = None,
|
|
) -> AIMessage:
|
|
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
|
|
return clone_ai_message_with_tool_calls(message, tool_calls, content=content)
|
|
|
|
|
|
@dataclass
|
|
class _SkillBundle:
|
|
"""Skill-related tool calls and tool results associated with one AIMessage."""
|
|
|
|
ai_index: int
|
|
skill_tool_indices: tuple[int, ...]
|
|
skill_tool_call_ids: frozenset[str]
|
|
skill_tool_tokens: int
|
|
skill_key: str
|
|
|
|
|
|
class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
|
"""Summarization middleware with pre-compression hook dispatch and skill rescue."""
|
|
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
skills_container_path: str | None = None,
|
|
skill_file_read_tool_names: Collection[str] | None = None,
|
|
before_summarization: list[BeforeSummarizationHook] | None = None,
|
|
preserve_recent_skill_count: int = 5,
|
|
preserve_recent_skill_tokens: int = 25_000,
|
|
preserve_recent_skill_tokens_per_skill: int = 5_000,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self._skills_container_path = skills_container_path or "/mnt/skills"
|
|
self._skill_file_read_tool_names = frozenset(skill_file_read_tool_names or {"read_file", "read", "view", "cat"})
|
|
self._before_summarization_hooks = before_summarization or []
|
|
self._preserve_recent_skill_count = max(0, preserve_recent_skill_count)
|
|
self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens)
|
|
self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill)
|
|
|
|
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
return self._maybe_summarize(state, runtime)
|
|
|
|
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
return await self._amaybe_summarize(state, runtime)
|
|
|
|
def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
messages = state["messages"]
|
|
self._ensure_message_ids(messages)
|
|
|
|
total_tokens = self.token_counter(messages)
|
|
if not self._should_summarize(messages, total_tokens):
|
|
return None
|
|
|
|
cutoff_index = self._determine_cutoff_index(messages)
|
|
if cutoff_index <= 0:
|
|
return None
|
|
|
|
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
|
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
|
|
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
|
summary = self._create_summary(messages_to_summarize)
|
|
new_messages = self._build_new_messages(summary)
|
|
|
|
return {
|
|
"messages": [
|
|
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
|
*new_messages,
|
|
*preserved_messages,
|
|
]
|
|
}
|
|
|
|
async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
messages = state["messages"]
|
|
self._ensure_message_ids(messages)
|
|
|
|
total_tokens = self.token_counter(messages)
|
|
if not self._should_summarize(messages, total_tokens):
|
|
return None
|
|
|
|
cutoff_index = self._determine_cutoff_index(messages)
|
|
if cutoff_index <= 0:
|
|
return None
|
|
|
|
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
|
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
|
|
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
|
summary = await self._acreate_summary(messages_to_summarize)
|
|
new_messages = self._build_new_messages(summary)
|
|
|
|
return {
|
|
"messages": [
|
|
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
|
*new_messages,
|
|
*preserved_messages,
|
|
]
|
|
}
|
|
|
|
@override
|
|
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
|
"""Generate summary without emitting streaming events to the client.
|
|
|
|
Suppresses callbacks to prevent the internal summarization LLM call from
|
|
producing visible AI message chunks in the frontend's ``messages-tuple``
|
|
stream (issue #2804).
|
|
"""
|
|
if not messages_to_summarize:
|
|
return "No previous conversation history."
|
|
|
|
trimmed = self._trim_messages_for_summary(messages_to_summarize)
|
|
if not trimmed:
|
|
return "Previous conversation was too long to summarize."
|
|
|
|
formatted = get_buffer_string(trimmed)
|
|
|
|
try:
|
|
response = self.model.invoke(
|
|
self.summary_prompt.format(messages=formatted).rstrip(),
|
|
config={
|
|
"metadata": {"lc_source": "summarization"},
|
|
"callbacks": [],
|
|
},
|
|
)
|
|
return response.text.strip()
|
|
except Exception as e:
|
|
return f"Error generating summary: {e!s}"
|
|
|
|
@override
|
|
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
|
"""Generate summary without emitting streaming events to the client.
|
|
|
|
Suppresses callbacks to prevent the internal summarization LLM call from
|
|
producing visible AI message chunks in the frontend's ``messages-tuple``
|
|
stream (issue #2804).
|
|
"""
|
|
if not messages_to_summarize:
|
|
return "No previous conversation history."
|
|
|
|
trimmed = self._trim_messages_for_summary(messages_to_summarize)
|
|
if not trimmed:
|
|
return "Previous conversation was too long to summarize."
|
|
|
|
formatted = get_buffer_string(trimmed)
|
|
|
|
try:
|
|
response = await self.model.ainvoke(
|
|
self.summary_prompt.format(messages=formatted).rstrip(),
|
|
config={
|
|
"metadata": {"lc_source": "summarization"},
|
|
"callbacks": [],
|
|
},
|
|
)
|
|
return response.text.strip()
|
|
except Exception as e:
|
|
return f"Error generating summary: {e!s}"
|
|
|
|
@override
|
|
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
|
"""Override the base implementation to let the human message with the special name 'summary'.
|
|
And this message will be ignored to display in the frontend, but still can be used as context for the model.
|
|
"""
|
|
return [
|
|
HumanMessage(
|
|
content=f"Here is a summary of the conversation to date:\n\n{summary}",
|
|
name="summary",
|
|
additional_kwargs={"hide_from_ui": True},
|
|
)
|
|
]
|
|
|
|
def _preserve_dynamic_context_reminders(
|
|
self,
|
|
messages_to_summarize: list[AnyMessage],
|
|
preserved_messages: list[AnyMessage],
|
|
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
|
"""Keep hidden dynamic-context reminders out of summary compression.
|
|
|
|
These reminders carry the current date and optional memory. If summarization
|
|
removes them, DynamicContextMiddleware can mistake the summary HumanMessage
|
|
for the first user message and inject the reminder in the wrong place.
|
|
"""
|
|
reminders = [msg for msg in messages_to_summarize if is_dynamic_context_reminder(msg)]
|
|
if not reminders:
|
|
return messages_to_summarize, preserved_messages
|
|
|
|
remaining = [msg for msg in messages_to_summarize if not is_dynamic_context_reminder(msg)]
|
|
return remaining, reminders + preserved_messages
|
|
|
|
def _partition_with_skill_rescue(
|
|
self,
|
|
messages: list[AnyMessage],
|
|
cutoff_index: int,
|
|
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
|
"""Partition like the parent, then rescue recently-loaded skill bundles."""
|
|
to_summarize, preserved = self._partition_messages(messages, cutoff_index)
|
|
|
|
if self._preserve_recent_skill_count == 0 or self._preserve_recent_skill_tokens == 0 or not to_summarize:
|
|
return to_summarize, preserved
|
|
|
|
try:
|
|
bundles = self._find_skill_bundles(to_summarize, self._skills_container_path)
|
|
except Exception:
|
|
logger.exception("Skill-preserving summarization rescue failed; falling back to default partition")
|
|
return to_summarize, preserved
|
|
|
|
if not bundles:
|
|
return to_summarize, preserved
|
|
|
|
rescue_bundles = self._select_bundles_to_rescue(bundles)
|
|
if not rescue_bundles:
|
|
return to_summarize, preserved
|
|
|
|
bundles_by_ai_index = {bundle.ai_index: bundle for bundle in rescue_bundles}
|
|
rescue_tool_indices = {idx for bundle in rescue_bundles for idx in bundle.skill_tool_indices}
|
|
rescued: list[AnyMessage] = []
|
|
remaining: list[AnyMessage] = []
|
|
for i, msg in enumerate(to_summarize):
|
|
bundle = bundles_by_ai_index.get(i)
|
|
if bundle is not None and isinstance(msg, AIMessage):
|
|
rescued_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") in bundle.skill_tool_call_ids]
|
|
remaining_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") not in bundle.skill_tool_call_ids]
|
|
|
|
if rescued_tool_calls:
|
|
rescued.append(_clone_ai_message(msg, rescued_tool_calls, content=""))
|
|
if remaining_tool_calls or msg.content:
|
|
remaining.append(_clone_ai_message(msg, remaining_tool_calls))
|
|
continue
|
|
|
|
if i in rescue_tool_indices:
|
|
rescued.append(msg)
|
|
continue
|
|
|
|
remaining.append(msg)
|
|
|
|
return remaining, rescued + preserved
|
|
|
|
def _find_skill_bundles(
|
|
self,
|
|
messages: list[AnyMessage],
|
|
skills_root: str,
|
|
) -> list[_SkillBundle]:
|
|
"""Locate AIMessage + paired ToolMessage groups that load skill files."""
|
|
bundles: list[_SkillBundle] = []
|
|
n = len(messages)
|
|
i = 0
|
|
while i < n:
|
|
msg = messages[i]
|
|
if not (isinstance(msg, AIMessage) and msg.tool_calls):
|
|
i += 1
|
|
continue
|
|
|
|
tool_calls = list(msg.tool_calls)
|
|
skill_paths_by_id: dict[str, str] = {}
|
|
for tc in tool_calls:
|
|
if self._is_skill_tool_call(tc, skills_root):
|
|
tc_id = tc.get("id")
|
|
path = _tool_call_path(tc)
|
|
if tc_id and path:
|
|
skill_paths_by_id[tc_id] = path
|
|
|
|
if not skill_paths_by_id:
|
|
i += 1
|
|
continue
|
|
|
|
skill_tool_tokens = 0
|
|
skill_key_parts: list[str] = []
|
|
skill_tool_indices: list[int] = []
|
|
matched_skill_call_ids: set[str] = set()
|
|
|
|
j = i + 1
|
|
while j < n and isinstance(messages[j], ToolMessage):
|
|
j += 1
|
|
|
|
for k in range(i + 1, j):
|
|
tool_msg = messages[k]
|
|
if isinstance(tool_msg, ToolMessage) and tool_msg.tool_call_id in skill_paths_by_id:
|
|
skill_tool_tokens += self.token_counter([tool_msg])
|
|
skill_key_parts.append(skill_paths_by_id[tool_msg.tool_call_id])
|
|
skill_tool_indices.append(k)
|
|
matched_skill_call_ids.add(tool_msg.tool_call_id)
|
|
|
|
if not skill_tool_indices:
|
|
i = j
|
|
continue
|
|
|
|
bundles.append(
|
|
_SkillBundle(
|
|
ai_index=i,
|
|
skill_tool_indices=tuple(skill_tool_indices),
|
|
skill_tool_call_ids=frozenset(matched_skill_call_ids),
|
|
skill_tool_tokens=skill_tool_tokens,
|
|
skill_key="|".join(sorted(skill_key_parts)),
|
|
)
|
|
)
|
|
i = j
|
|
|
|
return bundles
|
|
|
|
def _select_bundles_to_rescue(self, bundles: list[_SkillBundle]) -> list[_SkillBundle]:
|
|
"""Pick bundles to keep, walking newest-first under count/token budgets."""
|
|
selected: list[_SkillBundle] = []
|
|
if not bundles:
|
|
return selected
|
|
|
|
seen_skill_keys: set[str] = set()
|
|
total_tokens = 0
|
|
kept = 0
|
|
|
|
for bundle in reversed(bundles):
|
|
if kept >= self._preserve_recent_skill_count:
|
|
break
|
|
if bundle.skill_key in seen_skill_keys:
|
|
continue
|
|
if bundle.skill_tool_tokens > self._preserve_recent_skill_tokens_per_skill:
|
|
continue
|
|
if total_tokens + bundle.skill_tool_tokens > self._preserve_recent_skill_tokens:
|
|
continue
|
|
|
|
selected.append(bundle)
|
|
total_tokens += bundle.skill_tool_tokens
|
|
kept += 1
|
|
seen_skill_keys.add(bundle.skill_key)
|
|
|
|
selected.reverse()
|
|
return selected
|
|
|
|
def _is_skill_tool_call(self, tool_call: dict[str, Any], skills_root: str) -> bool:
|
|
"""Return True when ``tool_call`` reads a file under the configured skills root."""
|
|
name = tool_call.get("name") or ""
|
|
if name not in self._skill_file_read_tool_names:
|
|
return False
|
|
path = _tool_call_path(tool_call)
|
|
if not path:
|
|
return False
|
|
normalized_root = skills_root.rstrip("/")
|
|
return path == normalized_root or path.startswith(normalized_root + "/")
|
|
|
|
def _fire_hooks(
|
|
self,
|
|
messages_to_summarize: list[AnyMessage],
|
|
preserved_messages: list[AnyMessage],
|
|
runtime: Runtime,
|
|
) -> None:
|
|
if not self._before_summarization_hooks:
|
|
return
|
|
|
|
event = SummarizationEvent(
|
|
messages_to_summarize=tuple(messages_to_summarize),
|
|
preserved_messages=tuple(preserved_messages),
|
|
thread_id=_resolve_thread_id(runtime),
|
|
agent_name=_resolve_agent_name(runtime),
|
|
runtime=runtime,
|
|
)
|
|
|
|
for hook in self._before_summarization_hooks:
|
|
try:
|
|
hook(event)
|
|
except Exception:
|
|
hook_name = getattr(hook, "__name__", None) or type(hook).__name__
|
|
logger.exception("before_summarization hook %s failed", hook_name)
|