mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
Merge branch 'main' into rayhpeng/persistence-scaffold
# Conflicts: # backend/tests/test_model_factory.py
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
"""Tests for create_deerflow_agent SDK entry point."""
|
||||
|
||||
from typing import get_type_hints
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.agents.factory import create_deerflow_agent
|
||||
from deerflow.agents.features import Next, Prev, RuntimeFeatures
|
||||
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
|
||||
|
||||
def _make_mock_model():
|
||||
@@ -127,6 +130,13 @@ def test_vision_injects_view_image_tool(mock_create_agent):
|
||||
assert "view_image" in tool_names
|
||||
|
||||
|
||||
def test_view_image_middleware_preserves_viewed_images_reducer():
|
||||
middleware_hints = get_type_hints(ViewImageMiddleware.state_schema, include_extras=True)
|
||||
thread_hints = get_type_hints(ThreadState, include_extras=True)
|
||||
|
||||
assert middleware_hints["viewed_images"] == thread_hints["viewed_images"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Subagent feature auto-injects task_tool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -604,6 +604,63 @@ def test_codex_provider_strips_unsupported_max_tokens(monkeypatch):
|
||||
assert "max_tokens" not in FakeChatModel.captured_kwargs
|
||||
|
||||
|
||||
def test_thinking_disabled_vllm_chat_template_format(monkeypatch):
|
||||
wte = {"extra_body": {"chat_template_kwargs": {"thinking": True}}}
|
||||
model = _make_model(
|
||||
"vllm-qwen",
|
||||
use="deerflow.models.vllm_provider:VllmChatModel",
|
||||
supports_thinking=True,
|
||||
when_thinking_enabled=wte,
|
||||
)
|
||||
model.extra_body = {"top_k": 20}
|
||||
cfg = _make_app_config([model])
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(FakeChatModel):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="vllm-qwen", thinking_enabled=False)
|
||||
|
||||
assert captured.get("extra_body") == {"top_k": 20, "chat_template_kwargs": {"thinking": False}}
|
||||
assert captured.get("reasoning_effort") is None
|
||||
|
||||
|
||||
def test_thinking_disabled_vllm_enable_thinking_format(monkeypatch):
|
||||
wte = {"extra_body": {"chat_template_kwargs": {"enable_thinking": True}}}
|
||||
model = _make_model(
|
||||
"vllm-qwen-enable",
|
||||
use="deerflow.models.vllm_provider:VllmChatModel",
|
||||
supports_thinking=True,
|
||||
when_thinking_enabled=wte,
|
||||
)
|
||||
model.extra_body = {"top_k": 20}
|
||||
cfg = _make_app_config([model])
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(FakeChatModel):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="vllm-qwen-enable", thinking_enabled=False)
|
||||
|
||||
assert captured.get("extra_body") == {
|
||||
"top_k": 20,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
}
|
||||
assert captured.get("reasoning_effort") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stream_usage injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
+119
-122
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, make_stream_bridge
|
||||
@@ -44,7 +45,7 @@ async def test_publish_subscribe(bridge: MemoryStreamBridge):
|
||||
async def test_heartbeat(bridge: MemoryStreamBridge):
|
||||
"""When no events arrive within the heartbeat interval, yield a heartbeat."""
|
||||
run_id = "run-heartbeat"
|
||||
bridge._get_or_create_queue(run_id) # ensure queue exists
|
||||
bridge._get_or_create_stream(run_id) # ensure stream exists
|
||||
|
||||
received = []
|
||||
|
||||
@@ -61,37 +62,35 @@ async def test_heartbeat(bridge: MemoryStreamBridge):
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cleanup(bridge: MemoryStreamBridge):
|
||||
"""After cleanup, the run's queue is removed."""
|
||||
"""After cleanup, the run's stream/event log is removed."""
|
||||
run_id = "run-cleanup"
|
||||
await bridge.publish(run_id, "test", {})
|
||||
assert run_id in bridge._queues
|
||||
assert run_id in bridge._streams
|
||||
|
||||
await bridge.cleanup(run_id)
|
||||
assert run_id not in bridge._queues
|
||||
assert run_id not in bridge._streams
|
||||
assert run_id not in bridge._counters
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_backpressure():
|
||||
"""With maxsize=1, publish should not block forever."""
|
||||
async def test_history_is_bounded():
|
||||
"""Retained history should be bounded by queue_maxsize."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=1)
|
||||
run_id = "run-bp"
|
||||
|
||||
await bridge.publish(run_id, "first", {})
|
||||
await bridge.publish(run_id, "second", {})
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
# Second publish should either succeed after queue drains or warn+drop
|
||||
# It should not hang indefinitely
|
||||
async def publish_second():
|
||||
await bridge.publish(run_id, "second", {})
|
||||
received = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0):
|
||||
received.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
# Give it a generous timeout — the publish timeout is 30s but we don't
|
||||
# want to wait that long in tests. Instead, drain the queue first.
|
||||
async def drain():
|
||||
await asyncio.sleep(0.05)
|
||||
bridge._queues[run_id].get_nowait()
|
||||
|
||||
await asyncio.gather(publish_second(), drain())
|
||||
assert bridge._queues[run_id].qsize() == 1
|
||||
assert len(received) == 2
|
||||
assert received[0].event == "second"
|
||||
assert received[1] is END_SENTINEL
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -140,54 +139,116 @@ async def test_event_id_format(bridge: MemoryStreamBridge):
|
||||
assert re.match(r"^\d+-\d+$", event.id), f"Expected timestamp-seq format, got {event.id}"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_subscribe_replays_after_last_event_id(bridge: MemoryStreamBridge):
|
||||
"""Reconnect should replay buffered events after the provided Last-Event-ID."""
|
||||
run_id = "run-replay"
|
||||
await bridge.publish(run_id, "metadata", {"run_id": run_id})
|
||||
await bridge.publish(run_id, "values", {"step": 1})
|
||||
await bridge.publish(run_id, "updates", {"step": 2})
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
first_pass = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0):
|
||||
first_pass.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
received = []
|
||||
async for entry in bridge.subscribe(
|
||||
run_id,
|
||||
last_event_id=first_pass[0].id,
|
||||
heartbeat_interval=1.0,
|
||||
):
|
||||
received.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
assert [entry.event for entry in received[:-1]] == ["values", "updates"]
|
||||
assert received[-1] is END_SENTINEL
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_slow_subscriber_does_not_skip_after_buffer_trim():
|
||||
"""A slow subscriber should continue from the correct absolute offset."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=2)
|
||||
run_id = "run-slow-subscriber"
|
||||
await bridge.publish(run_id, "e1", {"step": 1})
|
||||
await bridge.publish(run_id, "e2", {"step": 2})
|
||||
|
||||
stream = bridge._streams[run_id]
|
||||
e1_id = stream.events[0].id
|
||||
assert stream.start_offset == 0
|
||||
|
||||
await bridge.publish(run_id, "e3", {"step": 3}) # trims e1
|
||||
assert stream.start_offset == 1
|
||||
assert [entry.event for entry in stream.events] == ["e2", "e3"]
|
||||
|
||||
resumed_after_e1 = []
|
||||
async for entry in bridge.subscribe(
|
||||
run_id,
|
||||
last_event_id=e1_id,
|
||||
heartbeat_interval=1.0,
|
||||
):
|
||||
resumed_after_e1.append(entry)
|
||||
if len(resumed_after_e1) == 2:
|
||||
break
|
||||
|
||||
assert [entry.event for entry in resumed_after_e1] == ["e2", "e3"]
|
||||
e2_id = resumed_after_e1[0].id
|
||||
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
received = []
|
||||
async for entry in bridge.subscribe(
|
||||
run_id,
|
||||
last_event_id=e2_id,
|
||||
heartbeat_interval=1.0,
|
||||
):
|
||||
received.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
assert [entry.event for entry in received[:-1]] == ["e3"]
|
||||
assert received[-1] is END_SENTINEL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# END sentinel guarantee tests
|
||||
# Stream termination tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_end_sentinel_delivered_when_queue_full():
|
||||
"""END sentinel must always be delivered, even when the queue is completely full.
|
||||
|
||||
This is the critical regression test for the bug where publish_end()
|
||||
would silently drop the END sentinel when the queue was full, causing
|
||||
subscribe() to hang forever and leaking resources.
|
||||
"""
|
||||
async def test_publish_end_terminates_even_when_history_is_full():
|
||||
"""publish_end() should terminate subscribers without mutating retained history."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=2)
|
||||
run_id = "run-end-full"
|
||||
run_id = "run-end-history-full"
|
||||
|
||||
# Fill the queue to capacity
|
||||
await bridge.publish(run_id, "event-1", {"n": 1})
|
||||
await bridge.publish(run_id, "event-2", {"n": 2})
|
||||
assert bridge._queues[run_id].full()
|
||||
stream = bridge._streams[run_id]
|
||||
assert [entry.event for entry in stream.events] == ["event-1", "event-2"]
|
||||
|
||||
# publish_end should succeed by evicting old events
|
||||
await bridge.publish_end(run_id)
|
||||
assert [entry.event for entry in stream.events] == ["event-1", "event-2"]
|
||||
|
||||
# Subscriber must receive END_SENTINEL
|
||||
events = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
|
||||
events.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
assert any(e is END_SENTINEL for e in events), "END sentinel was not delivered"
|
||||
assert [entry.event for entry in events[:-1]] == ["event-1", "event-2"]
|
||||
assert events[-1] is END_SENTINEL
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_end_sentinel_evicts_oldest_events():
|
||||
"""When queue is full, publish_end evicts the oldest events to make room."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=1)
|
||||
run_id = "run-evict"
|
||||
|
||||
# Fill queue with one event
|
||||
await bridge.publish(run_id, "will-be-evicted", {})
|
||||
assert bridge._queues[run_id].full()
|
||||
|
||||
# publish_end must succeed
|
||||
async def test_publish_end_without_history_yields_end_immediately():
|
||||
"""Subscribers should still receive END when a run completes without events."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=2)
|
||||
run_id = "run-end-empty"
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
# The only event we should get is END_SENTINEL (the regular event was evicted)
|
||||
events = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
|
||||
events.append(entry)
|
||||
@@ -199,8 +260,8 @@ async def test_end_sentinel_evicts_oldest_events():
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_end_sentinel_no_eviction_when_space_available():
|
||||
"""When queue has space, publish_end should not evict anything."""
|
||||
async def test_publish_end_preserves_history_when_space_available():
|
||||
"""When history has spare capacity, publish_end should preserve prior events."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=10)
|
||||
run_id = "run-no-evict"
|
||||
|
||||
@@ -244,87 +305,23 @@ async def test_concurrent_tasks_end_sentinel():
|
||||
return events
|
||||
return events # pragma: no cover
|
||||
|
||||
# Run producers and consumers concurrently
|
||||
run_ids = [f"concurrent-{i}" for i in range(num_runs)]
|
||||
producers = [producer(rid) for rid in run_ids]
|
||||
consumers = [consumer(rid) for rid in run_ids]
|
||||
results: dict[str, list] = {}
|
||||
|
||||
# Start consumers first, then producers
|
||||
consumer_tasks = [asyncio.create_task(c) for c in consumers]
|
||||
await asyncio.gather(*producers)
|
||||
async def consume_into(run_id: str) -> None:
|
||||
results[run_id] = await consumer(run_id)
|
||||
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*consumer_tasks),
|
||||
timeout=10.0,
|
||||
)
|
||||
with anyio.fail_after(10):
|
||||
async with anyio.create_task_group() as task_group:
|
||||
for run_id in run_ids:
|
||||
task_group.start_soon(consume_into, run_id)
|
||||
await anyio.sleep(0)
|
||||
for run_id in run_ids:
|
||||
task_group.start_soon(producer, run_id)
|
||||
|
||||
for i, events in enumerate(results):
|
||||
assert events[-1] is END_SENTINEL, f"Run {run_ids[i]} did not receive END sentinel"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Drop counter tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dropped_count_tracking():
|
||||
"""Dropped events should be tracked per run_id."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=1)
|
||||
run_id = "run-drop-count"
|
||||
|
||||
# Fill the queue
|
||||
await bridge.publish(run_id, "first", {})
|
||||
|
||||
# This publish will time out and be dropped (we patch timeout to be instant)
|
||||
# Instead, we verify the counter after publish_end eviction
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
# dropped_count tracks publish() drops, not publish_end evictions
|
||||
assert bridge.dropped_count(run_id) == 0
|
||||
|
||||
# cleanup should also clear the counter
|
||||
await bridge.cleanup(run_id)
|
||||
assert bridge.dropped_count(run_id) == 0
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dropped_total():
|
||||
"""dropped_total should sum across all runs."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=256)
|
||||
|
||||
# No drops yet
|
||||
assert bridge.dropped_total == 0
|
||||
|
||||
# Manually set some counts to verify the property
|
||||
bridge._dropped_counts["run-a"] = 3
|
||||
bridge._dropped_counts["run-b"] = 7
|
||||
assert bridge.dropped_total == 10
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cleanup_clears_dropped_counts():
|
||||
"""cleanup() should clear the dropped counter for the run."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=256)
|
||||
run_id = "run-cleanup-drops"
|
||||
|
||||
bridge._get_or_create_queue(run_id)
|
||||
bridge._dropped_counts[run_id] = 5
|
||||
|
||||
await bridge.cleanup(run_id)
|
||||
assert run_id not in bridge._dropped_counts
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_close_clears_dropped_counts():
|
||||
"""close() should clear all dropped counters."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=256)
|
||||
bridge._dropped_counts["run-x"] = 10
|
||||
bridge._dropped_counts["run-y"] = 20
|
||||
|
||||
await bridge.close()
|
||||
assert bridge.dropped_total == 0
|
||||
assert len(bridge._dropped_counts) == 0
|
||||
for run_id in run_ids:
|
||||
events = results[run_id]
|
||||
assert events[-1] is END_SENTINEL, f"Run {run_id} did not receive END sentinel"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
These functions truncate long tool outputs to prevent context window overflow.
|
||||
- _truncate_bash_output: middle-truncation (head + tail), for bash tool
|
||||
- _truncate_read_file_output: head-truncation, for read_file tool
|
||||
- _truncate_ls_output: head-truncation, for ls tool
|
||||
"""
|
||||
|
||||
from deerflow.sandbox.tools import _truncate_bash_output, _truncate_read_file_output
|
||||
from deerflow.sandbox.tools import _truncate_bash_output, _truncate_ls_output, _truncate_read_file_output
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncate_bash_output
|
||||
@@ -159,3 +160,71 @@ class TestTruncateReadFileOutput:
|
||||
for max_chars in [100, 1000, 5000, 20000, 49999]:
|
||||
result = _truncate_read_file_output(output, max_chars)
|
||||
assert len(result) <= max_chars, f"failed for max_chars={max_chars}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncate_ls_output
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTruncateLsOutput:
|
||||
def test_short_output_returned_unchanged(self):
|
||||
output = "dir1\ndir2\nfile1.txt"
|
||||
assert _truncate_ls_output(output, 20000) == output
|
||||
|
||||
def test_output_equal_to_limit_returned_unchanged(self):
|
||||
output = "X" * 20000
|
||||
assert _truncate_ls_output(output, 20000) == output
|
||||
|
||||
def test_long_output_is_truncated(self):
|
||||
output = "\n".join(f"file_{i}.txt" for i in range(5000))
|
||||
result = _truncate_ls_output(output, 20000)
|
||||
assert len(result) < len(output)
|
||||
|
||||
def test_result_never_exceeds_max_chars(self):
|
||||
output = "\n".join(f"subdir/file_{i}.txt" for i in range(5000))
|
||||
max_chars = 20000
|
||||
result = _truncate_ls_output(output, max_chars)
|
||||
assert len(result) <= max_chars
|
||||
|
||||
def test_head_is_preserved(self):
|
||||
head = "first_dir\nsecond_dir\n"
|
||||
output = head + "\n".join(f"file_{i}" for i in range(5000))
|
||||
result = _truncate_ls_output(output, 20000)
|
||||
assert result.startswith(head)
|
||||
|
||||
def test_truncation_marker_present(self):
|
||||
output = "\n".join(f"file_{i}.txt" for i in range(5000))
|
||||
result = _truncate_ls_output(output, 20000)
|
||||
assert "[truncated:" in result
|
||||
assert "showing first" in result
|
||||
|
||||
def test_total_chars_reported_correctly(self):
|
||||
output = "X" * 30000
|
||||
result = _truncate_ls_output(output, 20000)
|
||||
assert "of 30000 chars" in result
|
||||
|
||||
def test_hint_suggests_specific_path(self):
|
||||
output = "X" * 30000
|
||||
result = _truncate_ls_output(output, 20000)
|
||||
assert "Use a more specific path" in result
|
||||
|
||||
def test_max_chars_zero_disables_truncation(self):
|
||||
output = "\n".join(f"file_{i}.txt" for i in range(10000))
|
||||
assert _truncate_ls_output(output, 0) == output
|
||||
|
||||
def test_tail_is_not_preserved(self):
|
||||
output = "H" * 20000 + "TAIL_SHOULD_NOT_APPEAR"
|
||||
result = _truncate_ls_output(output, 20000)
|
||||
assert "TAIL_SHOULD_NOT_APPEAR" not in result
|
||||
|
||||
def test_small_max_chars_does_not_crash(self):
|
||||
output = "\n".join(f"file_{i}.txt" for i in range(100))
|
||||
result = _truncate_ls_output(output, 10)
|
||||
assert len(result) <= 10
|
||||
|
||||
def test_result_never_exceeds_max_chars_various_sizes(self):
|
||||
output = "\n".join(f"file_{i}.txt" for i in range(5000))
|
||||
for max_chars in [100, 1000, 5000, 20000, len(output) - 1]:
|
||||
result = _truncate_ls_output(output, max_chars)
|
||||
assert len(result) <= max_chars, f"failed for max_chars={max_chars}"
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
||||
|
||||
from deerflow.models.vllm_provider import VllmChatModel
|
||||
|
||||
|
||||
def _make_model() -> VllmChatModel:
|
||||
return VllmChatModel(
|
||||
model="Qwen/QwQ-32B",
|
||||
api_key="dummy",
|
||||
base_url="http://localhost:8000/v1",
|
||||
)
|
||||
|
||||
|
||||
def test_vllm_provider_restores_reasoning_in_request_payload():
|
||||
model = _make_model()
|
||||
payload = model._get_request_payload(
|
||||
[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "bash", "args": {"cmd": "pwd"}, "id": "tool-1", "type": "tool_call"}],
|
||||
additional_kwargs={"reasoning": "Need to inspect the workspace first."},
|
||||
),
|
||||
HumanMessage(content="Continue"),
|
||||
]
|
||||
)
|
||||
|
||||
assistant_message = payload["messages"][0]
|
||||
assert assistant_message["role"] == "assistant"
|
||||
assert assistant_message["reasoning"] == "Need to inspect the workspace first."
|
||||
assert assistant_message["tool_calls"][0]["function"]["name"] == "bash"
|
||||
|
||||
|
||||
def test_vllm_provider_normalizes_legacy_thinking_kwarg_to_enable_thinking():
|
||||
model = VllmChatModel(
|
||||
model="qwen3",
|
||||
api_key="dummy",
|
||||
base_url="http://localhost:8000/v1",
|
||||
extra_body={"chat_template_kwargs": {"thinking": True}},
|
||||
)
|
||||
|
||||
payload = model._get_request_payload([HumanMessage(content="Hello")])
|
||||
|
||||
assert payload["extra_body"]["chat_template_kwargs"] == {"enable_thinking": True}
|
||||
|
||||
|
||||
def test_vllm_provider_preserves_explicit_enable_thinking_kwarg():
|
||||
model = VllmChatModel(
|
||||
model="qwen3",
|
||||
api_key="dummy",
|
||||
base_url="http://localhost:8000/v1",
|
||||
extra_body={"chat_template_kwargs": {"enable_thinking": False, "foo": "bar"}},
|
||||
)
|
||||
|
||||
payload = model._get_request_payload([HumanMessage(content="Hello")])
|
||||
|
||||
assert payload["extra_body"]["chat_template_kwargs"] == {
|
||||
"enable_thinking": False,
|
||||
"foo": "bar",
|
||||
}
|
||||
|
||||
|
||||
def test_vllm_provider_preserves_reasoning_in_chat_result():
|
||||
model = _make_model()
|
||||
result = model._create_chat_result(
|
||||
{
|
||||
"model": "Qwen/QwQ-32B",
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "42",
|
||||
"reasoning": "I compared the two numbers directly.",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
}
|
||||
)
|
||||
|
||||
message = result.generations[0].message
|
||||
assert message.additional_kwargs["reasoning"] == "I compared the two numbers directly."
|
||||
assert message.additional_kwargs["reasoning_content"] == "I compared the two numbers directly."
|
||||
|
||||
|
||||
def test_vllm_provider_preserves_reasoning_in_streaming_chunks():
|
||||
model = _make_model()
|
||||
chunk = model._convert_chunk_to_generation_chunk(
|
||||
{
|
||||
"model": "Qwen/QwQ-32B",
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"reasoning": "First, call the weather tool.",
|
||||
"content": "Calling tool...",
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
AIMessageChunk,
|
||||
{},
|
||||
)
|
||||
|
||||
assert chunk is not None
|
||||
assert chunk.message.additional_kwargs["reasoning"] == "First, call the weather tool."
|
||||
assert chunk.message.additional_kwargs["reasoning_content"] == "First, call the weather tool."
|
||||
assert chunk.message.content == "Calling tool..."
|
||||
|
||||
|
||||
def test_vllm_provider_preserves_empty_reasoning_values_in_streaming_chunks():
|
||||
model = _make_model()
|
||||
chunk = model._convert_chunk_to_generation_chunk(
|
||||
{
|
||||
"model": "Qwen/QwQ-32B",
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"reasoning": "",
|
||||
"content": "Still replying...",
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
AIMessageChunk,
|
||||
{},
|
||||
)
|
||||
|
||||
assert chunk is not None
|
||||
assert "reasoning" in chunk.message.additional_kwargs
|
||||
assert chunk.message.additional_kwargs["reasoning"] == ""
|
||||
assert "reasoning_content" not in chunk.message.additional_kwargs
|
||||
assert chunk.message.content == "Still replying..."
|
||||
Reference in New Issue
Block a user