"""Summarization middleware extensions for DeerFlow.""" from __future__ import annotations import logging from dataclasses import dataclass from typing import Protocol, runtime_checkable from langchain.agents import AgentState from langchain.agents.middleware import SummarizationMiddleware from langchain_core.messages import AnyMessage, RemoveMessage from langgraph.config import get_config from langgraph.graph.message import REMOVE_ALL_MESSAGES from langgraph.runtime import Runtime logger = logging.getLogger(__name__) @dataclass(frozen=True) class SummarizationEvent: """Context emitted before conversation history is summarized away.""" messages_to_summarize: tuple[AnyMessage, ...] preserved_messages: tuple[AnyMessage, ...] thread_id: str | None agent_name: str | None runtime: Runtime @runtime_checkable class BeforeSummarizationHook(Protocol): """Hook invoked before summarization removes messages from state.""" def __call__(self, event: SummarizationEvent) -> None: ... def _resolve_thread_id(runtime: Runtime) -> str | None: """Resolve the current thread ID from runtime context or LangGraph config.""" thread_id = runtime.context.get("thread_id") if runtime.context else None if thread_id is None: try: config_data = get_config() except RuntimeError: return None thread_id = config_data.get("configurable", {}).get("thread_id") return thread_id def _resolve_agent_name(runtime: Runtime) -> str | None: """Resolve the current agent name from runtime context or LangGraph config.""" agent_name = runtime.context.get("agent_name") if runtime.context else None if agent_name is None: try: config_data = get_config() except RuntimeError: return None agent_name = config_data.get("configurable", {}).get("agent_name") return agent_name class DeerFlowSummarizationMiddleware(SummarizationMiddleware): """Summarization middleware with pre-compression hook dispatch.""" def __init__( self, *args, before_summarization: list[BeforeSummarizationHook] | None = None, **kwargs, ) -> None: super().__init__(*args, **kwargs) self._before_summarization_hooks = before_summarization or [] def before_model(self, state: AgentState, runtime: Runtime) -> dict | None: return self._maybe_summarize(state, runtime) async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None: return await self._amaybe_summarize(state, runtime) def _maybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None: messages = state["messages"] self._ensure_message_ids(messages) total_tokens = self.token_counter(messages) if not self._should_summarize(messages, total_tokens): return None cutoff_index = self._determine_cutoff_index(messages) if cutoff_index <= 0: return None messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index) self._fire_hooks(messages_to_summarize, preserved_messages, runtime) summary = self._create_summary(messages_to_summarize) new_messages = self._build_new_messages(summary) return { "messages": [ RemoveMessage(id=REMOVE_ALL_MESSAGES), *new_messages, *preserved_messages, ] } async def _amaybe_summarize(self, state: AgentState, runtime: Runtime) -> dict | None: messages = state["messages"] self._ensure_message_ids(messages) total_tokens = self.token_counter(messages) if not self._should_summarize(messages, total_tokens): return None cutoff_index = self._determine_cutoff_index(messages) if cutoff_index <= 0: return None messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index) self._fire_hooks(messages_to_summarize, preserved_messages, runtime) summary = await self._acreate_summary(messages_to_summarize) new_messages = self._build_new_messages(summary) return { "messages": [ RemoveMessage(id=REMOVE_ALL_MESSAGES), *new_messages, *preserved_messages, ] } def _fire_hooks( self, messages_to_summarize: list[AnyMessage], preserved_messages: list[AnyMessage], runtime: Runtime, ) -> None: if not self._before_summarization_hooks: return event = SummarizationEvent( messages_to_summarize=tuple(messages_to_summarize), preserved_messages=tuple(preserved_messages), thread_id=_resolve_thread_id(runtime), agent_name=_resolve_agent_name(runtime), runtime=runtime, ) for hook in self._before_summarization_hooks: try: hook(event) except Exception: hook_name = getattr(hook, "__name__", None) or type(hook).__name__ logger.exception("before_summarization hook %s failed", hook_name)