Files
deer-flow/backend/packages/harness/deerflow/models/patched_stepfun.py
T
hataa 37337b77f9 feat(models): add StepFun reasoning model adapter (#3461)
Add PatchedChatStepFun adapter for StepFun reasoning models (step-3.7-flash,
step-3.5-flash). Captures reasoning from both streaming and non-streaming
responses and replays it on historical assistant messages for multi-turn
tool-call conversations.

- New: PatchedChatStepFun adapter with streaming/non-streaming reasoning capture
- Support both reasoning and reasoning_content field names
- 17 unit tests covering all response paths
- Updated: config.example.yaml with StepFun configuration example
2026-06-09 18:01:43 +08:00

176 lines
6.5 KiB
Python

"""Patched ChatOpenAI adapter for StepFun reasoning models.
StepFun returns ``reasoning`` (or ``reasoning_content`` with deepseek-style) in
both streaming deltas and non-streaming responses. Standard ``ChatOpenAI``
ignores these non-standard fields, so reasoning content is silently dropped.
This adapter captures reasoning from all response paths and replays it on
historical assistant messages for multi-turn tool-call conversations.
"""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
from deerflow.models.assistant_payload_replay import (
restore_assistant_payloads,
restore_reasoning_content,
)
_MISSING = object()
def _extract_reasoning(value: Any) -> str | object:
"""Return reasoning content from a dict/Pydantic object.
StepFun may return reasoning via ``reasoning`` (default) or
``reasoning_content`` (deepseek-style). Check both fields.
"""
if isinstance(value, Mapping):
# Check reasoning_content first (deepseek-style), then reasoning (default)
for field in ("reasoning_content", "reasoning"):
if field in value and value[field] is not None:
return value[field]
return _MISSING
# Pydantic / SDK object attributes
for field in ("reasoning_content", "reasoning"):
attr = getattr(value, field, _MISSING)
if attr is not _MISSING and attr is not None:
return attr
# Some SDK versions store extra fields in model_extra
model_extra = getattr(value, "model_extra", None)
if isinstance(model_extra, Mapping):
for field in ("reasoning_content", "reasoning"):
if field in model_extra and model_extra[field] is not None:
return model_extra[field]
return _MISSING
def _with_reasoning_content(message: AIMessage | AIMessageChunk, reasoning: str) -> AIMessage | AIMessageChunk:
"""Return a copy of *message* with reasoning_content stored in additional_kwargs."""
additional_kwargs = dict(message.additional_kwargs)
if additional_kwargs.get("reasoning_content") != reasoning:
additional_kwargs["reasoning_content"] = reasoning
return message.model_copy(update={"additional_kwargs": additional_kwargs})
def _get_typed_choice_message(response: Any, index: int) -> Any:
"""Extract the SDK-typed choice message at *index*, if available."""
choices = getattr(response, "choices", None)
if choices is None:
return None
try:
return choices[index].message
except (AttributeError, IndexError, TypeError):
return None
class PatchedChatStepFun(ChatOpenAI):
"""ChatOpenAI with full reasoning support for StepFun models.
Captures ``reasoning`` / ``reasoning_content`` from both streaming and
non-streaming responses and replays it on historical assistant messages in
multi-turn tool-call conversations.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def lc_secrets(self) -> dict[str, str]:
return {"api_key": "STEPFUN_API_KEY", "openai_api_key": "STEPFUN_API_KEY"}
# --- Request payload replay ---
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict:
"""Restore ``reasoning_content`` on historical assistant messages."""
original_messages = self._convert_input(input_).to_messages()
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
restore_assistant_payloads(
payload.get("messages", []),
original_messages,
restore_reasoning_content,
)
return payload
# --- Streaming reasoning capture ---
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: type,
base_generation_info: dict | None,
) -> ChatGenerationChunk | None:
"""Capture ``reasoning`` / ``reasoning_content`` from streaming deltas."""
generation_chunk = super()._convert_chunk_to_generation_chunk(
chunk,
default_chunk_class,
base_generation_info,
)
if generation_chunk is None:
return None
choices = chunk.get("choices", [])
if choices:
delta = choices[0].get("delta") or {}
reasoning = _extract_reasoning(delta)
if reasoning is not _MISSING and isinstance(generation_chunk.message, AIMessageChunk):
generation_chunk = ChatGenerationChunk(
message=_with_reasoning_content(generation_chunk.message, reasoning),
generation_info=generation_chunk.generation_info,
)
return generation_chunk
# --- Non-streaming reasoning capture ---
def _create_chat_result(
self,
response: dict | Any,
generation_info: dict | None = None,
) -> ChatResult:
"""Extract ``reasoning`` / ``reasoning_content`` from non-streaming responses."""
result = super()._create_chat_result(response, generation_info)
response_dict = response if isinstance(response, dict) else response.model_dump()
choices = response_dict.get("choices", [])
patched_generations: list[ChatGeneration] | None = None
for index, generation in enumerate(result.generations):
choice = choices[index] if index < len(choices) else {}
choice_message = choice.get("message", {}) if isinstance(choice, Mapping) else {}
reasoning = _extract_reasoning(choice_message)
if reasoning is _MISSING and not isinstance(response, dict):
reasoning = _extract_reasoning(_get_typed_choice_message(response, index))
message = generation.message
if reasoning is not _MISSING and isinstance(message, AIMessage):
if patched_generations is None:
patched_generations = list(result.generations)
patched_generations[index] = ChatGeneration(
message=_with_reasoning_content(message, reasoning),
generation_info=generation.generation_info,
)
return ChatResult(
generations=patched_generations or result.generations,
llm_output=result.llm_output,
)