Merge branch 'main' into rayhpeng/persistence-scaffold

# Conflicts:
#	backend/tests/test_model_factory.py
This commit is contained in:
rayhpeng
2026-04-06 17:11:49 +08:00
24 changed files with 995 additions and 259 deletions
@@ -1,22 +1,19 @@
"""Middleware for injecting image details into conversation before LLM call."""
import logging
from typing import NotRequired, override
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ViewedImageData
from deerflow.agents.thread_state import ThreadState
logger = logging.getLogger(__name__)
class ViewImageMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
viewed_images: NotRequired[dict[str, ViewedImageData] | None]
class ViewImageMiddlewareState(ThreadState):
"""Reuse the thread state so reducer-backed keys keep their annotations."""
class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]):
@@ -74,5 +74,10 @@ class SandboxConfig(BaseModel):
ge=0,
description="Maximum characters to keep from read_file tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
)
ls_output_max_chars: int = Field(
default=20000,
ge=0,
description="Maximum characters to keep from ls tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
)
model_config = ConfigDict(extra="allow")
@@ -9,6 +9,27 @@ from deerflow.tracing import build_tracing_callbacks
logger = logging.getLogger(__name__)
def _deep_merge_dicts(base: dict | None, override: dict) -> dict:
"""Recursively merge two dictionaries without mutating the inputs."""
merged = dict(base or {})
for key, value in override.items():
if isinstance(value, dict) and isinstance(merged.get(key), dict):
merged[key] = _deep_merge_dicts(merged[key], value)
else:
merged[key] = value
return merged
def _vllm_disable_chat_template_kwargs(chat_template_kwargs: dict) -> dict:
"""Build the disable payload for vLLM/Qwen chat template kwargs."""
disable_kwargs: dict[str, bool] = {}
if "thinking" in chat_template_kwargs:
disable_kwargs["thinking"] = False
if "enable_thinking" in chat_template_kwargs:
disable_kwargs["enable_thinking"] = False
return disable_kwargs
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel:
"""Create a chat model instance from the config.
@@ -54,13 +75,23 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
if not thinking_enabled and has_thinking_settings:
if effective_wte.get("extra_body", {}).get("thinking", {}).get("type"):
# OpenAI-compatible gateway: thinking is nested under extra_body
kwargs.update({"extra_body": {"thinking": {"type": "disabled"}}})
kwargs.update({"reasoning_effort": "minimal"})
model_settings_from_config["extra_body"] = _deep_merge_dicts(
model_settings_from_config.get("extra_body"),
{"thinking": {"type": "disabled"}},
)
model_settings_from_config["reasoning_effort"] = "minimal"
elif disable_chat_template_kwargs := _vllm_disable_chat_template_kwargs(effective_wte.get("extra_body", {}).get("chat_template_kwargs") or {}):
# vLLM uses chat template kwargs to switch thinking on/off.
model_settings_from_config["extra_body"] = _deep_merge_dicts(
model_settings_from_config.get("extra_body"),
{"chat_template_kwargs": disable_chat_template_kwargs},
)
elif effective_wte.get("thinking", {}).get("type"):
# Native langchain_anthropic: thinking is a direct constructor parameter
kwargs.update({"thinking": {"type": "disabled"}})
if not model_config.supports_reasoning_effort and "reasoning_effort" in kwargs:
del kwargs["reasoning_effort"]
model_settings_from_config["thinking"] = {"type": "disabled"}
if not model_config.supports_reasoning_effort:
kwargs.pop("reasoning_effort", None)
model_settings_from_config.pop("reasoning_effort", None)
# For Codex Responses API models: map thinking mode to reasoning_effort
from deerflow.models.openai_codex_provider import CodexChatModel
@@ -0,0 +1,258 @@
"""Custom vLLM provider built on top of LangChain ChatOpenAI.
vLLM 0.19.0 exposes reasoning models through an OpenAI-compatible API, but
LangChain's default OpenAI adapter drops the non-standard ``reasoning`` field
from assistant messages and streaming deltas. That breaks interleaved
thinking/tool-call flows because vLLM expects the assistant's prior reasoning to
be echoed back on subsequent turns.
This provider preserves ``reasoning`` on:
- non-streaming responses
- streaming deltas
- multi-turn request payloads
"""
from __future__ import annotations
import json
from collections.abc import Mapping
from typing import Any, cast
import openai
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import _create_usage_metadata
def _normalize_vllm_chat_template_kwargs(payload: dict[str, Any]) -> None:
"""Map DeerFlow's legacy ``thinking`` toggle to vLLM/Qwen's ``enable_thinking``.
DeerFlow originally documented ``extra_body.chat_template_kwargs.thinking``
for vLLM, but vLLM 0.19.0's Qwen reasoning parser reads
``chat_template_kwargs.enable_thinking``. Normalize the payload just before
it is sent so existing configs keep working and flash mode can truly
disable reasoning.
"""
extra_body = payload.get("extra_body")
if not isinstance(extra_body, dict):
return
chat_template_kwargs = extra_body.get("chat_template_kwargs")
if not isinstance(chat_template_kwargs, dict):
return
if "thinking" not in chat_template_kwargs:
return
normalized_chat_template_kwargs = dict(chat_template_kwargs)
normalized_chat_template_kwargs.setdefault("enable_thinking", normalized_chat_template_kwargs["thinking"])
normalized_chat_template_kwargs.pop("thinking", None)
extra_body["chat_template_kwargs"] = normalized_chat_template_kwargs
def _reasoning_to_text(reasoning: Any) -> str:
"""Best-effort extraction of readable reasoning text from vLLM payloads."""
if isinstance(reasoning, str):
return reasoning
if isinstance(reasoning, list):
parts = [_reasoning_to_text(item) for item in reasoning]
return "".join(part for part in parts if part)
if isinstance(reasoning, dict):
for key in ("text", "content", "reasoning"):
value = reasoning.get(key)
if isinstance(value, str):
return value
if value is not None:
text = _reasoning_to_text(value)
if text:
return text
try:
return json.dumps(reasoning, ensure_ascii=False)
except TypeError:
return str(reasoning)
try:
return json.dumps(reasoning, ensure_ascii=False)
except TypeError:
return str(reasoning)
def _convert_delta_to_message_chunk_with_reasoning(_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]) -> BaseMessageChunk:
"""Convert a streaming delta to a LangChain message chunk while preserving reasoning."""
id_ = _dict.get("id")
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: dict[str, Any] = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
reasoning = _dict.get("reasoning")
if reasoning is not None:
additional_kwargs["reasoning"] = reasoning
reasoning_text = _reasoning_to_text(reasoning)
if reasoning_text:
additional_kwargs["reasoning_content"] = reasoning_text
tool_call_chunks = []
if raw_tool_calls := _dict.get("tool_calls"):
try:
tool_call_chunks = [
tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc["index"],
)
for rtc in raw_tool_calls
]
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content, id=id_)
if role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
id=id_,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
)
if role in ("system", "developer") or default_class == SystemMessageChunk:
role_kwargs = {"__openai_role__": "developer"} if role == "developer" else {}
return SystemMessageChunk(content=content, id=id_, additional_kwargs=role_kwargs)
if role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
if role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"], id=id_)
if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role, id=id_) # type: ignore[arg-type]
return default_class(content=content, id=id_) # type: ignore[call-arg]
def _restore_reasoning_field(payload_msg: dict[str, Any], orig_msg: AIMessage) -> None:
"""Re-inject vLLM reasoning onto outgoing assistant messages."""
reasoning = orig_msg.additional_kwargs.get("reasoning")
if reasoning is None:
reasoning = orig_msg.additional_kwargs.get("reasoning_content")
if reasoning is not None:
payload_msg["reasoning"] = reasoning
class VllmChatModel(ChatOpenAI):
"""ChatOpenAI variant that preserves vLLM reasoning fields across turns."""
model_config = {"arbitrary_types_allowed": True}
@property
def _llm_type(self) -> str:
return "vllm-openai-compatible"
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Restore assistant reasoning in request payloads for interleaved thinking."""
original_messages = self._convert_input(input_).to_messages()
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
_normalize_vllm_chat_template_kwargs(payload)
payload_messages = payload.get("messages", [])
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
_restore_reasoning_field(payload_msg, orig_msg)
else:
ai_messages = [message for message in original_messages if isinstance(message, AIMessage)]
assistant_payloads = [message for message in payload_messages if message.get("role") == "assistant"]
for payload_msg, ai_msg in zip(assistant_payloads, ai_messages):
_restore_reasoning_field(payload_msg, ai_msg)
return payload
def _create_chat_result(self, response: dict | openai.BaseModel, generation_info: dict | None = None) -> ChatResult:
"""Preserve vLLM reasoning on non-streaming responses."""
result = super()._create_chat_result(response, generation_info=generation_info)
response_dict = response if isinstance(response, dict) else response.model_dump()
for generation, choice in zip(result.generations, response_dict.get("choices", [])):
if not isinstance(generation, ChatGeneration):
continue
message = generation.message
if not isinstance(message, AIMessage):
continue
reasoning = choice.get("message", {}).get("reasoning")
if reasoning is None:
continue
message.additional_kwargs["reasoning"] = reasoning
reasoning_text = _reasoning_to_text(reasoning)
if reasoning_text:
message.additional_kwargs["reasoning_content"] = reasoning_text
return result
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: type,
base_generation_info: dict | None,
) -> ChatGenerationChunk | None:
"""Preserve vLLM reasoning on streaming deltas."""
if chunk.get("type") == "content.delta":
return None
token_usage = chunk.get("usage")
choices = chunk.get("choices", []) or chunk.get("chunk", {}).get("choices", [])
usage_metadata = _create_usage_metadata(token_usage, chunk.get("service_tier")) if token_usage else None
if len(choices) == 0:
generation_chunk = ChatGenerationChunk(message=default_chunk_class(content="", usage_metadata=usage_metadata), generation_info=base_generation_info)
if self.output_version == "v1":
generation_chunk.message.content = []
generation_chunk.message.response_metadata["output_version"] = "v1"
return generation_chunk
choice = choices[0]
if choice["delta"] is None:
return None
message_chunk = _convert_delta_to_message_chunk_with_reasoning(choice["delta"], default_chunk_class)
generation_info = {**base_generation_info} if base_generation_info else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
if service_tier := chunk.get("service_tier"):
generation_info["service_tier"] = service_tier
if logprobs := choice.get("logprobs"):
generation_info["logprobs"] = logprobs
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
message_chunk.usage_metadata = usage_metadata
message_chunk.response_metadata["model_provider"] = "openai"
return ChatGenerationChunk(message=message_chunk, generation_info=generation_info or None)
@@ -1,4 +1,4 @@
"""In-memory stream bridge backed by :class:`asyncio.Queue`."""
"""In-memory stream bridge backed by an in-process event log."""
from __future__ import annotations
@@ -6,35 +6,41 @@ import asyncio
import logging
import time
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from typing import Any
from .base import END_SENTINEL, HEARTBEAT_SENTINEL, StreamBridge, StreamEvent
logger = logging.getLogger(__name__)
_PUBLISH_TIMEOUT = 30.0 # seconds to wait when queue is full
@dataclass
class _RunStream:
events: list[StreamEvent] = field(default_factory=list)
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
ended: bool = False
start_offset: int = 0
class MemoryStreamBridge(StreamBridge):
"""Per-run ``asyncio.Queue`` implementation.
"""Per-run in-memory event log implementation.
Each *run_id* gets its own queue on first :meth:`publish` call.
Events are retained for a bounded time window per run so late subscribers
and reconnecting clients can replay buffered events from ``Last-Event-ID``.
"""
def __init__(self, *, queue_maxsize: int = 256) -> None:
self._maxsize = queue_maxsize
self._queues: dict[str, asyncio.Queue[StreamEvent]] = {}
self._streams: dict[str, _RunStream] = {}
self._counters: dict[str, int] = {}
self._dropped_counts: dict[str, int] = {}
# -- helpers ---------------------------------------------------------------
def _get_or_create_queue(self, run_id: str) -> asyncio.Queue[StreamEvent]:
if run_id not in self._queues:
self._queues[run_id] = asyncio.Queue(maxsize=self._maxsize)
def _get_or_create_stream(self, run_id: str) -> _RunStream:
if run_id not in self._streams:
self._streams[run_id] = _RunStream()
self._counters[run_id] = 0
self._dropped_counts[run_id] = 0
return self._queues[run_id]
return self._streams[run_id]
def _next_id(self, run_id: str) -> str:
self._counters[run_id] = self._counters.get(run_id, 0) + 1
@@ -42,49 +48,39 @@ class MemoryStreamBridge(StreamBridge):
seq = self._counters[run_id] - 1
return f"{ts}-{seq}"
def _resolve_start_offset(self, stream: _RunStream, last_event_id: str | None) -> int:
if last_event_id is None:
return stream.start_offset
for index, entry in enumerate(stream.events):
if entry.id == last_event_id:
return stream.start_offset + index + 1
if stream.events:
logger.warning(
"last_event_id=%s not found in retained buffer; replaying from earliest retained event",
last_event_id,
)
return stream.start_offset
# -- StreamBridge API ------------------------------------------------------
async def publish(self, run_id: str, event: str, data: Any) -> None:
queue = self._get_or_create_queue(run_id)
stream = self._get_or_create_stream(run_id)
entry = StreamEvent(id=self._next_id(run_id), event=event, data=data)
try:
await asyncio.wait_for(queue.put(entry), timeout=_PUBLISH_TIMEOUT)
except TimeoutError:
self._dropped_counts[run_id] = self._dropped_counts.get(run_id, 0) + 1
logger.warning(
"Stream bridge queue full for run %s — dropping event %s (total dropped: %d)",
run_id,
event,
self._dropped_counts[run_id],
)
async with stream.condition:
stream.events.append(entry)
if len(stream.events) > self._maxsize:
overflow = len(stream.events) - self._maxsize
del stream.events[:overflow]
stream.start_offset += overflow
stream.condition.notify_all()
async def publish_end(self, run_id: str) -> None:
queue = self._get_or_create_queue(run_id)
# END sentinel is critical — it is the only signal that allows
# subscribers to terminate. If the queue is full we evict the
# oldest *regular* events to make room rather than dropping END,
# which would cause the SSE connection to hang forever and leak
# the queue/counter resources for this run_id.
if queue.full():
evicted = 0
while queue.full():
try:
queue.get_nowait()
evicted += 1
except asyncio.QueueEmpty:
break # pragma: no cover defensive
if evicted:
logger.warning(
"Stream bridge queue full for run %s — evicted %d event(s) to guarantee END sentinel delivery",
run_id,
evicted,
)
# After eviction the queue is guaranteed to have space, so a
# simple non-blocking put is safe. We still use put() (which
# blocks until space is available) as a defensive measure.
await queue.put(END_SENTINEL)
stream = self._get_or_create_stream(run_id)
async with stream.condition:
stream.ended = True
stream.condition.notify_all()
async def subscribe(
self,
@@ -93,16 +89,34 @@ class MemoryStreamBridge(StreamBridge):
last_event_id: str | None = None,
heartbeat_interval: float = 15.0,
) -> AsyncIterator[StreamEvent]:
if last_event_id is not None:
logger.debug("last_event_id=%s accepted but ignored (memory bridge has no replay)", last_event_id)
stream = self._get_or_create_stream(run_id)
async with stream.condition:
next_offset = self._resolve_start_offset(stream, last_event_id)
queue = self._get_or_create_queue(run_id)
while True:
try:
entry = await asyncio.wait_for(queue.get(), timeout=heartbeat_interval)
except TimeoutError:
yield HEARTBEAT_SENTINEL
continue
async with stream.condition:
if next_offset < stream.start_offset:
logger.warning(
"subscriber for run %s fell behind retained buffer; resuming from offset %s",
run_id,
stream.start_offset,
)
next_offset = stream.start_offset
local_index = next_offset - stream.start_offset
if 0 <= local_index < len(stream.events):
entry = stream.events[local_index]
next_offset += 1
elif stream.ended:
entry = END_SENTINEL
else:
try:
await asyncio.wait_for(stream.condition.wait(), timeout=heartbeat_interval)
except TimeoutError:
entry = HEARTBEAT_SENTINEL
else:
continue
if entry is END_SENTINEL:
yield END_SENTINEL
return
@@ -111,20 +125,9 @@ class MemoryStreamBridge(StreamBridge):
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
if delay > 0:
await asyncio.sleep(delay)
self._queues.pop(run_id, None)
self._streams.pop(run_id, None)
self._counters.pop(run_id, None)
self._dropped_counts.pop(run_id, None)
async def close(self) -> None:
self._queues.clear()
self._streams.clear()
self._counters.clear()
self._dropped_counts.clear()
def dropped_count(self, run_id: str) -> int:
"""Return the number of events dropped for *run_id*."""
return self._dropped_counts.get(run_id, 0)
@property
def dropped_total(self) -> int:
"""Return the total number of events dropped across all runs."""
return sum(self._dropped_counts.values())
@@ -963,6 +963,29 @@ def _truncate_read_file_output(output: str, max_chars: int) -> str:
return f"{output[:kept]}{marker}"
def _truncate_ls_output(output: str, max_chars: int) -> str:
"""Head-truncate ls output, preserving the beginning of the listing.
Directory listings are read top-to-bottom; the head shows the most
relevant structure.
The returned string (including the truncation marker) is guaranteed to be
no longer than max_chars characters. Pass max_chars=0 to disable truncation
and return the full output unchanged.
"""
if max_chars == 0:
return output
if len(output) <= max_chars:
return output
total = len(output)
marker_max_len = len(f"\n... [truncated: showing first {total} of {total} chars. Use a more specific path to see fewer results] ...")
kept = max(0, max_chars - marker_max_len)
if kept == 0:
return output[:max_chars]
marker = f"\n... [truncated: showing first {kept} of {total} chars. Use a more specific path to see fewer results] ..."
return f"{output[:kept]}{marker}"
@tool("bash", parse_docstring=True)
def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str:
"""Execute a bash command in a Linux environment.
@@ -1037,7 +1060,15 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
children = sandbox.list_dir(path)
if not children:
return "(empty)"
return "\n".join(children)
output = "\n".join(children)
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_ls_output(output, max_chars)
except SandboxError as e:
return f"Error: {e}"
except FileNotFoundError: