Merge branch 'main' into release/2.0-rc

This commit is contained in:
Willem Jiang
2026-04-28 15:44:02 +08:00
committed by GitHub
20 changed files with 1531 additions and 98 deletions
@@ -1,4 +1,5 @@
import ast
import html
import json
import re
import uuid
@@ -36,8 +37,8 @@ def _fix_messages(messages: list) -> list:
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", []):
xml_parts = []
for tool in msg.tool_calls:
args_xml = " ".join(f"<parameter={k}>{json.dumps(v, ensure_ascii=False)}</parameter>" for k, v in tool.get("args", {}).items())
xml_parts.append(f"<tool_call> <function={tool['name']}> {args_xml} </function> </tool_call>")
args_xml = " ".join(f"<parameter={html.escape(str(k), quote=False)}>{html.escape(v if isinstance(v, str) else json.dumps(v, ensure_ascii=False), quote=False)}</parameter>" for k, v in tool.get("args", {}).items())
xml_parts.append(f"<tool_call> <function={html.escape(str(tool['name']), quote=False)}> {args_xml} </function> </tool_call>")
full_text = f"{text}\n" + "\n".join(xml_parts) if text else "\n".join(xml_parts)
fixed.append(AIMessage(content=full_text.strip() or " "))
continue
@@ -80,13 +81,24 @@ def _parse_xml_tool_call_to_dict(content: str) -> tuple[str, list[dict]]:
func_match = re.search(r"<function=([^>]+)>", inner_content)
if not func_match:
continue
function_name = func_match.group(1).strip()
function_name = html.unescape(func_match.group(1).strip())
# Ignore nested tool blocks when extracting parameters for this call.
# Nested `<tool_call>` sections represent separate invocations and
# their `<parameter>` tags must not leak into the current call args.
param_source_parts: list[str] = []
nested_cursor = 0
for nested_start, nested_end, _ in _iter_tool_call_blocks(inner_content):
param_source_parts.append(inner_content[nested_cursor:nested_start])
nested_cursor = nested_end
param_source_parts.append(inner_content[nested_cursor:])
param_source = "".join(param_source_parts)
args = {}
param_pattern = re.compile(r"<parameter=([^>]+)>(.*?)</parameter>", re.DOTALL)
for param_match in param_pattern.finditer(inner_content):
key = param_match.group(1).strip()
raw_value = param_match.group(2).strip()
for param_match in param_pattern.finditer(param_source):
key = html.unescape(param_match.group(1).strip())
raw_value = html.unescape(param_match.group(2).strip())
# Attempt to deserialize string values into native Python types
# to satisfy downstream Pydantic validation.