mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 09:25:57 +00:00
refactor(provider): share assistant payload replay matching (#3307)
* Share assistant payload replay matching * fix(provider): recover assistant field when ordinal AI index is taken The mismatch-length fallback in `_match_ai_message` only tried the exact `fallback_ordinal` AI index. When serialization drops or reorders an assistant message, a unique signature match can consume a non-ordinal index, leaving a later ambiguous payload's ordinal already used — so its provider field (e.g. `reasoning_content`) was silently dropped. Scan forward from the ordinal for the next unused `AIMessage` (wrapping to earlier indices) to preserve the positional bias while still recovering the field. Forward scanning avoids a naive min-unused pick that could restore the wrong field after a leading message is dropped. Add a regression test for the dropped-leading-message case. * fix(provider): avoid earlier assistant fallback replay
This commit is contained in:
@@ -17,6 +17,8 @@ from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from deerflow.models.assistant_payload_replay import restore_assistant_payloads, restore_reasoning_content
|
||||
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
@@ -45,12 +47,6 @@ def _with_reasoning_content(message: AIMessage | AIMessageChunk, reasoning: str)
|
||||
return message.model_copy(update={"additional_kwargs": additional_kwargs})
|
||||
|
||||
|
||||
def _restore_reasoning_content(payload_msg: dict, orig_msg: AIMessage) -> None:
|
||||
reasoning = orig_msg.additional_kwargs.get("reasoning_content")
|
||||
if reasoning is not None:
|
||||
payload_msg["reasoning_content"] = reasoning
|
||||
|
||||
|
||||
def _get_typed_choice_message(response: Any, index: int) -> Any:
|
||||
choices = getattr(response, "choices", None)
|
||||
if choices is None:
|
||||
@@ -81,17 +77,11 @@ class PatchedChatMiMo(ChatOpenAI):
|
||||
) -> dict:
|
||||
original_messages = self._convert_input(input_).to_messages()
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
payload_messages = payload.get("messages", [])
|
||||
|
||||
if len(payload_messages) == len(original_messages):
|
||||
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
||||
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
||||
_restore_reasoning_content(payload_msg, orig_msg)
|
||||
else:
|
||||
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
|
||||
assistant_payloads = [m for m in payload_messages if m.get("role") == "assistant"]
|
||||
for payload_msg, ai_msg in zip(assistant_payloads, ai_messages):
|
||||
_restore_reasoning_content(payload_msg, ai_msg)
|
||||
restore_assistant_payloads(
|
||||
payload.get("messages", []),
|
||||
original_messages,
|
||||
restore_reasoning_content,
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
Reference in New Issue
Block a user