mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 07:26:50 +00:00
Merge branch 'main' into rayhpeng/persistence-scaffold
This commit is contained in:
@@ -5,6 +5,7 @@ import concurrent.futures
|
||||
import json
|
||||
import tempfile
|
||||
import zipfile
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -205,6 +206,33 @@ class TestStream:
|
||||
msg_events = _ai_events(events)
|
||||
assert msg_events[0].data["content"] == "Hello!"
|
||||
|
||||
def test_custom_events_are_forwarded(self, client):
|
||||
"""stream() forwards custom stream events alongside normal values output."""
|
||||
ai = AIMessage(content="Hello!", id="ai-1")
|
||||
agent = MagicMock()
|
||||
agent.stream.return_value = iter(
|
||||
[
|
||||
("custom", {"type": "task_started", "task_id": "task-1"}),
|
||||
("values", {"messages": [HumanMessage(content="hi", id="h-1"), ai]}),
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(client, "_ensure_agent"),
|
||||
patch.object(client, "_agent", agent),
|
||||
):
|
||||
events = list(client.stream("hi", thread_id="t-custom"))
|
||||
|
||||
agent.stream.assert_called_once()
|
||||
call_kwargs = agent.stream.call_args.kwargs
|
||||
assert call_kwargs["stream_mode"] == ["values", "custom"]
|
||||
|
||||
assert events[0].type == "custom"
|
||||
assert events[0].data == {"type": "task_started", "task_id": "task-1"}
|
||||
assert any(event.type == "messages-tuple" and event.data["content"] == "Hello!" for event in events)
|
||||
assert any(event.type == "values" for event in events)
|
||||
assert events[-1].type == "end"
|
||||
|
||||
def test_context_propagation(self, client):
|
||||
"""stream() passes agent_name to the context."""
|
||||
agent = _make_agent_mock([{"messages": [AIMessage(content="ok", id="ai-1")]}])
|
||||
@@ -222,6 +250,33 @@ class TestStream:
|
||||
assert call_kwargs["context"]["thread_id"] == "t1"
|
||||
assert call_kwargs["context"]["agent_name"] == "test-agent-1"
|
||||
|
||||
def test_custom_mode_is_normalized_to_string(self, client):
|
||||
"""stream() forwards custom events even when the mode is not a plain string."""
|
||||
|
||||
class StreamMode(Enum):
|
||||
CUSTOM = "custom"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
agent = _make_agent_mock(
|
||||
[
|
||||
(StreamMode.CUSTOM, {"type": "task_started", "task_id": "task-1"}),
|
||||
{"messages": [AIMessage(content="Hello!", id="ai-1")]},
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(client, "_ensure_agent"),
|
||||
patch.object(client, "_agent", agent),
|
||||
):
|
||||
events = list(client.stream("hi", thread_id="t-custom-enum"))
|
||||
|
||||
assert events[0].type == "custom"
|
||||
assert events[0].data == {"type": "task_started", "task_id": "task-1"}
|
||||
assert any(event.type == "messages-tuple" and event.data["content"] == "Hello!" for event in events)
|
||||
assert events[-1].type == "end"
|
||||
|
||||
def test_tool_call_and_result(self, client):
|
||||
"""stream() emits messages-tuple events for tool calls and results."""
|
||||
ai = AIMessage(content="", id="ai-1", tool_calls=[{"name": "bash", "args": {"cmd": "ls"}, "id": "tc-1"}])
|
||||
|
||||
Reference in New Issue
Block a user