"""Helpers for keeping AIMessage tool-call metadata consistent.""" from __future__ import annotations from typing import Any from langchain_core.messages import AIMessage def _raw_tool_call_id(raw_tool_call: Any) -> str | None: if not isinstance(raw_tool_call, dict): return None raw_id = raw_tool_call.get("id") return raw_id if isinstance(raw_id, str) and raw_id else None def clone_ai_message_with_tool_calls( message: AIMessage, tool_calls: list[dict[str, Any]], *, content: Any | None = None, ) -> AIMessage: """Clone an AIMessage while keeping raw provider tool-call metadata in sync.""" kept_ids = {tc["id"] for tc in tool_calls if isinstance(tc.get("id"), str) and tc["id"]} update: dict[str, Any] = {"tool_calls": tool_calls} if content is not None: update["content"] = content additional_kwargs = dict(getattr(message, "additional_kwargs", {}) or {}) raw_tool_calls = additional_kwargs.get("tool_calls") if isinstance(raw_tool_calls, list): synced_raw_tool_calls = [raw_tc for raw_tc in raw_tool_calls if _raw_tool_call_id(raw_tc) in kept_ids] if synced_raw_tool_calls: additional_kwargs["tool_calls"] = synced_raw_tool_calls else: additional_kwargs.pop("tool_calls", None) if not tool_calls: additional_kwargs.pop("function_call", None) update["additional_kwargs"] = additional_kwargs response_metadata = dict(getattr(message, "response_metadata", {}) or {}) if not tool_calls and response_metadata.get("finish_reason") == "tool_calls": response_metadata["finish_reason"] = "stop" update["response_metadata"] = response_metadata return message.model_copy(update=update)