Merge branch 'main' into fix-2804

This commit is contained in:
Willem Jiang
2026-05-12 15:53:28 +08:00
committed by GitHub
38 changed files with 953 additions and 291 deletions
@@ -36,42 +36,73 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
@staticmethod
def _message_tool_calls(msg) -> list[dict]:
"""Return normalized tool calls from structured fields or raw provider payloads."""
"""Return normalized tool calls from structured fields or raw provider payloads.
LangChain stores malformed provider function calls in ``invalid_tool_calls``.
They do not execute, but provider adapters may still serialize enough of
the call id/name back into the next request that strict OpenAI-compatible
validators expect a matching ToolMessage. Treat them as dangling calls so
the next model request stays well-formed and the model sees a recoverable
tool error instead of another provider 400.
"""
normalized: list[dict] = []
tool_calls = getattr(msg, "tool_calls", None) or []
if tool_calls:
return list(tool_calls)
normalized.extend(list(tool_calls))
raw_tool_calls = (getattr(msg, "additional_kwargs", None) or {}).get("tool_calls") or []
normalized: list[dict] = []
for raw_tc in raw_tool_calls:
if not isinstance(raw_tc, dict):
if not tool_calls:
for raw_tc in raw_tool_calls:
if not isinstance(raw_tc, dict):
continue
function = raw_tc.get("function")
name = raw_tc.get("name")
if not name and isinstance(function, dict):
name = function.get("name")
args = raw_tc.get("args", {})
if not args and isinstance(function, dict):
raw_args = function.get("arguments")
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except (TypeError, ValueError, json.JSONDecodeError):
parsed_args = {}
args = parsed_args if isinstance(parsed_args, dict) else {}
normalized.append(
{
"id": raw_tc.get("id"),
"name": name or "unknown",
"args": args if isinstance(args, dict) else {},
}
)
for invalid_tc in getattr(msg, "invalid_tool_calls", None) or []:
if not isinstance(invalid_tc, dict):
continue
function = raw_tc.get("function")
name = raw_tc.get("name")
if not name and isinstance(function, dict):
name = function.get("name")
args = raw_tc.get("args", {})
if not args and isinstance(function, dict):
raw_args = function.get("arguments")
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except (TypeError, ValueError, json.JSONDecodeError):
parsed_args = {}
args = parsed_args if isinstance(parsed_args, dict) else {}
normalized.append(
{
"id": raw_tc.get("id"),
"name": name or "unknown",
"args": args if isinstance(args, dict) else {},
"id": invalid_tc.get("id"),
"name": invalid_tc.get("name") or "unknown",
"args": {},
"invalid": True,
"error": invalid_tc.get("error"),
}
)
return normalized
@staticmethod
def _synthetic_tool_message_content(tool_call: dict) -> str:
if tool_call.get("invalid"):
error = tool_call.get("error")
if isinstance(error, str) and error:
return f"[Tool call could not be executed because its arguments were invalid: {error}]"
return "[Tool call could not be executed because its arguments were invalid.]"
return "[Tool call was interrupted and did not return a result.]"
def _build_patched_messages(self, messages: list) -> list | None:
"""Return a new message list with patches inserted at the correct positions.
@@ -114,7 +145,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
patched.append(
ToolMessage(
content="[Tool call was interrupted and did not return a result.]",
content=self._synthetic_tool_message_content(tc),
tool_call_id=tc_id,
name=tc.get("name", "unknown"),
status="error",
+2 -43
View File
@@ -1,11 +1,6 @@
"""Load MCP tools using langchain-mcp-adapters."""
import asyncio
import atexit
import concurrent.futures
import logging
from collections.abc import Callable
from typing import Any
from langchain_core.tools import BaseTool
@@ -13,46 +8,10 @@ from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.mcp.client import build_servers_config
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
from deerflow.reflection import resolve_variable
from deerflow.tools.sync import make_sync_tool_wrapper
logger = logging.getLogger(__name__)
# Global thread pool for sync tool invocation in async environments
_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="mcp-sync-tool")
# Register shutdown hook for the global executor
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
def _make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
"""Build a synchronous wrapper for an asynchronous tool coroutine.
Args:
coro: The tool's asynchronous coroutine.
tool_name: Name of the tool (for logging).
Returns:
A synchronous function that correctly handles nested event loops.
"""
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
try:
if loop is not None and loop.is_running():
# Use global executor to avoid nested loop issues and improve performance
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
return future.result()
else:
return asyncio.run(coro(*args, **kwargs))
except Exception as e:
logger.error(f"Error invoking MCP tool '{tool_name}' via sync wrapper: {e}", exc_info=True)
raise
return sync_wrapper
async def get_mcp_tools() -> list[BaseTool]:
"""Get all tools from enabled MCP servers.
@@ -126,7 +85,7 @@ async def get_mcp_tools() -> list[BaseTool]:
# Patch tools to support sync invocation, as deerflow client streams synchronously
for tool in tools:
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
tool.func = _make_sync_tool_wrapper(tool.coroutine, tool.name)
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
return tools
@@ -23,6 +23,18 @@ class RunRepository(RunStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
@staticmethod
def _normalize_model_name(model_name: str | None) -> str | None:
"""Normalize model_name for storage: strip whitespace, truncate to 128 chars."""
if model_name is None:
return None
if not isinstance(model_name, str):
model_name = str(model_name)
normalized = model_name.strip()
if len(normalized) > 128:
normalized = normalized[:128]
return normalized
@staticmethod
def _safe_json(obj: Any) -> Any:
"""Ensure obj is JSON-serializable. Falls back to model_dump() or str()."""
@@ -70,6 +82,7 @@ class RunRepository(RunStore):
thread_id,
assistant_id=None,
user_id: str | None | _AutoSentinel = AUTO,
model_name: str | None = None,
status="pending",
multitask_strategy="reject",
metadata=None,
@@ -85,6 +98,7 @@ class RunRepository(RunStore):
thread_id=thread_id,
assistant_id=assistant_id,
user_id=resolved_user_id,
model_name=self._normalize_model_name(model_name),
status=status,
multitask_strategy=multitask_strategy,
metadata_json=self._safe_json(metadata) or {},
@@ -20,12 +20,13 @@ from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Mapping
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, HumanMessage, ToolMessage
from langgraph.types import Command
if TYPE_CHECKING:
@@ -71,6 +72,7 @@ class RunJournal(BaseCallbackHandler):
# Dedup: LangChain may fire on_llm_end multiple times for the same run_id
self._counted_llm_run_ids: set[str] = set()
self._counted_external_source_ids: set[str] = set()
self._counted_message_llm_run_ids: set[str] = set()
# Convenience fields
self._last_ai_msg: str | None = None
@@ -86,6 +88,50 @@ class RunJournal(BaseCallbackHandler):
# -- Lifecycle callbacks --
@staticmethod
def _message_text(message: BaseMessage) -> str:
"""Extract displayable text from a message's mixed content shape."""
content = getattr(message, "content", None)
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, Mapping):
text = block.get("text")
if isinstance(text, str):
parts.append(text)
else:
nested = block.get("content")
if isinstance(nested, str):
parts.append(nested)
return "".join(parts)
if isinstance(content, Mapping):
for key in ("text", "content"):
value = content.get(key)
if isinstance(value, str):
return value
text = getattr(message, "text", None)
if isinstance(text, str):
return text
return ""
def _record_message_summary(self, message: BaseMessage, *, caller: str | None = None) -> None:
"""Update run-level convenience fields for persisted run rows."""
self._msg_count += 1
# ``last_ai_message`` should represent the lead agent's user-facing
# answer. Middleware/subagent model calls and empty tool-call-only
# AI messages must not overwrite the last useful assistant text.
is_ai_message = isinstance(message, AIMessage) or getattr(message, "type", None) == "ai"
if is_ai_message and (caller is None or caller == "lead_agent"):
text = self._message_text(message).strip()
if text:
self._last_ai_msg = text[:2000]
def on_chain_start(
self,
serialized: dict[str, Any],
@@ -164,6 +210,7 @@ class RunJournal(BaseCallbackHandler):
content=m.model_dump(),
metadata={"caller": caller},
)
self._record_message_summary(m, caller=caller)
break
if self._first_human_msg:
break
@@ -222,6 +269,8 @@ class RunJournal(BaseCallbackHandler):
"llm_call_index": call_index,
},
)
if rid not in self._counted_message_llm_run_ids:
self._record_message_summary(message, caller=caller)
# Token accumulation (dedup by langchain run_id to avoid double-counting
# when the callback fires more than once for the same response)
@@ -245,6 +294,9 @@ class RunJournal(BaseCallbackHandler):
else:
self._lead_agent_tokens += total_tk
if messages:
self._counted_message_llm_run_ids.add(str(run_id))
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times.pop(str(run_id), None)
self._put(event_type="llm.error", category="trace", content=str(error))
@@ -260,12 +312,14 @@ class RunJournal(BaseCallbackHandler):
if isinstance(output, ToolMessage):
msg = cast(ToolMessage, output)
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
self._record_message_summary(msg)
elif isinstance(output, Command):
cmd = cast(Command, output)
messages = cmd.update.get("messages", [])
for message in messages:
if isinstance(message, BaseMessage):
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
self._record_message_summary(message)
else:
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
else:
@@ -36,6 +36,7 @@ class RunRecord:
abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
abort_action: str = "interrupt"
error: str | None = None
model_name: str | None = None
class RunManager:
@@ -65,6 +66,7 @@ class RunManager:
metadata=record.metadata or {},
kwargs=record.kwargs or {},
created_at=record.created_at,
model_name=record.model_name,
)
except Exception:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
@@ -137,6 +139,18 @@ class RunManager:
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
logger.info("Run %s -> %s", run_id, status.value)
async def update_model_name(self, run_id: str, model_name: str | None) -> None:
"""Update the model name for a run."""
async with self._lock:
record = self._runs.get(run_id)
if record is None:
logger.warning("update_model_name called for unknown run %s", run_id)
return
record.model_name = model_name
record.updated_at = _now_iso()
await self._persist_to_store(record)
logger.info("Run %s model_name=%s", run_id, model_name)
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
"""Request cancellation of a run.
@@ -171,6 +185,7 @@ class RunManager:
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
model_name: str | None = None,
) -> RunRecord:
"""Atomically check for inflight runs and create a new one.
@@ -221,6 +236,7 @@ class RunManager:
kwargs=kwargs or {},
created_at=now,
updated_at=now,
model_name=model_name,
)
self._runs[run_id] = record
@@ -23,6 +23,7 @@ class RunStore(abc.ABC):
thread_id: str,
assistant_id: str | None = None,
user_id: str | None = None,
model_name: str | None = None,
status: str = "pending",
multitask_strategy: str = "reject",
metadata: dict[str, Any] | None = None,
@@ -22,6 +22,7 @@ class MemoryRunStore(RunStore):
thread_id,
assistant_id=None,
user_id=None,
model_name=None,
status="pending",
multitask_strategy="reject",
metadata=None,
@@ -35,6 +36,7 @@ class MemoryRunStore(RunStore):
"thread_id": thread_id,
"assistant_id": assistant_id,
"user_id": user_id,
"model_name": model_name,
"status": status,
"multitask_strategy": multitask_strategy,
"metadata": metadata or {},
@@ -230,6 +230,17 @@ async def run_agent(
else:
agent = agent_factory(config=runnable_config)
# Capture the effective (resolved) model name from the agent's metadata.
# _resolve_model_name in agent.py may return the default model if the
# requested name is not in the allowlist — this update ensures the
# persisted model_name reflects the actual model used.
if record.model_name is not None:
resolved = getattr(agent, "metadata", {}) or {}
if isinstance(resolved, dict):
effective = resolved.get("model_name")
if effective and effective != record.model_name:
await run_manager.update_model_name(record.run_id, effective)
# 4. Attach checkpointer and store
if checkpointer is not None:
agent.checkpointer = checkpointer
@@ -26,7 +26,7 @@ class SubagentConfig:
name: str
description: str
system_prompt: str
system_prompt: str | None = None
tools: list[str] | None = None
disallowed_tools: list[str] | None = field(default_factory=lambda: ["task"])
skills: list[str] | None = None
@@ -286,11 +286,13 @@ class SubagentExecutor:
# Reuse shared middleware composition with lead agent.
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=self.model_name, lazy_init=True)
# system_prompt is included in initial state messages (see _build_initial_state)
# to avoid multiple SystemMessages which some LLM APIs don't support.
return create_agent(
model=model,
tools=tools if tools is not None else self.tools,
middleware=middlewares,
system_prompt=self.config.system_prompt,
system_prompt=None,
state_schema=ThreadState,
)
@@ -365,14 +367,25 @@ class SubagentExecutor:
Returns:
Initial state dictionary and tools filtered by loaded skill metadata.
"""
# Load skills as conversation items (Codex pattern)
skills = await self._load_skills()
filtered_tools = self._apply_skill_allowed_tools(skills)
skill_messages = await self._load_skill_messages(skills)
# Combine system_prompt and skills into a single SystemMessage.
# Some LLM APIs reject multiple SystemMessages with
# "System message must be at the beginning."
system_parts: list[str] = []
if self.config.system_prompt:
system_parts.append(self.config.system_prompt)
for skill_msg in skill_messages:
system_parts.append(skill_msg.content)
messages: list[Any] = []
# Skill content injected as developer/system messages before the task
messages.extend(skill_messages)
if system_parts:
messages.append(SystemMessage(content="\n\n".join(system_parts)))
# Then the actual task
messages.append(HumanMessage(content=task))
@@ -10,11 +10,11 @@ from weakref import WeakValueDictionary
from langchain.tools import tool
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.mcp.tools import _make_sync_tool_wrapper
from deerflow.skills.security_scanner import scan_skill_content
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.storage.skill_storage import SkillStorage
from deerflow.skills.types import SKILL_MD_FILE
from deerflow.tools.sync import make_sync_tool_wrapper
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
@@ -235,4 +235,4 @@ async def skill_manage_tool(
)
skill_manage_tool.func = _make_sync_tool_wrapper(_skill_manage_impl, "skill_manage")
skill_manage_tool.func = make_sync_tool_wrapper(_skill_manage_impl, "skill_manage")
@@ -0,0 +1,36 @@
"""Utilities for invoking async tools from synchronous agent paths."""
import asyncio
import atexit
import concurrent.futures
import logging
from collections.abc import Callable
from typing import Any
logger = logging.getLogger(__name__)
# Shared thread pool for sync tool invocation in async environments.
_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="tool-sync")
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
"""Build a synchronous wrapper for an asynchronous tool coroutine."""
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
try:
if loop is not None and loop.is_running():
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
return future.result()
return asyncio.run(coro(*args, **kwargs))
except Exception as e:
logger.error("Error invoking tool %r via sync wrapper: %s", tool_name, e, exc_info=True)
raise
return sync_wrapper
@@ -8,6 +8,7 @@ from deerflow.reflection import resolve_variable
from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
from deerflow.tools.builtins.tool_search import reset_deferred_registry
from deerflow.tools.sync import make_sync_tool_wrapper
logger = logging.getLogger(__name__)
@@ -33,6 +34,13 @@ def _is_host_bash_tool(tool: object) -> bool:
return False
def _ensure_sync_invocable_tool(tool: BaseTool) -> BaseTool:
"""Attach a sync wrapper to async-only tools used by sync agent callers."""
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
return tool
def get_available_tools(
groups: list[str] | None = None,
include_mcp: bool = True,
@@ -77,7 +85,7 @@ def get_available_tools(
cfg.use,
)
loaded_tools = [t for _, t in loaded_tools_raw]
loaded_tools = [_ensure_sync_invocable_tool(t) for _, t in loaded_tools_raw]
# Conditionally add tools based on config
builtin_tools = BUILTIN_TOOLS.copy()