refactor(provider): share assistant payload replay matching (#3307)

* Share assistant payload replay matching

* fix(provider): recover assistant field when ordinal AI index is taken

The mismatch-length fallback in `_match_ai_message` only tried the exact
`fallback_ordinal` AI index. When serialization drops or reorders an
assistant message, a unique signature match can consume a non-ordinal
index, leaving a later ambiguous payload's ordinal already used — so its
provider field (e.g. `reasoning_content`) was silently dropped.

Scan forward from the ordinal for the next unused `AIMessage` (wrapping to
earlier indices) to preserve the positional bias while still recovering
the field. Forward scanning avoids a naive min-unused pick that could
restore the wrong field after a leading message is dropped.

Add a regression test for the dropped-leading-message case.

* fix(provider): avoid earlier assistant fallback replay
This commit is contained in:
AochenShen99
2026-05-29 23:05:59 +08:00
committed by GitHub
parent 052b1e2102
commit 4093c83383
5 changed files with 307 additions and 50 deletions
@@ -17,6 +17,8 @@ from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
from deerflow.models.assistant_payload_replay import restore_assistant_payloads, restore_reasoning_content
_MISSING = object()
@@ -45,12 +47,6 @@ def _with_reasoning_content(message: AIMessage | AIMessageChunk, reasoning: str)
return message.model_copy(update={"additional_kwargs": additional_kwargs})
def _restore_reasoning_content(payload_msg: dict, orig_msg: AIMessage) -> None:
reasoning = orig_msg.additional_kwargs.get("reasoning_content")
if reasoning is not None:
payload_msg["reasoning_content"] = reasoning
def _get_typed_choice_message(response: Any, index: int) -> Any:
choices = getattr(response, "choices", None)
if choices is None:
@@ -81,17 +77,11 @@ class PatchedChatMiMo(ChatOpenAI):
) -> dict:
original_messages = self._convert_input(input_).to_messages()
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
payload_messages = payload.get("messages", [])
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
_restore_reasoning_content(payload_msg, orig_msg)
else:
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
assistant_payloads = [m for m in payload_messages if m.get("role") == "assistant"]
for payload_msg, ai_msg in zip(assistant_payloads, ai_messages):
_restore_reasoning_content(payload_msg, ai_msg)
restore_assistant_payloads(
payload.get("messages", []),
original_messages,
restore_reasoning_content,
)
return payload