mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-20 15:11:09 +00:00
refactor(backend): simplify DeerFlowClient streaming helpers (#1969)
Post-review cleanup for the token-level streaming fix. No behavior change for correct inputs; one efficiency regression fixed. Fix: chat() O(n²) accumulator ----------------------------- `chat()` accumulated per-id text via `buffers[id] = buffers.get(id,"") + delta`, which is O(n) per concat → O(n²) total over a streamed response. At ~2 KB cumulative text this becomes user-visible; at 50 KB / 5000 chunks it costs roughly 100-300 ms of pure copying. Switched to `dict[str, list[str]]` + `"".join()` once at return. Cleanup ------- - Extract `_serialize_tool_calls`, `_ai_text_event`, `_ai_tool_calls_event`, and `_tool_message_event` static helpers. The messages-mode and values-mode branches previously repeated four inline dict literals each; they now call the same builders. - `StreamEvent.type` is now typed as `Literal["values", "messages-tuple", "custom", "end"]` via a `StreamEventType` alias. Makes the closed set explicit and catches typos at type-check time. - Direct attribute access on `AIMessage`/`AIMessageChunk`: `.usage_metadata`, `.tool_calls`, `.id` all have default values on the base class, so the `getattr(..., None)` fallbacks were dead code. Removed from the hot path. - `_account_usage` parameter type loosened to `Any` so that LangChain's `UsageMetadata` TypedDict is accepted under strict type checking. - Trimmed narrating comments on `seen_ids` / `streamed_ids` / the values-synthesis skip block; kept the non-obvious ones that document the cross-mode dedup invariant. Net diff: -15 lines. All 132 unit tests + harness boundary test still pass; ruff check and ruff format pass.
This commit is contained in:
@@ -25,7 +25,7 @@ import uuid
|
|||||||
from collections.abc import Generator, Sequence
|
from collections.abc import Generator, Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
@@ -55,6 +55,9 @@ from deerflow.uploads.manager import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
StreamEventType = Literal["values", "messages-tuple", "custom", "end"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StreamEvent:
|
class StreamEvent:
|
||||||
"""A single event from the streaming agent response.
|
"""A single event from the streaming agent response.
|
||||||
@@ -69,7 +72,7 @@ class StreamEvent:
|
|||||||
data: Event payload. Contents vary by type.
|
data: Event payload. Contents vary by type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str
|
type: StreamEventType
|
||||||
data: dict[str, Any] = field(default_factory=dict)
|
data: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@@ -254,13 +257,53 @@ class DeerFlowClient:
|
|||||||
|
|
||||||
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
|
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_tool_calls(tool_calls) -> list[dict]:
|
||||||
|
"""Reshape LangChain tool_calls into the wire format used in events."""
|
||||||
|
return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None) -> "StreamEvent":
|
||||||
|
"""Build a ``messages-tuple`` AI text event, attaching usage when present."""
|
||||||
|
data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
|
||||||
|
if usage:
|
||||||
|
data["usage_metadata"] = usage
|
||||||
|
return StreamEvent(type="messages-tuple", data=data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ai_tool_calls_event(msg_id: str | None, tool_calls) -> "StreamEvent":
|
||||||
|
"""Build a ``messages-tuple`` AI tool-calls event."""
|
||||||
|
return StreamEvent(
|
||||||
|
type="messages-tuple",
|
||||||
|
data={
|
||||||
|
"type": "ai",
|
||||||
|
"content": "",
|
||||||
|
"id": msg_id,
|
||||||
|
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tool_message_event(msg) -> "StreamEvent":
|
||||||
|
"""Build a ``messages-tuple`` tool-result event from a ToolMessage."""
|
||||||
|
return StreamEvent(
|
||||||
|
type="messages-tuple",
|
||||||
|
data={
|
||||||
|
"type": "tool",
|
||||||
|
"content": DeerFlowClient._extract_text(msg.content),
|
||||||
|
"name": getattr(msg, "name", None),
|
||||||
|
"tool_call_id": getattr(msg, "tool_call_id", None),
|
||||||
|
"id": getattr(msg, "id", None),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_message(msg) -> dict:
|
def _serialize_message(msg) -> dict:
|
||||||
"""Serialize a LangChain message to a plain dict for values events."""
|
"""Serialize a LangChain message to a plain dict for values events."""
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
d: dict[str, Any] = {"type": "ai", "content": msg.content, "id": getattr(msg, "id", None)}
|
d: dict[str, Any] = {"type": "ai", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||||
if msg.tool_calls:
|
if msg.tool_calls:
|
||||||
d["tool_calls"] = [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls]
|
d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls)
|
||||||
if getattr(msg, "usage_metadata", None):
|
if getattr(msg, "usage_metadata", None):
|
||||||
d["usage_metadata"] = msg.usage_metadata
|
d["usage_metadata"] = msg.usage_metadata
|
||||||
return d
|
return d
|
||||||
@@ -409,28 +452,24 @@ class DeerFlowClient:
|
|||||||
if self._agent_name:
|
if self._agent_name:
|
||||||
context["agent_name"] = self._agent_name
|
context["agent_name"] = self._agent_name
|
||||||
|
|
||||||
# ids already emitted as a complete message via the ``values``
|
|
||||||
# snapshot path — used by the values path itself to avoid
|
|
||||||
# duplicate per-message synthesis when the same message appears
|
|
||||||
# in consecutive snapshots.
|
|
||||||
seen_ids: set[str] = set()
|
seen_ids: set[str] = set()
|
||||||
# ids whose text / tool_calls have already been streamed via the
|
# Cross-mode handoff: ids already streamed via LangGraph ``messages``
|
||||||
# LangGraph ``messages`` mode. The ``values`` path uses this set
|
# mode so the ``values`` path skips re-synthesis of the same message.
|
||||||
# to skip re-emitting synthesized messages-tuple events for the
|
|
||||||
# same message.
|
|
||||||
streamed_ids: set[str] = set()
|
streamed_ids: set[str] = set()
|
||||||
# ids whose ``usage_metadata`` has already been counted into
|
# The same message id carries identical cumulative ``usage_metadata``
|
||||||
# ``cumulative_usage``. The same message id shows up both in
|
# in both the final ``messages`` chunk and the values snapshot —
|
||||||
# ``messages`` chunks (last chunk carries usage) and in ``values``
|
# count it only on whichever arrives first.
|
||||||
# snapshots (final AIMessage carries the same usage) — count once.
|
|
||||||
counted_usage_ids: set[str] = set()
|
counted_usage_ids: set[str] = set()
|
||||||
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||||
|
|
||||||
def _account_usage(msg_id: str | None, usage: dict | None) -> dict | None:
|
def _account_usage(msg_id: str | None, usage: Any) -> dict | None:
|
||||||
"""Add *usage* to cumulative totals if this id has not been counted.
|
"""Add *usage* to cumulative totals if this id has not been counted.
|
||||||
|
|
||||||
Returns the normalized usage dict (for attaching to an event)
|
``usage`` is a ``langchain_core.messages.UsageMetadata`` TypedDict
|
||||||
when we accepted it, otherwise ``None``.
|
or ``None``; typed as ``Any`` because TypedDicts are not
|
||||||
|
structurally assignable to plain ``dict`` under strict type
|
||||||
|
checking. Returns the normalized usage dict (for attaching
|
||||||
|
to an event) when we accepted it, otherwise ``None``.
|
||||||
"""
|
"""
|
||||||
if not usage:
|
if not usage:
|
||||||
return None
|
return None
|
||||||
@@ -467,11 +506,7 @@ class DeerFlowClient:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if mode == "messages":
|
if mode == "messages":
|
||||||
# LangGraph emits ``(message_chunk, metadata_dict)`` for
|
# LangGraph ``messages`` mode emits ``(message_chunk, metadata)``.
|
||||||
# each LLM delta and each tool message. ``message_chunk``
|
|
||||||
# is typically an ``AIMessageChunk`` (subclass of
|
|
||||||
# ``AIMessage``) during LLM streaming; for tool nodes it
|
|
||||||
# is a ``ToolMessage``.
|
|
||||||
if isinstance(chunk, tuple) and len(chunk) == 2:
|
if isinstance(chunk, tuple) and len(chunk) == 2:
|
||||||
msg_chunk, _metadata = chunk
|
msg_chunk, _metadata = chunk
|
||||||
else:
|
else:
|
||||||
@@ -481,44 +516,22 @@ class DeerFlowClient:
|
|||||||
|
|
||||||
if isinstance(msg_chunk, AIMessage):
|
if isinstance(msg_chunk, AIMessage):
|
||||||
text = self._extract_text(msg_chunk.content)
|
text = self._extract_text(msg_chunk.content)
|
||||||
usage = getattr(msg_chunk, "usage_metadata", None)
|
counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata)
|
||||||
counted_usage = _account_usage(msg_id, usage)
|
|
||||||
|
|
||||||
if text:
|
if text:
|
||||||
if msg_id:
|
if msg_id:
|
||||||
streamed_ids.add(msg_id)
|
streamed_ids.add(msg_id)
|
||||||
event_data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
|
yield self._ai_text_event(msg_id, text, counted_usage)
|
||||||
if counted_usage:
|
|
||||||
event_data["usage_metadata"] = counted_usage
|
|
||||||
yield StreamEvent(type="messages-tuple", data=event_data)
|
|
||||||
|
|
||||||
tool_calls = getattr(msg_chunk, "tool_calls", None)
|
if msg_chunk.tool_calls:
|
||||||
if tool_calls:
|
|
||||||
if msg_id:
|
if msg_id:
|
||||||
streamed_ids.add(msg_id)
|
streamed_ids.add(msg_id)
|
||||||
yield StreamEvent(
|
yield self._ai_tool_calls_event(msg_id, msg_chunk.tool_calls)
|
||||||
type="messages-tuple",
|
|
||||||
data={
|
|
||||||
"type": "ai",
|
|
||||||
"content": "",
|
|
||||||
"id": msg_id,
|
|
||||||
"tool_calls": [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(msg_chunk, ToolMessage):
|
elif isinstance(msg_chunk, ToolMessage):
|
||||||
if msg_id:
|
if msg_id:
|
||||||
streamed_ids.add(msg_id)
|
streamed_ids.add(msg_id)
|
||||||
yield StreamEvent(
|
yield self._tool_message_event(msg_chunk)
|
||||||
type="messages-tuple",
|
|
||||||
data={
|
|
||||||
"type": "tool",
|
|
||||||
"content": self._extract_text(msg_chunk.content),
|
|
||||||
"name": getattr(msg_chunk, "name", None),
|
|
||||||
"tool_call_id": getattr(msg_chunk, "tool_call_id", None),
|
|
||||||
"id": msg_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# mode == "values"
|
# mode == "values"
|
||||||
@@ -531,50 +544,25 @@ class DeerFlowClient:
|
|||||||
if msg_id:
|
if msg_id:
|
||||||
seen_ids.add(msg_id)
|
seen_ids.add(msg_id)
|
||||||
|
|
||||||
# Already streamed through ``messages`` mode — capture
|
# Already streamed via ``messages`` mode; only (defensively)
|
||||||
# usage once more (defensive: the final AIMessage in the
|
# capture usage here and skip re-synthesizing the event.
|
||||||
# snapshot may carry usage_metadata that the streamed
|
|
||||||
# chunks did not) but skip synthesizing messages-tuple
|
|
||||||
# events, which would duplicate what the consumer already
|
|
||||||
# received.
|
|
||||||
if msg_id and msg_id in streamed_ids:
|
if msg_id and msg_id in streamed_ids:
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
_account_usage(msg_id, getattr(msg, "usage_metadata", None))
|
_account_usage(msg_id, getattr(msg, "usage_metadata", None))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
usage = getattr(msg, "usage_metadata", None)
|
counted_usage = _account_usage(msg_id, msg.usage_metadata)
|
||||||
counted_usage = _account_usage(msg_id, usage)
|
|
||||||
|
|
||||||
if msg.tool_calls:
|
if msg.tool_calls:
|
||||||
yield StreamEvent(
|
yield self._ai_tool_calls_event(msg_id, msg.tool_calls)
|
||||||
type="messages-tuple",
|
|
||||||
data={
|
|
||||||
"type": "ai",
|
|
||||||
"content": "",
|
|
||||||
"id": msg_id,
|
|
||||||
"tool_calls": [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
text = self._extract_text(msg.content)
|
text = self._extract_text(msg.content)
|
||||||
if text:
|
if text:
|
||||||
event_data = {"type": "ai", "content": text, "id": msg_id}
|
yield self._ai_text_event(msg_id, text, counted_usage)
|
||||||
if counted_usage:
|
|
||||||
event_data["usage_metadata"] = counted_usage
|
|
||||||
yield StreamEvent(type="messages-tuple", data=event_data)
|
|
||||||
|
|
||||||
elif isinstance(msg, ToolMessage):
|
elif isinstance(msg, ToolMessage):
|
||||||
yield StreamEvent(
|
yield self._tool_message_event(msg)
|
||||||
type="messages-tuple",
|
|
||||||
data={
|
|
||||||
"type": "tool",
|
|
||||||
"content": self._extract_text(msg.content),
|
|
||||||
"name": getattr(msg, "name", None),
|
|
||||||
"tool_call_id": getattr(msg, "tool_call_id", None),
|
|
||||||
"id": msg_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Emit a values event for each state snapshot
|
# Emit a values event for each state snapshot
|
||||||
yield StreamEvent(
|
yield StreamEvent(
|
||||||
@@ -607,21 +595,18 @@ class DeerFlowClient:
|
|||||||
The accumulated text of the last AI message, or empty string
|
The accumulated text of the last AI message, or empty string
|
||||||
if no AI text was produced.
|
if no AI text was produced.
|
||||||
"""
|
"""
|
||||||
# Accumulator keyed by message id. Token-level streaming yields
|
# Per-id delta lists joined once at the end — avoids the O(n²) cost
|
||||||
# multiple ``messages-tuple`` events sharing the same id, each
|
# of repeated ``str + str`` on a growing buffer for long responses.
|
||||||
# carrying a delta that must be concatenated. Non-streaming mock
|
chunks: dict[str, list[str]] = {}
|
||||||
# sources that emit a single event per id are a degenerate case
|
|
||||||
# of the same logic.
|
|
||||||
buffers: dict[str, str] = {}
|
|
||||||
last_id: str = ""
|
last_id: str = ""
|
||||||
for event in self.stream(message, thread_id=thread_id, **kwargs):
|
for event in self.stream(message, thread_id=thread_id, **kwargs):
|
||||||
if event.type == "messages-tuple" and event.data.get("type") == "ai":
|
if event.type == "messages-tuple" and event.data.get("type") == "ai":
|
||||||
msg_id = event.data.get("id") or ""
|
msg_id = event.data.get("id") or ""
|
||||||
delta = event.data.get("content", "")
|
delta = event.data.get("content", "")
|
||||||
if delta:
|
if delta:
|
||||||
buffers[msg_id] = buffers.get(msg_id, "") + delta
|
chunks.setdefault(msg_id, []).append(delta)
|
||||||
last_id = msg_id
|
last_id = msg_id
|
||||||
return buffers.get(last_id, "")
|
return "".join(chunks.get(last_id, ()))
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Public API — configuration queries
|
# Public API — configuration queries
|
||||||
|
|||||||
Reference in New Issue
Block a user