51 lines
1.7 KiB
Python
51 lines
1.7 KiB
Python
"""Helpers for keeping AIMessage tool-call metadata consistent."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from langchain_core.messages import AIMessage
|
|
|
|
|
|
def _raw_tool_call_id(raw_tool_call: Any) -> str | None:
|
|
if not isinstance(raw_tool_call, dict):
|
|
return None
|
|
|
|
raw_id = raw_tool_call.get("id")
|
|
return raw_id if isinstance(raw_id, str) and raw_id else None
|
|
|
|
|
|
def clone_ai_message_with_tool_calls(
|
|
message: AIMessage,
|
|
tool_calls: list[dict[str, Any]],
|
|
*,
|
|
content: Any | None = None,
|
|
) -> AIMessage:
|
|
"""Clone an AIMessage while keeping raw provider tool-call metadata in sync."""
|
|
kept_ids = {tc["id"] for tc in tool_calls if isinstance(tc.get("id"), str) and tc["id"]}
|
|
|
|
update: dict[str, Any] = {"tool_calls": tool_calls}
|
|
if content is not None:
|
|
update["content"] = content
|
|
|
|
additional_kwargs = dict(getattr(message, "additional_kwargs", {}) or {})
|
|
raw_tool_calls = additional_kwargs.get("tool_calls")
|
|
if isinstance(raw_tool_calls, list):
|
|
synced_raw_tool_calls = [raw_tc for raw_tc in raw_tool_calls if _raw_tool_call_id(raw_tc) in kept_ids]
|
|
if synced_raw_tool_calls:
|
|
additional_kwargs["tool_calls"] = synced_raw_tool_calls
|
|
else:
|
|
additional_kwargs.pop("tool_calls", None)
|
|
|
|
if not tool_calls:
|
|
additional_kwargs.pop("function_call", None)
|
|
|
|
update["additional_kwargs"] = additional_kwargs
|
|
|
|
response_metadata = dict(getattr(message, "response_metadata", {}) or {})
|
|
if not tool_calls and response_metadata.get("finish_reason") == "tool_calls":
|
|
response_metadata["finish_reason"] = "stop"
|
|
update["response_metadata"] = response_metadata
|
|
|
|
return message.model_copy(update=update)
|