"""Summarization middleware extensions for DeerFlow.""" from __future__ import annotations import logging from collections.abc import Collection from dataclasses import dataclass from typing import Any, Protocol, override, runtime_checkable from langchain.agents import AgentState from langchain.agents.middleware import SummarizationMiddleware from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage from langgraph.config import get_config from langgraph.graph.message import REMOVE_ALL_MESSAGES from langgraph.runtime import Runtime logger = logging.getLogger(__name__) @dataclass(frozen=True) class SummarizationEvent: """Context emitted before conversation history is summarized away.""" messages_to_summarize: tuple[AnyMessage, ...] preserved_messages: tuple[AnyMessage, ...] thread_id: str | None agent_name: str | None runtime: Runtime @runtime_checkable class BeforeSummarizationHook(Protocol): """Hook invoked before summarization removes messages from state.""" def __call__(self, event: SummarizationEvent) -> None: ... def _resolve_thread_id(runtime: Runtime) -> str | None: """Resolve the current thread ID from runtime context or LangGraph config.""" thread_id = runtime.context.get("thread_id") if runtime.context else None if thread_id is None: try: config_data = get_config() except RuntimeError: return None thread_id = config_data.get("configurable", {}).get("thread_id") return thread_id def _resolve_agent_name(runtime: Runtime) -> str | None: """Resolve the current agent name from runtime context or LangGraph config.""" agent_name = runtime.context.get("agent_name") if runtime.context else None if agent_name is None: try: config_data = get_config() except RuntimeError: return None agent_name = config_data.get("configurable", {}).get("agent_name") return agent_name def _tool_call_path(tool_call: dict[str, Any]) -> str | None: """Best-effort extraction of a file path argument from a read_file-like tool call.""" args = tool_call.get("args") or {} if not isinstance(args, dict): return None for key in ("path", "file_path", "filepath"): value = args.get(key) if isinstance(value, str) and value: return value return None def _clone_ai_message( message: AIMessage, tool_calls: list[dict[str, Any]], *, content: Any | None = None, ) -> AIMessage: """Clone an AIMessage while replacing its tool_calls list and optional content.""" update: dict[str, Any] = {"tool_calls": tool_calls} if content is not None: update["content"] = content return message.model_copy(update=update) @dataclass class _SkillBundle: """Skill-related tool calls and tool results associated with one AIMessage.""" ai_index: int skill_tool_indices: tuple[int, ...] skill_tool_call_ids: frozenset[str] skill_tool_tokens: int skill_key: str class DeerFlowSummarizationMiddleware(SummarizationMiddleware): """Summarization middleware with pre-compression hook dispatch and skill rescue.""" def __init__( self, *args, skills_container_path: str | None = None, skill_file_read_tool_names: Collection[str] | None = None, before_summarization: list[BeforeSummarizationHook] | None = None, preserve_recent_skill_count: int = 5, preserve_recent_skill_tokens: int = 25_000, preserve_recent_skill_tokens_per_skill: int = 5_000, **kwargs, ) -> None: super().__init__(*args, **kwargs) self._skills_container_path = skills_container_path or "/mnt/skills" self._skill_file_read_tool_names = frozenset(skill_file_read_tool_names or {"read_file", "read", "view", "cat"}) self._before_summarization_hooks = before_summarization or [] self._preserve_recent_skill_count = max(0, preserve_recent_skill_count) self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens) self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill) def before_model(self, state: AgentState, runtime: Runtime) -> dict | None: return self._maybe_summarize(state, runtime) async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None: return await self._amaybe_summarize(state, runtime) def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None: messages = state["messages"] self._ensure_message_ids(messages) total_tokens = self.token_counter(messages) if not self._should_summarize(messages, total_tokens): return None cutoff_index = self._determine_cutoff_index(messages) if cutoff_index <= 0: return None messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index) self._fire_hooks(messages_to_summarize, preserved_messages, runtime) summary = self._create_summary(messages_to_summarize) new_messages = self._build_new_messages(summary) return { "messages": [ RemoveMessage(id=REMOVE_ALL_MESSAGES), *new_messages, *preserved_messages, ] } async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None: messages = state["messages"] self._ensure_message_ids(messages) total_tokens = self.token_counter(messages) if not self._should_summarize(messages, total_tokens): return None cutoff_index = self._determine_cutoff_index(messages) if cutoff_index <= 0: return None messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index) self._fire_hooks(messages_to_summarize, preserved_messages, runtime) summary = await self._acreate_summary(messages_to_summarize) new_messages = self._build_new_messages(summary) return { "messages": [ RemoveMessage(id=REMOVE_ALL_MESSAGES), *new_messages, *preserved_messages, ] } @override def _build_new_messages(self, summary: str) -> list[HumanMessage]: """Override the base implementation to let the human message with the special name 'summary'. And this message will be ignored to display in the frontend, but still can be used as context for the model. """ return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")] def _partition_with_skill_rescue( self, messages: list[AnyMessage], cutoff_index: int, ) -> tuple[list[AnyMessage], list[AnyMessage]]: """Partition like the parent, then rescue recently-loaded skill bundles.""" to_summarize, preserved = self._partition_messages(messages, cutoff_index) if self._preserve_recent_skill_count == 0 or self._preserve_recent_skill_tokens == 0 or not to_summarize: return to_summarize, preserved try: bundles = self._find_skill_bundles(to_summarize, self._skills_container_path) except Exception: logger.exception("Skill-preserving summarization rescue failed; falling back to default partition") return to_summarize, preserved if not bundles: return to_summarize, preserved rescue_bundles = self._select_bundles_to_rescue(bundles) if not rescue_bundles: return to_summarize, preserved bundles_by_ai_index = {bundle.ai_index: bundle for bundle in rescue_bundles} rescue_tool_indices = {idx for bundle in rescue_bundles for idx in bundle.skill_tool_indices} rescued: list[AnyMessage] = [] remaining: list[AnyMessage] = [] for i, msg in enumerate(to_summarize): bundle = bundles_by_ai_index.get(i) if bundle is not None and isinstance(msg, AIMessage): rescued_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") in bundle.skill_tool_call_ids] remaining_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") not in bundle.skill_tool_call_ids] if rescued_tool_calls: rescued.append(_clone_ai_message(msg, rescued_tool_calls, content="")) if remaining_tool_calls or msg.content: remaining.append(_clone_ai_message(msg, remaining_tool_calls)) continue if i in rescue_tool_indices: rescued.append(msg) continue remaining.append(msg) return remaining, rescued + preserved def _find_skill_bundles( self, messages: list[AnyMessage], skills_root: str, ) -> list[_SkillBundle]: """Locate AIMessage + paired ToolMessage groups that load skill files.""" bundles: list[_SkillBundle] = [] n = len(messages) i = 0 while i < n: msg = messages[i] if not (isinstance(msg, AIMessage) and msg.tool_calls): i += 1 continue tool_calls = list(msg.tool_calls) skill_paths_by_id: dict[str, str] = {} for tc in tool_calls: if self._is_skill_tool_call(tc, skills_root): tc_id = tc.get("id") path = _tool_call_path(tc) if tc_id and path: skill_paths_by_id[tc_id] = path if not skill_paths_by_id: i += 1 continue skill_tool_tokens = 0 skill_key_parts: list[str] = [] skill_tool_indices: list[int] = [] matched_skill_call_ids: set[str] = set() j = i + 1 while j < n and isinstance(messages[j], ToolMessage): j += 1 for k in range(i + 1, j): tool_msg = messages[k] if isinstance(tool_msg, ToolMessage) and tool_msg.tool_call_id in skill_paths_by_id: skill_tool_tokens += self.token_counter([tool_msg]) skill_key_parts.append(skill_paths_by_id[tool_msg.tool_call_id]) skill_tool_indices.append(k) matched_skill_call_ids.add(tool_msg.tool_call_id) if not skill_tool_indices: i = j continue bundles.append( _SkillBundle( ai_index=i, skill_tool_indices=tuple(skill_tool_indices), skill_tool_call_ids=frozenset(matched_skill_call_ids), skill_tool_tokens=skill_tool_tokens, skill_key="|".join(sorted(skill_key_parts)), ) ) i = j return bundles def _select_bundles_to_rescue(self, bundles: list[_SkillBundle]) -> list[_SkillBundle]: """Pick bundles to keep, walking newest-first under count/token budgets.""" selected: list[_SkillBundle] = [] if not bundles: return selected seen_skill_keys: set[str] = set() total_tokens = 0 kept = 0 for bundle in reversed(bundles): if kept >= self._preserve_recent_skill_count: break if bundle.skill_key in seen_skill_keys: continue if bundle.skill_tool_tokens > self._preserve_recent_skill_tokens_per_skill: continue if total_tokens + bundle.skill_tool_tokens > self._preserve_recent_skill_tokens: continue selected.append(bundle) total_tokens += bundle.skill_tool_tokens kept += 1 seen_skill_keys.add(bundle.skill_key) selected.reverse() return selected def _is_skill_tool_call(self, tool_call: dict[str, Any], skills_root: str) -> bool: """Return True when ``tool_call`` reads a file under the configured skills root.""" name = tool_call.get("name") or "" if name not in self._skill_file_read_tool_names: return False path = _tool_call_path(tool_call) if not path: return False normalized_root = skills_root.rstrip("/") return path == normalized_root or path.startswith(normalized_root + "/") def _fire_hooks( self, messages_to_summarize: list[AnyMessage], preserved_messages: list[AnyMessage], runtime: Runtime, ) -> None: if not self._before_summarization_hooks: return event = SummarizationEvent( messages_to_summarize=tuple(messages_to_summarize), preserved_messages=tuple(preserved_messages), thread_id=_resolve_thread_id(runtime), agent_name=_resolve_agent_name(runtime), runtime=runtime, ) for hook in self._before_summarization_hooks: try: hook(event) except Exception: hook_name = getattr(hook, "__name__", None) or type(hook).__name__ logger.exception("before_summarization hook %s failed", hook_name)