refactor(backend): simplify DeerFlowClient streaming helpers (#1969)

Post-review cleanup for the token-level streaming fix. No behavior
change for correct inputs; one efficiency regression fixed.

Fix: chat() O(n²) accumulator
-----------------------------
`chat()` accumulated per-id text via `buffers[id] = buffers.get(id,"") + delta`,
which is O(n) per concat → O(n²) total over a streamed response. At
~2 KB cumulative text this becomes user-visible; at 50 KB / 5000 chunks
it costs roughly 100-300 ms of pure copying. Switched to
`dict[str, list[str]]` + `"".join()` once at return.

Cleanup
-------
- Extract `_serialize_tool_calls`, `_ai_text_event`, `_ai_tool_calls_event`,
  and `_tool_message_event` static helpers. The messages-mode and
  values-mode branches previously repeated four inline dict literals each;
  they now call the same builders.
- `StreamEvent.type` is now typed as `Literal["values", "messages-tuple",
  "custom", "end"]` via a `StreamEventType` alias. Makes the closed set
  explicit and catches typos at type-check time.
- Direct attribute access on `AIMessage`/`AIMessageChunk`: `.usage_metadata`,
  `.tool_calls`, `.id` all have default values on the base class, so the
  `getattr(..., None)` fallbacks were dead code. Removed from the hot
  path.
- `_account_usage` parameter type loosened to `Any` so that LangChain's
  `UsageMetadata` TypedDict is accepted under strict type checking.
- Trimmed narrating comments on `seen_ids` / `streamed_ids` / the
  values-synthesis skip block; kept the non-obvious ones that document
  the cross-mode dedup invariant.

Net diff: -15 lines. All 132 unit tests + harness boundary test still
pass; ruff check and ruff format pass.
This commit is contained in:
greatmengqi
2026-04-08 10:53:24 +08:00
parent f3486bb37d
commit 1f11ba10a9
+74 -89
View File
@@ -25,7 +25,7 @@ import uuid
from collections.abc import Generator, Sequence from collections.abc import Generator, Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Literal
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
@@ -55,6 +55,9 @@ from deerflow.uploads.manager import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
StreamEventType = Literal["values", "messages-tuple", "custom", "end"]
@dataclass @dataclass
class StreamEvent: class StreamEvent:
"""A single event from the streaming agent response. """A single event from the streaming agent response.
@@ -69,7 +72,7 @@ class StreamEvent:
data: Event payload. Contents vary by type. data: Event payload. Contents vary by type.
""" """
type: str type: StreamEventType
data: dict[str, Any] = field(default_factory=dict) data: dict[str, Any] = field(default_factory=dict)
@@ -254,13 +257,53 @@ class DeerFlowClient:
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
@staticmethod
def _serialize_tool_calls(tool_calls) -> list[dict]:
"""Reshape LangChain tool_calls into the wire format used in events."""
return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls]
@staticmethod
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None) -> "StreamEvent":
"""Build a ``messages-tuple`` AI text event, attaching usage when present."""
data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
if usage:
data["usage_metadata"] = usage
return StreamEvent(type="messages-tuple", data=data)
@staticmethod
def _ai_tool_calls_event(msg_id: str | None, tool_calls) -> "StreamEvent":
"""Build a ``messages-tuple`` AI tool-calls event."""
return StreamEvent(
type="messages-tuple",
data={
"type": "ai",
"content": "",
"id": msg_id,
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
},
)
@staticmethod
def _tool_message_event(msg) -> "StreamEvent":
"""Build a ``messages-tuple`` tool-result event from a ToolMessage."""
return StreamEvent(
type="messages-tuple",
data={
"type": "tool",
"content": DeerFlowClient._extract_text(msg.content),
"name": getattr(msg, "name", None),
"tool_call_id": getattr(msg, "tool_call_id", None),
"id": getattr(msg, "id", None),
},
)
@staticmethod @staticmethod
def _serialize_message(msg) -> dict: def _serialize_message(msg) -> dict:
"""Serialize a LangChain message to a plain dict for values events.""" """Serialize a LangChain message to a plain dict for values events."""
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
d: dict[str, Any] = {"type": "ai", "content": msg.content, "id": getattr(msg, "id", None)} d: dict[str, Any] = {"type": "ai", "content": msg.content, "id": getattr(msg, "id", None)}
if msg.tool_calls: if msg.tool_calls:
d["tool_calls"] = [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls] d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls)
if getattr(msg, "usage_metadata", None): if getattr(msg, "usage_metadata", None):
d["usage_metadata"] = msg.usage_metadata d["usage_metadata"] = msg.usage_metadata
return d return d
@@ -409,28 +452,24 @@ class DeerFlowClient:
if self._agent_name: if self._agent_name:
context["agent_name"] = self._agent_name context["agent_name"] = self._agent_name
# ids already emitted as a complete message via the ``values``
# snapshot path — used by the values path itself to avoid
# duplicate per-message synthesis when the same message appears
# in consecutive snapshots.
seen_ids: set[str] = set() seen_ids: set[str] = set()
# ids whose text / tool_calls have already been streamed via the # Cross-mode handoff: ids already streamed via LangGraph ``messages``
# LangGraph ``messages`` mode. The ``values`` path uses this set # mode so the ``values`` path skips re-synthesis of the same message.
# to skip re-emitting synthesized messages-tuple events for the
# same message.
streamed_ids: set[str] = set() streamed_ids: set[str] = set()
# ids whose ``usage_metadata`` has already been counted into # The same message id carries identical cumulative ``usage_metadata``
# ``cumulative_usage``. The same message id shows up both in # in both the final ``messages`` chunk and the values snapshot —
# ``messages`` chunks (last chunk carries usage) and in ``values`` # count it only on whichever arrives first.
# snapshots (final AIMessage carries the same usage) — count once.
counted_usage_ids: set[str] = set() counted_usage_ids: set[str] = set()
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
def _account_usage(msg_id: str | None, usage: dict | None) -> dict | None: def _account_usage(msg_id: str | None, usage: Any) -> dict | None:
"""Add *usage* to cumulative totals if this id has not been counted. """Add *usage* to cumulative totals if this id has not been counted.
Returns the normalized usage dict (for attaching to an event) ``usage`` is a ``langchain_core.messages.UsageMetadata`` TypedDict
when we accepted it, otherwise ``None``. or ``None``; typed as ``Any`` because TypedDicts are not
structurally assignable to plain ``dict`` under strict type
checking. Returns the normalized usage dict (for attaching
to an event) when we accepted it, otherwise ``None``.
""" """
if not usage: if not usage:
return None return None
@@ -467,11 +506,7 @@ class DeerFlowClient:
continue continue
if mode == "messages": if mode == "messages":
# LangGraph emits ``(message_chunk, metadata_dict)`` for # LangGraph ``messages`` mode emits ``(message_chunk, metadata)``.
# each LLM delta and each tool message. ``message_chunk``
# is typically an ``AIMessageChunk`` (subclass of
# ``AIMessage``) during LLM streaming; for tool nodes it
# is a ``ToolMessage``.
if isinstance(chunk, tuple) and len(chunk) == 2: if isinstance(chunk, tuple) and len(chunk) == 2:
msg_chunk, _metadata = chunk msg_chunk, _metadata = chunk
else: else:
@@ -481,44 +516,22 @@ class DeerFlowClient:
if isinstance(msg_chunk, AIMessage): if isinstance(msg_chunk, AIMessage):
text = self._extract_text(msg_chunk.content) text = self._extract_text(msg_chunk.content)
usage = getattr(msg_chunk, "usage_metadata", None) counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata)
counted_usage = _account_usage(msg_id, usage)
if text: if text:
if msg_id: if msg_id:
streamed_ids.add(msg_id) streamed_ids.add(msg_id)
event_data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id} yield self._ai_text_event(msg_id, text, counted_usage)
if counted_usage:
event_data["usage_metadata"] = counted_usage
yield StreamEvent(type="messages-tuple", data=event_data)
tool_calls = getattr(msg_chunk, "tool_calls", None) if msg_chunk.tool_calls:
if tool_calls:
if msg_id: if msg_id:
streamed_ids.add(msg_id) streamed_ids.add(msg_id)
yield StreamEvent( yield self._ai_tool_calls_event(msg_id, msg_chunk.tool_calls)
type="messages-tuple",
data={
"type": "ai",
"content": "",
"id": msg_id,
"tool_calls": [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls],
},
)
elif isinstance(msg_chunk, ToolMessage): elif isinstance(msg_chunk, ToolMessage):
if msg_id: if msg_id:
streamed_ids.add(msg_id) streamed_ids.add(msg_id)
yield StreamEvent( yield self._tool_message_event(msg_chunk)
type="messages-tuple",
data={
"type": "tool",
"content": self._extract_text(msg_chunk.content),
"name": getattr(msg_chunk, "name", None),
"tool_call_id": getattr(msg_chunk, "tool_call_id", None),
"id": msg_id,
},
)
continue continue
# mode == "values" # mode == "values"
@@ -531,50 +544,25 @@ class DeerFlowClient:
if msg_id: if msg_id:
seen_ids.add(msg_id) seen_ids.add(msg_id)
# Already streamed through ``messages`` mode — capture # Already streamed via ``messages`` mode; only (defensively)
# usage once more (defensive: the final AIMessage in the # capture usage here and skip re-synthesizing the event.
# snapshot may carry usage_metadata that the streamed
# chunks did not) but skip synthesizing messages-tuple
# events, which would duplicate what the consumer already
# received.
if msg_id and msg_id in streamed_ids: if msg_id and msg_id in streamed_ids:
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
_account_usage(msg_id, getattr(msg, "usage_metadata", None)) _account_usage(msg_id, getattr(msg, "usage_metadata", None))
continue continue
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
usage = getattr(msg, "usage_metadata", None) counted_usage = _account_usage(msg_id, msg.usage_metadata)
counted_usage = _account_usage(msg_id, usage)
if msg.tool_calls: if msg.tool_calls:
yield StreamEvent( yield self._ai_tool_calls_event(msg_id, msg.tool_calls)
type="messages-tuple",
data={
"type": "ai",
"content": "",
"id": msg_id,
"tool_calls": [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls],
},
)
text = self._extract_text(msg.content) text = self._extract_text(msg.content)
if text: if text:
event_data = {"type": "ai", "content": text, "id": msg_id} yield self._ai_text_event(msg_id, text, counted_usage)
if counted_usage:
event_data["usage_metadata"] = counted_usage
yield StreamEvent(type="messages-tuple", data=event_data)
elif isinstance(msg, ToolMessage): elif isinstance(msg, ToolMessage):
yield StreamEvent( yield self._tool_message_event(msg)
type="messages-tuple",
data={
"type": "tool",
"content": self._extract_text(msg.content),
"name": getattr(msg, "name", None),
"tool_call_id": getattr(msg, "tool_call_id", None),
"id": msg_id,
},
)
# Emit a values event for each state snapshot # Emit a values event for each state snapshot
yield StreamEvent( yield StreamEvent(
@@ -607,21 +595,18 @@ class DeerFlowClient:
The accumulated text of the last AI message, or empty string The accumulated text of the last AI message, or empty string
if no AI text was produced. if no AI text was produced.
""" """
# Accumulator keyed by message id. Token-level streaming yields # Per-id delta lists joined once at the end — avoids the O(n²) cost
# multiple ``messages-tuple`` events sharing the same id, each # of repeated ``str + str`` on a growing buffer for long responses.
# carrying a delta that must be concatenated. Non-streaming mock chunks: dict[str, list[str]] = {}
# sources that emit a single event per id are a degenerate case
# of the same logic.
buffers: dict[str, str] = {}
last_id: str = "" last_id: str = ""
for event in self.stream(message, thread_id=thread_id, **kwargs): for event in self.stream(message, thread_id=thread_id, **kwargs):
if event.type == "messages-tuple" and event.data.get("type") == "ai": if event.type == "messages-tuple" and event.data.get("type") == "ai":
msg_id = event.data.get("id") or "" msg_id = event.data.get("id") or ""
delta = event.data.get("content", "") delta = event.data.get("content", "")
if delta: if delta:
buffers[msg_id] = buffers.get(msg_id, "") + delta chunks.setdefault(msg_id, []).append(delta)
last_id = msg_id last_id = msg_id
return buffers.get(last_id, "") return "".join(chunks.get(last_id, ()))
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Public API — configuration queries # Public API — configuration queries