mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 09:25:57 +00:00
4093c83383
* Share assistant payload replay matching * fix(provider): recover assistant field when ordinal AI index is taken The mismatch-length fallback in `_match_ai_message` only tried the exact `fallback_ordinal` AI index. When serialization drops or reorders an assistant message, a unique signature match can consume a non-ordinal index, leaving a later ambiguous payload's ordinal already used — so its provider field (e.g. `reasoning_content`) was silently dropped. Scan forward from the ordinal for the next unused `AIMessage` (wrapping to earlier indices) to preserve the positional bias while still recovering the field. Forward scanning avoids a naive min-unused pick that could restore the wrong field after a leading message is dropped. Add a regression test for the dropped-leading-message case. * fix(provider): avoid earlier assistant fallback replay
125 lines
4.6 KiB
Python
125 lines
4.6 KiB
Python
"""Helpers for replaying provider-specific assistant message fields.
|
|
|
|
Several provider adapters need to preserve fields that LangChain stores on the
|
|
original ``AIMessage`` but drops when serializing request payloads. This module
|
|
keeps the assistant-message matching logic shared while letting each provider
|
|
decide which fields to restore.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from collections.abc import Callable, Sequence
|
|
from typing import Any
|
|
|
|
from langchain_core.messages import AIMessage, BaseMessage
|
|
|
|
AssistantPayloadRestorer = Callable[[dict[str, Any], AIMessage], None]
|
|
|
|
|
|
def restore_assistant_payloads(
|
|
payload_messages: Sequence[dict[str, Any]],
|
|
original_messages: Sequence[BaseMessage],
|
|
restore: AssistantPayloadRestorer,
|
|
) -> None:
|
|
"""Restore provider-specific fields onto serialized assistant payloads."""
|
|
if len(payload_messages) == len(original_messages):
|
|
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
|
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
|
restore(payload_msg, orig_msg)
|
|
return
|
|
|
|
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
|
|
assistant_payloads = [m for m in payload_messages if m.get("role") == "assistant"]
|
|
used_ai_indexes: set[int] = set()
|
|
|
|
for ordinal, payload_msg in enumerate(assistant_payloads):
|
|
ai_msg = _match_ai_message(payload_msg, ai_messages, used_ai_indexes, ordinal)
|
|
if ai_msg is not None:
|
|
restore(payload_msg, ai_msg)
|
|
|
|
|
|
def restore_additional_kwargs_field(payload_msg: dict[str, Any], orig_msg: AIMessage, field_name: str) -> None:
|
|
"""Copy a provider-specific ``additional_kwargs`` field onto a payload message."""
|
|
value = orig_msg.additional_kwargs.get(field_name)
|
|
if value is not None:
|
|
payload_msg[field_name] = value
|
|
|
|
|
|
def restore_reasoning_content(payload_msg: dict[str, Any], orig_msg: AIMessage) -> None:
|
|
"""Copy provider reasoning content onto a serialized assistant payload."""
|
|
restore_additional_kwargs_field(payload_msg, orig_msg, "reasoning_content")
|
|
|
|
|
|
def _match_ai_message(
|
|
payload_msg: dict[str, Any],
|
|
ai_messages: Sequence[AIMessage],
|
|
used_ai_indexes: set[int],
|
|
fallback_ordinal: int,
|
|
) -> AIMessage | None:
|
|
payload_key = _assistant_signature(payload_msg)
|
|
if payload_key is not None:
|
|
matches = [index for index, ai_msg in enumerate(ai_messages) if index not in used_ai_indexes and _ai_signature(ai_msg) == payload_key]
|
|
if len(matches) == 1:
|
|
used_ai_indexes.add(matches[0])
|
|
return ai_messages[matches[0]]
|
|
|
|
fallback_index = _next_unused_index_at_or_after(len(ai_messages), used_ai_indexes, fallback_ordinal)
|
|
if fallback_index is not None:
|
|
used_ai_indexes.add(fallback_index)
|
|
return ai_messages[fallback_index]
|
|
|
|
return None
|
|
|
|
|
|
def _next_unused_index_at_or_after(count: int, used_ai_indexes: set[int], start: int) -> int | None:
|
|
"""Return the next unused AI index at or after ``start``.
|
|
|
|
Scanning forward from the payload's ordinal preserves the positional bias of
|
|
the previous behaviour while still recovering when serialization drops or
|
|
reorders messages so the exact ordinal index is already taken. It does not
|
|
wrap to earlier indexes because those messages may be represented by payload
|
|
entries that were already dropped.
|
|
"""
|
|
if count == 0 or start >= count:
|
|
return None
|
|
for index in range(start, count):
|
|
if index not in used_ai_indexes:
|
|
return index
|
|
return None
|
|
|
|
|
|
def _assistant_signature(payload_msg: dict[str, Any]) -> tuple[str, str] | None:
|
|
return _signature(
|
|
payload_msg.get("content"),
|
|
_tool_call_ids(payload_msg.get("tool_calls") or []),
|
|
)
|
|
|
|
|
|
def _ai_signature(message: AIMessage) -> tuple[str, str] | None:
|
|
tool_calls = message.tool_calls or message.additional_kwargs.get("tool_calls") or []
|
|
return _signature(message.content, _tool_call_ids(tool_calls))
|
|
|
|
|
|
def _signature(content: Any, tool_call_ids: tuple[str, ...]) -> tuple[str, str] | None:
|
|
if content in (None, "") and not tool_call_ids:
|
|
return None
|
|
return (_stable_repr(content), "|".join(tool_call_ids))
|
|
|
|
|
|
def _stable_repr(value: Any) -> str:
|
|
try:
|
|
return json.dumps(value, sort_keys=True, ensure_ascii=False)
|
|
except TypeError:
|
|
return repr(value)
|
|
|
|
|
|
def _tool_call_ids(tool_calls: Sequence[Any]) -> tuple[str, ...]:
|
|
ids: list[str] = []
|
|
for tool_call in tool_calls:
|
|
if isinstance(tool_call, dict):
|
|
call_id = tool_call.get("id")
|
|
if isinstance(call_id, str) and call_id:
|
|
ids.append(call_id)
|
|
return tuple(ids)
|