mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 01:15:58 +00:00
d133b1119a
* fix(summarization): tag summary LLM calls nostream to stop phantom stream messages (#2503) The SummarizationMiddleware runs its summary LLM call inside a before_model hook. Without a nostream tag the summary tokens were captured by LangGraph's messages-tuple stream callback and broadcast to the frontend as a phantom AI message. Generate a dedicated summary model copy tagged with "nostream" (merged on top of any existing tags such as "middleware:summarize" so RunJournal attribution is preserved) and override _create_summary / _acreate_summary to invoke it directly. This avoids temporarily swapping the shared self.model, which would otherwise leak the RunnableBinding across concurrent runs and break parent logic that inspects the raw model (profile / _get_ls_params). Add regression tests covering nostream tagging, concurrent-run isolation, raw model preservation, and existing-tag merge. * fix(summarization): address nostream review feedback
444 lines
18 KiB
Python
444 lines
18 KiB
Python
"""Summarization middleware extensions for DeerFlow."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from collections.abc import Collection
|
|
from dataclasses import dataclass
|
|
from typing import Any, Protocol, override, runtime_checkable
|
|
|
|
from langchain.agents import AgentState
|
|
from langchain.agents.middleware import SummarizationMiddleware
|
|
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage, get_buffer_string
|
|
from langgraph.config import get_config
|
|
from langgraph.constants import TAG_NOSTREAM
|
|
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
|
from langgraph.runtime import Runtime
|
|
|
|
from deerflow.agents.middlewares.dynamic_context_middleware import is_dynamic_context_reminder
|
|
from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SummarizationEvent:
|
|
"""Context emitted before conversation history is summarized away."""
|
|
|
|
messages_to_summarize: tuple[AnyMessage, ...]
|
|
preserved_messages: tuple[AnyMessage, ...]
|
|
thread_id: str | None
|
|
agent_name: str | None
|
|
runtime: Runtime
|
|
|
|
|
|
@runtime_checkable
|
|
class BeforeSummarizationHook(Protocol):
|
|
"""Hook invoked before summarization removes messages from state."""
|
|
|
|
def __call__(self, event: SummarizationEvent) -> None: ...
|
|
|
|
|
|
def _resolve_thread_id(runtime: Runtime) -> str | None:
|
|
"""Resolve the current thread ID from runtime context or LangGraph config."""
|
|
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
|
if thread_id is None:
|
|
try:
|
|
config_data = get_config()
|
|
except RuntimeError:
|
|
return None
|
|
thread_id = config_data.get("configurable", {}).get("thread_id")
|
|
return thread_id
|
|
|
|
|
|
def _resolve_agent_name(runtime: Runtime) -> str | None:
|
|
"""Resolve the current agent name from runtime context or LangGraph config."""
|
|
agent_name = runtime.context.get("agent_name") if runtime.context else None
|
|
if agent_name is None:
|
|
try:
|
|
config_data = get_config()
|
|
except RuntimeError:
|
|
return None
|
|
agent_name = config_data.get("configurable", {}).get("agent_name")
|
|
return agent_name
|
|
|
|
|
|
def _tool_call_path(tool_call: dict[str, Any]) -> str | None:
|
|
"""Best-effort extraction of a file path argument from a read_file-like tool call."""
|
|
args = tool_call.get("args") or {}
|
|
if not isinstance(args, dict):
|
|
return None
|
|
for key in ("path", "file_path", "filepath"):
|
|
value = args.get(key)
|
|
if isinstance(value, str) and value:
|
|
return value
|
|
return None
|
|
|
|
|
|
def _clone_ai_message(
|
|
message: AIMessage,
|
|
tool_calls: list[dict[str, Any]],
|
|
*,
|
|
content: Any | None = None,
|
|
) -> AIMessage:
|
|
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
|
|
return clone_ai_message_with_tool_calls(message, tool_calls, content=content)
|
|
|
|
|
|
@dataclass
|
|
class _SkillBundle:
|
|
"""Skill-related tool calls and tool results associated with one AIMessage."""
|
|
|
|
ai_index: int
|
|
skill_tool_indices: tuple[int, ...]
|
|
skill_tool_call_ids: frozenset[str]
|
|
skill_tool_tokens: int
|
|
skill_key: str
|
|
|
|
|
|
class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
|
"""Summarization middleware with pre-compression hook dispatch and skill rescue."""
|
|
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
skills_container_path: str | None = None,
|
|
skill_file_read_tool_names: Collection[str] | None = None,
|
|
before_summarization: list[BeforeSummarizationHook] | None = None,
|
|
preserve_recent_skill_count: int = 5,
|
|
preserve_recent_skill_tokens: int = 25_000,
|
|
preserve_recent_skill_tokens_per_skill: int = 5_000,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self._skills_container_path = skills_container_path or "/mnt/skills"
|
|
self._skill_file_read_tool_names = frozenset(skill_file_read_tool_names or {"read_file", "read", "view", "cat"})
|
|
self._before_summarization_hooks = before_summarization or []
|
|
self._preserve_recent_skill_count = max(0, preserve_recent_skill_count)
|
|
self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens)
|
|
self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill)
|
|
# The summary LLM call runs inside a LangGraph middleware hook, so its token
|
|
# stream would otherwise be captured by the messages-tuple stream callback and
|
|
# broadcast to the frontend as a phantom AI message. Tag a dedicated model copy
|
|
# with TAG_NOSTREAM so the streaming handler skips it.
|
|
# Keep self.model untagged so the parent's profile / ls_params inspection still works.
|
|
#
|
|
# Preserve any tags already bound on the model (e.g. "middleware:summarize" set in
|
|
# lead_agent/agent.py for RunJournal attribution): RunnableBinding.with_config does a
|
|
# shallow merge that would otherwise overwrite the existing tags list entirely.
|
|
existing_tags = list((getattr(self.model, "config", None) or {}).get("tags") or [])
|
|
merged_tags = [*existing_tags, TAG_NOSTREAM] if TAG_NOSTREAM not in existing_tags else existing_tags
|
|
self._summary_model = self.model.with_config(tags=merged_tags)
|
|
|
|
@override
|
|
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
|
return self._summarize_with(messages_to_summarize)
|
|
|
|
@override
|
|
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
|
return await self._asummarize_with(messages_to_summarize)
|
|
|
|
def _summarize_with(self, messages_to_summarize: list[AnyMessage]) -> str:
|
|
"""Mirror the parent ``_create_summary`` but invoke the nostream-tagged model.
|
|
|
|
We do not swap ``self.model`` at the instance level: the agent/middleware is
|
|
cached and reused across concurrent runs, so a temporary swap would leak the
|
|
``RunnableBinding`` to other coroutines during ``await`` and break parent logic
|
|
that inspects the raw model (``profile`` / ``_get_ls_params``).
|
|
"""
|
|
if not messages_to_summarize:
|
|
return "No previous conversation history."
|
|
prompt = self._build_summary_prompt(messages_to_summarize)
|
|
if prompt is None:
|
|
return "Previous conversation was too long to summarize."
|
|
try:
|
|
response = self._summary_model.invoke(
|
|
prompt,
|
|
config={"metadata": {"lc_source": "summarization"}},
|
|
)
|
|
return response.text.strip()
|
|
except Exception as e:
|
|
return f"Error generating summary: {e!s}"
|
|
|
|
async def _asummarize_with(self, messages_to_summarize: list[AnyMessage]) -> str:
|
|
"""Async counterpart of :meth:`_summarize_with` using the nostream model."""
|
|
if not messages_to_summarize:
|
|
return "No previous conversation history."
|
|
prompt = self._build_summary_prompt(messages_to_summarize)
|
|
if prompt is None:
|
|
return "Previous conversation was too long to summarize."
|
|
try:
|
|
response = await self._summary_model.ainvoke(
|
|
prompt,
|
|
config={"metadata": {"lc_source": "summarization"}},
|
|
)
|
|
return response.text.strip()
|
|
except Exception as e:
|
|
return f"Error generating summary: {e!s}"
|
|
|
|
def _build_summary_prompt(self, messages_to_summarize: list[AnyMessage]) -> str | None:
|
|
"""Build the summary prompt, returning ``None`` when trimming leaves nothing."""
|
|
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
|
|
if not trimmed_messages:
|
|
return None
|
|
# Format messages to avoid token inflation from metadata when str() is called on
|
|
# message objects.
|
|
formatted_messages = get_buffer_string(trimmed_messages)
|
|
return self.summary_prompt.format(messages=formatted_messages).rstrip()
|
|
|
|
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
return self._maybe_summarize(state, runtime)
|
|
|
|
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
return await self._amaybe_summarize(state, runtime)
|
|
|
|
def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
messages = state["messages"]
|
|
self._ensure_message_ids(messages)
|
|
|
|
total_tokens = self.token_counter(messages)
|
|
if not self._should_summarize(messages, total_tokens):
|
|
return None
|
|
|
|
cutoff_index = self._determine_cutoff_index(messages)
|
|
if cutoff_index <= 0:
|
|
return None
|
|
|
|
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
|
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
|
|
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
|
summary = self._create_summary(messages_to_summarize)
|
|
new_messages = self._build_new_messages(summary)
|
|
|
|
return {
|
|
"messages": [
|
|
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
|
*new_messages,
|
|
*preserved_messages,
|
|
]
|
|
}
|
|
|
|
async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None:
|
|
messages = state["messages"]
|
|
self._ensure_message_ids(messages)
|
|
|
|
total_tokens = self.token_counter(messages)
|
|
if not self._should_summarize(messages, total_tokens):
|
|
return None
|
|
|
|
cutoff_index = self._determine_cutoff_index(messages)
|
|
if cutoff_index <= 0:
|
|
return None
|
|
|
|
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
|
|
messages_to_summarize, preserved_messages = self._preserve_dynamic_context_reminders(messages_to_summarize, preserved_messages)
|
|
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
|
|
summary = await self._acreate_summary(messages_to_summarize)
|
|
new_messages = self._build_new_messages(summary)
|
|
|
|
return {
|
|
"messages": [
|
|
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
|
*new_messages,
|
|
*preserved_messages,
|
|
]
|
|
}
|
|
|
|
@override
|
|
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
|
"""Override the base implementation to let the human message with the special name 'summary'.
|
|
And this message will be ignored to display in the frontend, but still can be used as context for the model.
|
|
"""
|
|
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
|
|
|
|
def _preserve_dynamic_context_reminders(
|
|
self,
|
|
messages_to_summarize: list[AnyMessage],
|
|
preserved_messages: list[AnyMessage],
|
|
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
|
"""Keep hidden dynamic-context reminders out of summary compression.
|
|
|
|
These reminders carry the current date and optional memory. If summarization
|
|
removes them, DynamicContextMiddleware can mistake the summary HumanMessage
|
|
for the first user message and inject the reminder in the wrong place.
|
|
"""
|
|
reminders = [msg for msg in messages_to_summarize if is_dynamic_context_reminder(msg)]
|
|
if not reminders:
|
|
return messages_to_summarize, preserved_messages
|
|
|
|
remaining = [msg for msg in messages_to_summarize if not is_dynamic_context_reminder(msg)]
|
|
return remaining, reminders + preserved_messages
|
|
|
|
def _partition_with_skill_rescue(
|
|
self,
|
|
messages: list[AnyMessage],
|
|
cutoff_index: int,
|
|
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
|
"""Partition like the parent, then rescue recently-loaded skill bundles."""
|
|
to_summarize, preserved = self._partition_messages(messages, cutoff_index)
|
|
|
|
if self._preserve_recent_skill_count == 0 or self._preserve_recent_skill_tokens == 0 or not to_summarize:
|
|
return to_summarize, preserved
|
|
|
|
try:
|
|
bundles = self._find_skill_bundles(to_summarize, self._skills_container_path)
|
|
except Exception:
|
|
logger.exception("Skill-preserving summarization rescue failed; falling back to default partition")
|
|
return to_summarize, preserved
|
|
|
|
if not bundles:
|
|
return to_summarize, preserved
|
|
|
|
rescue_bundles = self._select_bundles_to_rescue(bundles)
|
|
if not rescue_bundles:
|
|
return to_summarize, preserved
|
|
|
|
bundles_by_ai_index = {bundle.ai_index: bundle for bundle in rescue_bundles}
|
|
rescue_tool_indices = {idx for bundle in rescue_bundles for idx in bundle.skill_tool_indices}
|
|
rescued: list[AnyMessage] = []
|
|
remaining: list[AnyMessage] = []
|
|
for i, msg in enumerate(to_summarize):
|
|
bundle = bundles_by_ai_index.get(i)
|
|
if bundle is not None and isinstance(msg, AIMessage):
|
|
rescued_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") in bundle.skill_tool_call_ids]
|
|
remaining_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") not in bundle.skill_tool_call_ids]
|
|
|
|
if rescued_tool_calls:
|
|
rescued.append(_clone_ai_message(msg, rescued_tool_calls, content=""))
|
|
if remaining_tool_calls or msg.content:
|
|
remaining.append(_clone_ai_message(msg, remaining_tool_calls))
|
|
continue
|
|
|
|
if i in rescue_tool_indices:
|
|
rescued.append(msg)
|
|
continue
|
|
|
|
remaining.append(msg)
|
|
|
|
return remaining, rescued + preserved
|
|
|
|
def _find_skill_bundles(
|
|
self,
|
|
messages: list[AnyMessage],
|
|
skills_root: str,
|
|
) -> list[_SkillBundle]:
|
|
"""Locate AIMessage + paired ToolMessage groups that load skill files."""
|
|
bundles: list[_SkillBundle] = []
|
|
n = len(messages)
|
|
i = 0
|
|
while i < n:
|
|
msg = messages[i]
|
|
if not (isinstance(msg, AIMessage) and msg.tool_calls):
|
|
i += 1
|
|
continue
|
|
|
|
tool_calls = list(msg.tool_calls)
|
|
skill_paths_by_id: dict[str, str] = {}
|
|
for tc in tool_calls:
|
|
if self._is_skill_tool_call(tc, skills_root):
|
|
tc_id = tc.get("id")
|
|
path = _tool_call_path(tc)
|
|
if tc_id and path:
|
|
skill_paths_by_id[tc_id] = path
|
|
|
|
if not skill_paths_by_id:
|
|
i += 1
|
|
continue
|
|
|
|
skill_tool_tokens = 0
|
|
skill_key_parts: list[str] = []
|
|
skill_tool_indices: list[int] = []
|
|
matched_skill_call_ids: set[str] = set()
|
|
|
|
j = i + 1
|
|
while j < n and isinstance(messages[j], ToolMessage):
|
|
j += 1
|
|
|
|
for k in range(i + 1, j):
|
|
tool_msg = messages[k]
|
|
if isinstance(tool_msg, ToolMessage) and tool_msg.tool_call_id in skill_paths_by_id:
|
|
skill_tool_tokens += self.token_counter([tool_msg])
|
|
skill_key_parts.append(skill_paths_by_id[tool_msg.tool_call_id])
|
|
skill_tool_indices.append(k)
|
|
matched_skill_call_ids.add(tool_msg.tool_call_id)
|
|
|
|
if not skill_tool_indices:
|
|
i = j
|
|
continue
|
|
|
|
bundles.append(
|
|
_SkillBundle(
|
|
ai_index=i,
|
|
skill_tool_indices=tuple(skill_tool_indices),
|
|
skill_tool_call_ids=frozenset(matched_skill_call_ids),
|
|
skill_tool_tokens=skill_tool_tokens,
|
|
skill_key="|".join(sorted(skill_key_parts)),
|
|
)
|
|
)
|
|
i = j
|
|
|
|
return bundles
|
|
|
|
def _select_bundles_to_rescue(self, bundles: list[_SkillBundle]) -> list[_SkillBundle]:
|
|
"""Pick bundles to keep, walking newest-first under count/token budgets."""
|
|
selected: list[_SkillBundle] = []
|
|
if not bundles:
|
|
return selected
|
|
|
|
seen_skill_keys: set[str] = set()
|
|
total_tokens = 0
|
|
kept = 0
|
|
|
|
for bundle in reversed(bundles):
|
|
if kept >= self._preserve_recent_skill_count:
|
|
break
|
|
if bundle.skill_key in seen_skill_keys:
|
|
continue
|
|
if bundle.skill_tool_tokens > self._preserve_recent_skill_tokens_per_skill:
|
|
continue
|
|
if total_tokens + bundle.skill_tool_tokens > self._preserve_recent_skill_tokens:
|
|
continue
|
|
|
|
selected.append(bundle)
|
|
total_tokens += bundle.skill_tool_tokens
|
|
kept += 1
|
|
seen_skill_keys.add(bundle.skill_key)
|
|
|
|
selected.reverse()
|
|
return selected
|
|
|
|
def _is_skill_tool_call(self, tool_call: dict[str, Any], skills_root: str) -> bool:
|
|
"""Return True when ``tool_call`` reads a file under the configured skills root."""
|
|
name = tool_call.get("name") or ""
|
|
if name not in self._skill_file_read_tool_names:
|
|
return False
|
|
path = _tool_call_path(tool_call)
|
|
if not path:
|
|
return False
|
|
normalized_root = skills_root.rstrip("/")
|
|
return path == normalized_root or path.startswith(normalized_root + "/")
|
|
|
|
def _fire_hooks(
|
|
self,
|
|
messages_to_summarize: list[AnyMessage],
|
|
preserved_messages: list[AnyMessage],
|
|
runtime: Runtime,
|
|
) -> None:
|
|
if not self._before_summarization_hooks:
|
|
return
|
|
|
|
event = SummarizationEvent(
|
|
messages_to_summarize=tuple(messages_to_summarize),
|
|
preserved_messages=tuple(preserved_messages),
|
|
thread_id=_resolve_thread_id(runtime),
|
|
agent_name=_resolve_agent_name(runtime),
|
|
runtime=runtime,
|
|
)
|
|
|
|
for hook in self._before_summarization_hooks:
|
|
try:
|
|
hook(event)
|
|
except Exception:
|
|
hook_name = getattr(hook, "__name__", None) or type(hook).__name__
|
|
logger.exception("before_summarization hook %s failed", hook_name)
|