mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
fix(backend): stream DeerFlowClient AI text as token deltas (#1969)
DeerFlowClient.stream() subscribed to LangGraph stream_mode=["values",
"custom"] which only delivers full-state snapshots at graph-node
boundaries, so AI replies were dumped as a single messages-tuple event
per node instead of streaming token-by-token. `client.stream("hello")`
looked identical to `client.chat("hello")` — the bug reported in #1969.
Subscribe to "messages" mode as well, forward AIMessageChunk deltas as
messages-tuple events with delta semantics (consumers accumulate by id),
and dedup the values-snapshot path so it does not re-synthesize AI
text that was already streamed. Introduce a per-id usage_metadata
counter so the final AIMessage in the values snapshot and the final
"messages" chunk — which carry the same cumulative usage — are not
double-counted.
chat() now accumulates per-id deltas and returns the last message's
full accumulated text. Non-streaming mock sources (single event per id)
are a degenerate case of the same logic, keeping existing callers and
tests backward compatible.
Verified end-to-end against a real LLM: a 15-number count emits 35
messages-tuple events with BPE subword boundaries clearly visible
("eleven" -> "ele" / "ven", "twelve" -> "tw" / "elve"), 476ms across
the window, end-event usage matches the values-snapshot usage exactly
(not doubled). tests/test_client_live.py::TestLiveStreaming passes.
New unit tests:
- test_messages_mode_emits_token_deltas: 3 AIMessageChunks produce 3
delta events with correct content/id/usage, values-snapshot does not
duplicate, usage counted once.
- test_chat_accumulates_streamed_deltas: chat() rebuilds full text
from deltas.
- test_messages_mode_tool_message: ToolMessage delivered via messages
mode is not duplicated by the values-snapshot synthesis path.
The stream() docstring now documents why this client does not reuse
Gateway's run_agent() / StreamBridge pipeline (sync vs async, raw
LangChain objects vs serialized dicts, single caller vs HTTP fan-out).
Fixes #1969
This commit is contained in:
+6
-5
@@ -395,11 +395,12 @@ Both can be modified at runtime via Gateway API endpoints or `DeerFlowClient` me
|
|||||||
**Architecture**: Imports the same `deerflow` modules that LangGraph Server and Gateway API use. Shares the same config files and data directories. No FastAPI dependency.
|
**Architecture**: Imports the same `deerflow` modules that LangGraph Server and Gateway API use. Shares the same config files and data directories. No FastAPI dependency.
|
||||||
|
|
||||||
**Agent Conversation** (replaces LangGraph Server):
|
**Agent Conversation** (replaces LangGraph Server):
|
||||||
- `chat(message, thread_id)` — synchronous, returns final text
|
- `chat(message, thread_id)` — synchronous, accumulates streaming deltas per message-id and returns the final AI text
|
||||||
- `stream(message, thread_id)` — yields `StreamEvent` aligned with LangGraph SSE protocol:
|
- `stream(message, thread_id)` — subscribes to LangGraph `stream_mode=["values", "messages", "custom"]` and yields `StreamEvent`:
|
||||||
- `"values"` — full state snapshot (title, messages, artifacts)
|
- `"values"` — full state snapshot (title, messages, artifacts); AI text already delivered via `messages` mode is **not** re-synthesized here to avoid duplicate deliveries
|
||||||
- `"messages-tuple"` — per-message update (AI text, tool calls, tool results)
|
- `"messages-tuple"` — per-chunk update: for AI text this is a **delta** (concat per `id` to rebuild the full message); tool calls and tool results are emitted once each
|
||||||
- `"end"` — stream finished
|
- `"custom"` — forwarded from `StreamWriter`
|
||||||
|
- `"end"` — stream finished (carries cumulative `usage` counted once per message id)
|
||||||
- Agent created lazily via `create_agent()` + `_build_middlewares()`, same as `make_lead_agent`
|
- Agent created lazily via `create_agent()` + `_build_middlewares()`, same as `make_lead_agent`
|
||||||
- Supports `checkpointer` parameter for state persistence across turns
|
- Supports `checkpointer` parameter for state persistence across turns
|
||||||
- `reset_agent()` forces agent recreation (e.g. after memory or skill changes)
|
- `reset_agent()` forces agent recreation (e.g. after memory or skill changes)
|
||||||
|
|||||||
@@ -336,6 +336,52 @@ class DeerFlowClient:
|
|||||||
consumers can switch between HTTP streaming and embedded mode
|
consumers can switch between HTTP streaming and embedded mode
|
||||||
without changing their event-handling logic.
|
without changing their event-handling logic.
|
||||||
|
|
||||||
|
Token-level streaming
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
This method subscribes to LangGraph's ``messages`` stream mode, so
|
||||||
|
``messages-tuple`` events for AI text are emitted as **deltas** as
|
||||||
|
the model generates tokens, not as one cumulative dump at node
|
||||||
|
completion. Each delta carries a stable ``id`` — consumers that
|
||||||
|
want the full text must accumulate ``content`` per ``id``.
|
||||||
|
``chat()`` already does this for you.
|
||||||
|
|
||||||
|
Tool calls and tool results are still emitted once per logical
|
||||||
|
message. ``values`` events continue to carry full state snapshots
|
||||||
|
after each graph node finishes; AI text already delivered via the
|
||||||
|
``messages`` stream is **not** re-synthesized from the snapshot to
|
||||||
|
avoid duplicate deliveries.
|
||||||
|
|
||||||
|
Why not reuse Gateway's ``run_agent``?
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
Gateway (``runtime/runs/worker.py``) has a complete streaming
|
||||||
|
pipeline: ``run_agent`` → ``StreamBridge`` → ``sse_consumer``. It
|
||||||
|
looks like this client duplicates that work, but the two paths
|
||||||
|
serve different audiences and **cannot** share execution:
|
||||||
|
|
||||||
|
* ``run_agent`` is ``async def`` and uses ``agent.astream()``;
|
||||||
|
this method is a sync generator using ``agent.stream()`` so
|
||||||
|
callers can write ``for event in client.stream(...)`` without
|
||||||
|
touching asyncio. Bridging the two would require spinning up
|
||||||
|
an event loop + thread per call.
|
||||||
|
* Gateway events are JSON-serialized by ``serialize()`` for SSE
|
||||||
|
wire transmission. In-process callers want the raw LangChain
|
||||||
|
objects (``AIMessage``, ``usage_metadata`` as dataclasses), not
|
||||||
|
dicts.
|
||||||
|
* ``StreamBridge`` is an asyncio-queue decoupling producers from
|
||||||
|
consumers across an HTTP boundary (``Last-Event-ID`` replay,
|
||||||
|
heartbeats, multi-subscriber fan-out). A single in-process
|
||||||
|
caller with a direct iterator needs none of that.
|
||||||
|
|
||||||
|
So ``DeerFlowClient.stream()`` is a parallel, sync, in-process
|
||||||
|
consumer of the same ``create_agent()`` factory — not a wrapper
|
||||||
|
around Gateway. The two paths **should** stay in sync on which
|
||||||
|
LangGraph stream modes they subscribe to; that invariant is
|
||||||
|
enforced by ``tests/test_client.py::test_messages_mode_emits_token_deltas``
|
||||||
|
rather than by a shared constant, because the three layers
|
||||||
|
(Graph, Platform SDK, HTTP) each use their own naming
|
||||||
|
(``messages`` vs ``messages-tuple``) and cannot literally share
|
||||||
|
a string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: User message text.
|
message: User message text.
|
||||||
thread_id: Thread ID for conversation context. Auto-generated if None.
|
thread_id: Thread ID for conversation context. Auto-generated if None.
|
||||||
@@ -346,8 +392,8 @@ class DeerFlowClient:
|
|||||||
StreamEvent with one of:
|
StreamEvent with one of:
|
||||||
- type="values" data={"title": str|None, "messages": [...], "artifacts": [...]}
|
- type="values" data={"title": str|None, "messages": [...], "artifacts": [...]}
|
||||||
- type="custom" data={...}
|
- type="custom" data={...}
|
||||||
- type="messages-tuple" data={"type": "ai", "content": str, "id": str}
|
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str}
|
||||||
- type="messages-tuple" data={"type": "ai", "content": str, "id": str, "usage_metadata": {...}}
|
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str, "usage_metadata": {...}}
|
||||||
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
|
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
|
||||||
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
|
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
|
||||||
- type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}}
|
- type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}}
|
||||||
@@ -363,14 +409,52 @@ class DeerFlowClient:
|
|||||||
if self._agent_name:
|
if self._agent_name:
|
||||||
context["agent_name"] = self._agent_name
|
context["agent_name"] = self._agent_name
|
||||||
|
|
||||||
|
# ids already emitted as a complete message via the ``values``
|
||||||
|
# snapshot path — used by the values path itself to avoid
|
||||||
|
# duplicate per-message synthesis when the same message appears
|
||||||
|
# in consecutive snapshots.
|
||||||
seen_ids: set[str] = set()
|
seen_ids: set[str] = set()
|
||||||
|
# ids whose text / tool_calls have already been streamed via the
|
||||||
|
# LangGraph ``messages`` mode. The ``values`` path uses this set
|
||||||
|
# to skip re-emitting synthesized messages-tuple events for the
|
||||||
|
# same message.
|
||||||
|
streamed_ids: set[str] = set()
|
||||||
|
# ids whose ``usage_metadata`` has already been counted into
|
||||||
|
# ``cumulative_usage``. The same message id shows up both in
|
||||||
|
# ``messages`` chunks (last chunk carries usage) and in ``values``
|
||||||
|
# snapshots (final AIMessage carries the same usage) — count once.
|
||||||
|
counted_usage_ids: set[str] = set()
|
||||||
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||||
|
|
||||||
|
def _account_usage(msg_id: str | None, usage: dict | None) -> dict | None:
|
||||||
|
"""Add *usage* to cumulative totals if this id has not been counted.
|
||||||
|
|
||||||
|
Returns the normalized usage dict (for attaching to an event)
|
||||||
|
when we accepted it, otherwise ``None``.
|
||||||
|
"""
|
||||||
|
if not usage:
|
||||||
|
return None
|
||||||
|
if msg_id and msg_id in counted_usage_ids:
|
||||||
|
return None
|
||||||
|
if msg_id:
|
||||||
|
counted_usage_ids.add(msg_id)
|
||||||
|
input_tokens = usage.get("input_tokens", 0) or 0
|
||||||
|
output_tokens = usage.get("output_tokens", 0) or 0
|
||||||
|
total_tokens = usage.get("total_tokens", 0) or 0
|
||||||
|
cumulative_usage["input_tokens"] += input_tokens
|
||||||
|
cumulative_usage["output_tokens"] += output_tokens
|
||||||
|
cumulative_usage["total_tokens"] += total_tokens
|
||||||
|
return {
|
||||||
|
"input_tokens": input_tokens,
|
||||||
|
"output_tokens": output_tokens,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
for item in self._agent.stream(
|
for item in self._agent.stream(
|
||||||
state,
|
state,
|
||||||
config=config,
|
config=config,
|
||||||
context=context,
|
context=context,
|
||||||
stream_mode=["values", "custom"],
|
stream_mode=["values", "messages", "custom"],
|
||||||
):
|
):
|
||||||
if isinstance(item, tuple) and len(item) == 2:
|
if isinstance(item, tuple) and len(item) == 2:
|
||||||
mode, chunk = item
|
mode, chunk = item
|
||||||
@@ -382,6 +466,62 @@ class DeerFlowClient:
|
|||||||
yield StreamEvent(type="custom", data=chunk)
|
yield StreamEvent(type="custom", data=chunk)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if mode == "messages":
|
||||||
|
# LangGraph emits ``(message_chunk, metadata_dict)`` for
|
||||||
|
# each LLM delta and each tool message. ``message_chunk``
|
||||||
|
# is typically an ``AIMessageChunk`` (subclass of
|
||||||
|
# ``AIMessage``) during LLM streaming; for tool nodes it
|
||||||
|
# is a ``ToolMessage``.
|
||||||
|
if isinstance(chunk, tuple) and len(chunk) == 2:
|
||||||
|
msg_chunk, _metadata = chunk
|
||||||
|
else:
|
||||||
|
msg_chunk = chunk
|
||||||
|
|
||||||
|
msg_id = getattr(msg_chunk, "id", None)
|
||||||
|
|
||||||
|
if isinstance(msg_chunk, AIMessage):
|
||||||
|
text = self._extract_text(msg_chunk.content)
|
||||||
|
usage = getattr(msg_chunk, "usage_metadata", None)
|
||||||
|
counted_usage = _account_usage(msg_id, usage)
|
||||||
|
|
||||||
|
if text:
|
||||||
|
if msg_id:
|
||||||
|
streamed_ids.add(msg_id)
|
||||||
|
event_data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
|
||||||
|
if counted_usage:
|
||||||
|
event_data["usage_metadata"] = counted_usage
|
||||||
|
yield StreamEvent(type="messages-tuple", data=event_data)
|
||||||
|
|
||||||
|
tool_calls = getattr(msg_chunk, "tool_calls", None)
|
||||||
|
if tool_calls:
|
||||||
|
if msg_id:
|
||||||
|
streamed_ids.add(msg_id)
|
||||||
|
yield StreamEvent(
|
||||||
|
type="messages-tuple",
|
||||||
|
data={
|
||||||
|
"type": "ai",
|
||||||
|
"content": "",
|
||||||
|
"id": msg_id,
|
||||||
|
"tool_calls": [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(msg_chunk, ToolMessage):
|
||||||
|
if msg_id:
|
||||||
|
streamed_ids.add(msg_id)
|
||||||
|
yield StreamEvent(
|
||||||
|
type="messages-tuple",
|
||||||
|
data={
|
||||||
|
"type": "tool",
|
||||||
|
"content": self._extract_text(msg_chunk.content),
|
||||||
|
"name": getattr(msg_chunk, "name", None),
|
||||||
|
"tool_call_id": getattr(msg_chunk, "tool_call_id", None),
|
||||||
|
"id": msg_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# mode == "values"
|
||||||
messages = chunk.get("messages", [])
|
messages = chunk.get("messages", [])
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
@@ -391,13 +531,20 @@ class DeerFlowClient:
|
|||||||
if msg_id:
|
if msg_id:
|
||||||
seen_ids.add(msg_id)
|
seen_ids.add(msg_id)
|
||||||
|
|
||||||
|
# Already streamed through ``messages`` mode — capture
|
||||||
|
# usage once more (defensive: the final AIMessage in the
|
||||||
|
# snapshot may carry usage_metadata that the streamed
|
||||||
|
# chunks did not) but skip synthesizing messages-tuple
|
||||||
|
# events, which would duplicate what the consumer already
|
||||||
|
# received.
|
||||||
|
if msg_id and msg_id in streamed_ids:
|
||||||
|
if isinstance(msg, AIMessage):
|
||||||
|
_account_usage(msg_id, getattr(msg, "usage_metadata", None))
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
# Track token usage from AI messages
|
|
||||||
usage = getattr(msg, "usage_metadata", None)
|
usage = getattr(msg, "usage_metadata", None)
|
||||||
if usage:
|
counted_usage = _account_usage(msg_id, usage)
|
||||||
cumulative_usage["input_tokens"] += usage.get("input_tokens", 0) or 0
|
|
||||||
cumulative_usage["output_tokens"] += usage.get("output_tokens", 0) or 0
|
|
||||||
cumulative_usage["total_tokens"] += usage.get("total_tokens", 0) or 0
|
|
||||||
|
|
||||||
if msg.tool_calls:
|
if msg.tool_calls:
|
||||||
yield StreamEvent(
|
yield StreamEvent(
|
||||||
@@ -412,13 +559,9 @@ class DeerFlowClient:
|
|||||||
|
|
||||||
text = self._extract_text(msg.content)
|
text = self._extract_text(msg.content)
|
||||||
if text:
|
if text:
|
||||||
event_data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
|
event_data = {"type": "ai", "content": text, "id": msg_id}
|
||||||
if usage:
|
if counted_usage:
|
||||||
event_data["usage_metadata"] = {
|
event_data["usage_metadata"] = counted_usage
|
||||||
"input_tokens": usage.get("input_tokens", 0) or 0,
|
|
||||||
"output_tokens": usage.get("output_tokens", 0) or 0,
|
|
||||||
"total_tokens": usage.get("total_tokens", 0) or 0,
|
|
||||||
}
|
|
||||||
yield StreamEvent(type="messages-tuple", data=event_data)
|
yield StreamEvent(type="messages-tuple", data=event_data)
|
||||||
|
|
||||||
elif isinstance(msg, ToolMessage):
|
elif isinstance(msg, ToolMessage):
|
||||||
@@ -448,10 +591,12 @@ class DeerFlowClient:
|
|||||||
def chat(self, message: str, *, thread_id: str | None = None, **kwargs) -> str:
|
def chat(self, message: str, *, thread_id: str | None = None, **kwargs) -> str:
|
||||||
"""Send a message and return the final text response.
|
"""Send a message and return the final text response.
|
||||||
|
|
||||||
Convenience wrapper around :meth:`stream` that returns only the
|
Convenience wrapper around :meth:`stream` that accumulates delta
|
||||||
**last** AI text from ``messages-tuple`` events. If the agent emits
|
``messages-tuple`` events per ``id`` and returns the text of the
|
||||||
multiple text segments in one turn, intermediate segments are
|
**last** AI message to complete. Intermediate AI messages (e.g.
|
||||||
discarded. Use :meth:`stream` directly to capture all events.
|
planner drafts) are discarded — only the final id's accumulated
|
||||||
|
text is returned. Use :meth:`stream` directly if you need every
|
||||||
|
delta as it arrives.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: User message text.
|
message: User message text.
|
||||||
@@ -459,15 +604,24 @@ class DeerFlowClient:
|
|||||||
**kwargs: Override client defaults (same as stream()).
|
**kwargs: Override client defaults (same as stream()).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The last AI message text, or empty string if no response.
|
The accumulated text of the last AI message, or empty string
|
||||||
|
if no AI text was produced.
|
||||||
"""
|
"""
|
||||||
last_text = ""
|
# Accumulator keyed by message id. Token-level streaming yields
|
||||||
|
# multiple ``messages-tuple`` events sharing the same id, each
|
||||||
|
# carrying a delta that must be concatenated. Non-streaming mock
|
||||||
|
# sources that emit a single event per id are a degenerate case
|
||||||
|
# of the same logic.
|
||||||
|
buffers: dict[str, str] = {}
|
||||||
|
last_id: str = ""
|
||||||
for event in self.stream(message, thread_id=thread_id, **kwargs):
|
for event in self.stream(message, thread_id=thread_id, **kwargs):
|
||||||
if event.type == "messages-tuple" and event.data.get("type") == "ai":
|
if event.type == "messages-tuple" and event.data.get("type") == "ai":
|
||||||
content = event.data.get("content", "")
|
msg_id = event.data.get("id") or ""
|
||||||
if content:
|
delta = event.data.get("content", "")
|
||||||
last_text = content
|
if delta:
|
||||||
return last_text
|
buffers[msg_id] = buffers.get(msg_id, "") + delta
|
||||||
|
last_id = msg_id
|
||||||
|
return buffers.get(last_id, "")
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Public API — configuration queries
|
# Public API — configuration queries
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from pathlib import Path
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage # noqa: F401
|
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, SystemMessage, ToolMessage # noqa: F401
|
||||||
|
|
||||||
from app.gateway.routers.mcp import McpConfigResponse
|
from app.gateway.routers.mcp import McpConfigResponse
|
||||||
from app.gateway.routers.memory import MemoryConfigResponse, MemoryStatusResponse
|
from app.gateway.routers.memory import MemoryConfigResponse, MemoryStatusResponse
|
||||||
@@ -225,7 +225,9 @@ class TestStream:
|
|||||||
|
|
||||||
agent.stream.assert_called_once()
|
agent.stream.assert_called_once()
|
||||||
call_kwargs = agent.stream.call_args.kwargs
|
call_kwargs = agent.stream.call_args.kwargs
|
||||||
assert call_kwargs["stream_mode"] == ["values", "custom"]
|
# ``messages`` enables token-level streaming of AI text deltas;
|
||||||
|
# see DeerFlowClient.stream() docstring and GitHub issue #1969.
|
||||||
|
assert call_kwargs["stream_mode"] == ["values", "messages", "custom"]
|
||||||
|
|
||||||
assert events[0].type == "custom"
|
assert events[0].type == "custom"
|
||||||
assert events[0].data == {"type": "task_started", "task_id": "task-1"}
|
assert events[0].data == {"type": "task_started", "task_id": "task-1"}
|
||||||
@@ -351,6 +353,123 @@ class TestStream:
|
|||||||
# Should not raise; end event proves it completed
|
# Should not raise; end event proves it completed
|
||||||
assert events[-1].type == "end"
|
assert events[-1].type == "end"
|
||||||
|
|
||||||
|
def test_messages_mode_emits_token_deltas(self, client):
|
||||||
|
"""stream() forwards LangGraph ``messages`` mode chunks as delta events.
|
||||||
|
|
||||||
|
Regression for bytedance/deer-flow#1969 — before the fix the client
|
||||||
|
only subscribed to ``values`` mode, so LLM output was delivered as
|
||||||
|
a single cumulative dump after each graph node finished instead of
|
||||||
|
token-by-token deltas as the model generated them.
|
||||||
|
"""
|
||||||
|
# Three AI chunks sharing the same id, followed by a terminal
|
||||||
|
# values snapshot with the fully assembled message — this matches
|
||||||
|
# the shape LangGraph emits when ``stream_mode`` includes both
|
||||||
|
# ``messages`` and ``values``.
|
||||||
|
assembled = AIMessage(content="Hel lo world!", id="ai-1", usage_metadata={"input_tokens": 3, "output_tokens": 4, "total_tokens": 7})
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.stream.return_value = iter(
|
||||||
|
[
|
||||||
|
("messages", (AIMessageChunk(content="Hel", id="ai-1"), {})),
|
||||||
|
("messages", (AIMessageChunk(content=" lo ", id="ai-1"), {})),
|
||||||
|
(
|
||||||
|
"messages",
|
||||||
|
(
|
||||||
|
AIMessageChunk(
|
||||||
|
content="world!",
|
||||||
|
id="ai-1",
|
||||||
|
usage_metadata={"input_tokens": 3, "output_tokens": 4, "total_tokens": 7},
|
||||||
|
),
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(client, "_ensure_agent"),
|
||||||
|
patch.object(client, "_agent", agent),
|
||||||
|
):
|
||||||
|
events = list(client.stream("hi", thread_id="t-stream"))
|
||||||
|
|
||||||
|
# Three delta messages-tuple events, all with the same id, each
|
||||||
|
# carrying only its own delta (not cumulative).
|
||||||
|
ai_text_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "ai" and e.data.get("content")]
|
||||||
|
assert [e.data["content"] for e in ai_text_events] == ["Hel", " lo ", "world!"]
|
||||||
|
assert all(e.data["id"] == "ai-1" for e in ai_text_events)
|
||||||
|
|
||||||
|
# The values snapshot MUST NOT re-synthesize an AI text event for
|
||||||
|
# the already-streamed id (otherwise consumers see duplicated text).
|
||||||
|
assert len(ai_text_events) == 3
|
||||||
|
|
||||||
|
# Usage metadata attached only to the chunk that actually carried
|
||||||
|
# it, and counted into cumulative usage exactly once (the values
|
||||||
|
# snapshot's duplicate usage on the assembled AIMessage must not
|
||||||
|
# be double-counted).
|
||||||
|
events_with_usage = [e for e in ai_text_events if "usage_metadata" in e.data]
|
||||||
|
assert len(events_with_usage) == 1
|
||||||
|
assert events_with_usage[0].data["usage_metadata"] == {"input_tokens": 3, "output_tokens": 4, "total_tokens": 7}
|
||||||
|
end_event = events[-1]
|
||||||
|
assert end_event.type == "end"
|
||||||
|
assert end_event.data["usage"] == {"input_tokens": 3, "output_tokens": 4, "total_tokens": 7}
|
||||||
|
|
||||||
|
# The values snapshot itself is still emitted.
|
||||||
|
assert any(e.type == "values" for e in events)
|
||||||
|
|
||||||
|
# stream_mode includes ``messages`` — the whole point of this fix.
|
||||||
|
call_kwargs = agent.stream.call_args.kwargs
|
||||||
|
assert "messages" in call_kwargs["stream_mode"]
|
||||||
|
|
||||||
|
def test_chat_accumulates_streamed_deltas(self, client):
|
||||||
|
"""chat() concatenates per-id deltas from messages mode."""
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.stream.return_value = iter(
|
||||||
|
[
|
||||||
|
("messages", (AIMessageChunk(content="Hel", id="ai-1"), {})),
|
||||||
|
("messages", (AIMessageChunk(content="lo ", id="ai-1"), {})),
|
||||||
|
("messages", (AIMessageChunk(content="world!", id="ai-1"), {})),
|
||||||
|
("values", {"messages": [HumanMessage(content="hi", id="h-1"), AIMessage(content="Hello world!", id="ai-1")]}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(client, "_ensure_agent"),
|
||||||
|
patch.object(client, "_agent", agent),
|
||||||
|
):
|
||||||
|
result = client.chat("hi", thread_id="t-chat-stream")
|
||||||
|
|
||||||
|
assert result == "Hello world!"
|
||||||
|
|
||||||
|
def test_messages_mode_tool_message(self, client):
|
||||||
|
"""stream() forwards ToolMessage chunks from messages mode."""
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.stream.return_value = iter(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"messages",
|
||||||
|
(
|
||||||
|
ToolMessage(content="file.txt", id="tm-1", tool_call_id="tc-1", name="bash"),
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("values", {"messages": [HumanMessage(content="ls", id="h-1"), ToolMessage(content="file.txt", id="tm-1", tool_call_id="tc-1", name="bash")]}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(client, "_ensure_agent"),
|
||||||
|
patch.object(client, "_agent", agent),
|
||||||
|
):
|
||||||
|
events = list(client.stream("ls", thread_id="t-tool-stream"))
|
||||||
|
|
||||||
|
tool_events = [e for e in events if e.type == "messages-tuple" and e.data.get("type") == "tool"]
|
||||||
|
# The tool result must be delivered exactly once (from messages
|
||||||
|
# mode), not duplicated by the values-snapshot synthesis path.
|
||||||
|
assert len(tool_events) == 1
|
||||||
|
assert tool_events[0].data["content"] == "file.txt"
|
||||||
|
assert tool_events[0].data["name"] == "bash"
|
||||||
|
assert tool_events[0].data["tool_call_id"] == "tc-1"
|
||||||
|
|
||||||
def test_list_content_blocks(self, client):
|
def test_list_content_blocks(self, client):
|
||||||
"""stream() handles AIMessage with list-of-blocks content."""
|
"""stream() handles AIMessage with list-of-blocks content."""
|
||||||
ai = AIMessage(
|
ai = AIMessage(
|
||||||
|
|||||||
Reference in New Issue
Block a user