mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 07:56:48 +00:00
Merge branch 'main' into fix-2804
This commit is contained in:
+57
-26
@@ -36,42 +36,73 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
|
||||
@staticmethod
|
||||
def _message_tool_calls(msg) -> list[dict]:
|
||||
"""Return normalized tool calls from structured fields or raw provider payloads."""
|
||||
"""Return normalized tool calls from structured fields or raw provider payloads.
|
||||
|
||||
LangChain stores malformed provider function calls in ``invalid_tool_calls``.
|
||||
They do not execute, but provider adapters may still serialize enough of
|
||||
the call id/name back into the next request that strict OpenAI-compatible
|
||||
validators expect a matching ToolMessage. Treat them as dangling calls so
|
||||
the next model request stays well-formed and the model sees a recoverable
|
||||
tool error instead of another provider 400.
|
||||
"""
|
||||
normalized: list[dict] = []
|
||||
|
||||
tool_calls = getattr(msg, "tool_calls", None) or []
|
||||
if tool_calls:
|
||||
return list(tool_calls)
|
||||
normalized.extend(list(tool_calls))
|
||||
|
||||
raw_tool_calls = (getattr(msg, "additional_kwargs", None) or {}).get("tool_calls") or []
|
||||
normalized: list[dict] = []
|
||||
for raw_tc in raw_tool_calls:
|
||||
if not isinstance(raw_tc, dict):
|
||||
if not tool_calls:
|
||||
for raw_tc in raw_tool_calls:
|
||||
if not isinstance(raw_tc, dict):
|
||||
continue
|
||||
|
||||
function = raw_tc.get("function")
|
||||
name = raw_tc.get("name")
|
||||
if not name and isinstance(function, dict):
|
||||
name = function.get("name")
|
||||
|
||||
args = raw_tc.get("args", {})
|
||||
if not args and isinstance(function, dict):
|
||||
raw_args = function.get("arguments")
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
parsed_args = json.loads(raw_args)
|
||||
except (TypeError, ValueError, json.JSONDecodeError):
|
||||
parsed_args = {}
|
||||
args = parsed_args if isinstance(parsed_args, dict) else {}
|
||||
|
||||
normalized.append(
|
||||
{
|
||||
"id": raw_tc.get("id"),
|
||||
"name": name or "unknown",
|
||||
"args": args if isinstance(args, dict) else {},
|
||||
}
|
||||
)
|
||||
|
||||
for invalid_tc in getattr(msg, "invalid_tool_calls", None) or []:
|
||||
if not isinstance(invalid_tc, dict):
|
||||
continue
|
||||
|
||||
function = raw_tc.get("function")
|
||||
name = raw_tc.get("name")
|
||||
if not name and isinstance(function, dict):
|
||||
name = function.get("name")
|
||||
|
||||
args = raw_tc.get("args", {})
|
||||
if not args and isinstance(function, dict):
|
||||
raw_args = function.get("arguments")
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
parsed_args = json.loads(raw_args)
|
||||
except (TypeError, ValueError, json.JSONDecodeError):
|
||||
parsed_args = {}
|
||||
args = parsed_args if isinstance(parsed_args, dict) else {}
|
||||
|
||||
normalized.append(
|
||||
{
|
||||
"id": raw_tc.get("id"),
|
||||
"name": name or "unknown",
|
||||
"args": args if isinstance(args, dict) else {},
|
||||
"id": invalid_tc.get("id"),
|
||||
"name": invalid_tc.get("name") or "unknown",
|
||||
"args": {},
|
||||
"invalid": True,
|
||||
"error": invalid_tc.get("error"),
|
||||
}
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _synthetic_tool_message_content(tool_call: dict) -> str:
|
||||
if tool_call.get("invalid"):
|
||||
error = tool_call.get("error")
|
||||
if isinstance(error, str) and error:
|
||||
return f"[Tool call could not be executed because its arguments were invalid: {error}]"
|
||||
return "[Tool call could not be executed because its arguments were invalid.]"
|
||||
return "[Tool call was interrupted and did not return a result.]"
|
||||
|
||||
def _build_patched_messages(self, messages: list) -> list | None:
|
||||
"""Return a new message list with patches inserted at the correct positions.
|
||||
|
||||
@@ -114,7 +145,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
|
||||
patched.append(
|
||||
ToolMessage(
|
||||
content="[Tool call was interrupted and did not return a result.]",
|
||||
content=self._synthetic_tool_message_content(tc),
|
||||
tool_call_id=tc_id,
|
||||
name=tc.get("name", "unknown"),
|
||||
status="error",
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
"""Load MCP tools using langchain-mcp-adapters."""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
@@ -13,46 +8,10 @@ from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.mcp.client import build_servers_config
|
||||
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
|
||||
from deerflow.reflection import resolve_variable
|
||||
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global thread pool for sync tool invocation in async environments
|
||||
_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="mcp-sync-tool")
|
||||
|
||||
# Register shutdown hook for the global executor
|
||||
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
|
||||
|
||||
|
||||
def _make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
|
||||
"""Build a synchronous wrapper for an asynchronous tool coroutine.
|
||||
|
||||
Args:
|
||||
coro: The tool's asynchronous coroutine.
|
||||
tool_name: Name of the tool (for logging).
|
||||
|
||||
Returns:
|
||||
A synchronous function that correctly handles nested event loops.
|
||||
"""
|
||||
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
try:
|
||||
if loop is not None and loop.is_running():
|
||||
# Use global executor to avoid nested loop issues and improve performance
|
||||
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
|
||||
return future.result()
|
||||
else:
|
||||
return asyncio.run(coro(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.error(f"Error invoking MCP tool '{tool_name}' via sync wrapper: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
async def get_mcp_tools() -> list[BaseTool]:
|
||||
"""Get all tools from enabled MCP servers.
|
||||
@@ -126,7 +85,7 @@ async def get_mcp_tools() -> list[BaseTool]:
|
||||
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
||||
for tool in tools:
|
||||
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
|
||||
tool.func = _make_sync_tool_wrapper(tool.coroutine, tool.name)
|
||||
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
@@ -23,6 +23,18 @@ class RunRepository(RunStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model_name(model_name: str | None) -> str | None:
|
||||
"""Normalize model_name for storage: strip whitespace, truncate to 128 chars."""
|
||||
if model_name is None:
|
||||
return None
|
||||
if not isinstance(model_name, str):
|
||||
model_name = str(model_name)
|
||||
normalized = model_name.strip()
|
||||
if len(normalized) > 128:
|
||||
normalized = normalized[:128]
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _safe_json(obj: Any) -> Any:
|
||||
"""Ensure obj is JSON-serializable. Falls back to model_dump() or str()."""
|
||||
@@ -70,6 +82,7 @@ class RunRepository(RunStore):
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
model_name: str | None = None,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
@@ -85,6 +98,7 @@ class RunRepository(RunStore):
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
user_id=resolved_user_id,
|
||||
model_name=self._normalize_model_name(model_name),
|
||||
status=status,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata_json=self._safe_json(metadata) or {},
|
||||
|
||||
@@ -20,12 +20,13 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -71,6 +72,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
# Dedup: LangChain may fire on_llm_end multiple times for the same run_id
|
||||
self._counted_llm_run_ids: set[str] = set()
|
||||
self._counted_external_source_ids: set[str] = set()
|
||||
self._counted_message_llm_run_ids: set[str] = set()
|
||||
|
||||
# Convenience fields
|
||||
self._last_ai_msg: str | None = None
|
||||
@@ -86,6 +88,50 @@ class RunJournal(BaseCallbackHandler):
|
||||
|
||||
# -- Lifecycle callbacks --
|
||||
|
||||
@staticmethod
|
||||
def _message_text(message: BaseMessage) -> str:
|
||||
"""Extract displayable text from a message's mixed content shape."""
|
||||
content = getattr(message, "content", None)
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
parts.append(block)
|
||||
elif isinstance(block, Mapping):
|
||||
text = block.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
else:
|
||||
nested = block.get("content")
|
||||
if isinstance(nested, str):
|
||||
parts.append(nested)
|
||||
return "".join(parts)
|
||||
if isinstance(content, Mapping):
|
||||
for key in ("text", "content"):
|
||||
value = content.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
text = getattr(message, "text", None)
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return ""
|
||||
|
||||
def _record_message_summary(self, message: BaseMessage, *, caller: str | None = None) -> None:
|
||||
"""Update run-level convenience fields for persisted run rows."""
|
||||
self._msg_count += 1
|
||||
|
||||
# ``last_ai_message`` should represent the lead agent's user-facing
|
||||
# answer. Middleware/subagent model calls and empty tool-call-only
|
||||
# AI messages must not overwrite the last useful assistant text.
|
||||
is_ai_message = isinstance(message, AIMessage) or getattr(message, "type", None) == "ai"
|
||||
if is_ai_message and (caller is None or caller == "lead_agent"):
|
||||
text = self._message_text(message).strip()
|
||||
if text:
|
||||
self._last_ai_msg = text[:2000]
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
@@ -164,6 +210,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
content=m.model_dump(),
|
||||
metadata={"caller": caller},
|
||||
)
|
||||
self._record_message_summary(m, caller=caller)
|
||||
break
|
||||
if self._first_human_msg:
|
||||
break
|
||||
@@ -222,6 +269,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
"llm_call_index": call_index,
|
||||
},
|
||||
)
|
||||
if rid not in self._counted_message_llm_run_ids:
|
||||
self._record_message_summary(message, caller=caller)
|
||||
|
||||
# Token accumulation (dedup by langchain run_id to avoid double-counting
|
||||
# when the callback fires more than once for the same response)
|
||||
@@ -245,6 +294,9 @@ class RunJournal(BaseCallbackHandler):
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
if messages:
|
||||
self._counted_message_llm_run_ids.add(str(run_id))
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
self._put(event_type="llm.error", category="trace", content=str(error))
|
||||
@@ -260,12 +312,14 @@ class RunJournal(BaseCallbackHandler):
|
||||
if isinstance(output, ToolMessage):
|
||||
msg = cast(ToolMessage, output)
|
||||
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
|
||||
self._record_message_summary(msg)
|
||||
elif isinstance(output, Command):
|
||||
cmd = cast(Command, output)
|
||||
messages = cmd.update.get("messages", [])
|
||||
for message in messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
|
||||
self._record_message_summary(message)
|
||||
else:
|
||||
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
|
||||
else:
|
||||
|
||||
@@ -36,6 +36,7 @@ class RunRecord:
|
||||
abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
|
||||
abort_action: str = "interrupt"
|
||||
error: str | None = None
|
||||
model_name: str | None = None
|
||||
|
||||
|
||||
class RunManager:
|
||||
@@ -65,6 +66,7 @@ class RunManager:
|
||||
metadata=record.metadata or {},
|
||||
kwargs=record.kwargs or {},
|
||||
created_at=record.created_at,
|
||||
model_name=record.model_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
||||
@@ -137,6 +139,18 @@ class RunManager:
|
||||
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
|
||||
logger.info("Run %s -> %s", run_id, status.value)
|
||||
|
||||
async def update_model_name(self, run_id: str, model_name: str | None) -> None:
|
||||
"""Update the model name for a run."""
|
||||
async with self._lock:
|
||||
record = self._runs.get(run_id)
|
||||
if record is None:
|
||||
logger.warning("update_model_name called for unknown run %s", run_id)
|
||||
return
|
||||
record.model_name = model_name
|
||||
record.updated_at = _now_iso()
|
||||
await self._persist_to_store(record)
|
||||
logger.info("Run %s model_name=%s", run_id, model_name)
|
||||
|
||||
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
|
||||
"""Request cancellation of a run.
|
||||
|
||||
@@ -171,6 +185,7 @@ class RunManager:
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
model_name: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Atomically check for inflight runs and create a new one.
|
||||
|
||||
@@ -221,6 +236,7 @@ class RunManager:
|
||||
kwargs=kwargs or {},
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
model_name=model_name,
|
||||
)
|
||||
self._runs[run_id] = record
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ class RunStore(abc.ABC):
|
||||
thread_id: str,
|
||||
assistant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
model_name: str | None = None,
|
||||
status: str = "pending",
|
||||
multitask_strategy: str = "reject",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
|
||||
@@ -22,6 +22,7 @@ class MemoryRunStore(RunStore):
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
user_id=None,
|
||||
model_name=None,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
@@ -35,6 +36,7 @@ class MemoryRunStore(RunStore):
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"user_id": user_id,
|
||||
"model_name": model_name,
|
||||
"status": status,
|
||||
"multitask_strategy": multitask_strategy,
|
||||
"metadata": metadata or {},
|
||||
|
||||
@@ -230,6 +230,17 @@ async def run_agent(
|
||||
else:
|
||||
agent = agent_factory(config=runnable_config)
|
||||
|
||||
# Capture the effective (resolved) model name from the agent's metadata.
|
||||
# _resolve_model_name in agent.py may return the default model if the
|
||||
# requested name is not in the allowlist — this update ensures the
|
||||
# persisted model_name reflects the actual model used.
|
||||
if record.model_name is not None:
|
||||
resolved = getattr(agent, "metadata", {}) or {}
|
||||
if isinstance(resolved, dict):
|
||||
effective = resolved.get("model_name")
|
||||
if effective and effective != record.model_name:
|
||||
await run_manager.update_model_name(record.run_id, effective)
|
||||
|
||||
# 4. Attach checkpointer and store
|
||||
if checkpointer is not None:
|
||||
agent.checkpointer = checkpointer
|
||||
|
||||
@@ -26,7 +26,7 @@ class SubagentConfig:
|
||||
|
||||
name: str
|
||||
description: str
|
||||
system_prompt: str
|
||||
system_prompt: str | None = None
|
||||
tools: list[str] | None = None
|
||||
disallowed_tools: list[str] | None = field(default_factory=lambda: ["task"])
|
||||
skills: list[str] | None = None
|
||||
|
||||
@@ -286,11 +286,13 @@ class SubagentExecutor:
|
||||
# Reuse shared middleware composition with lead agent.
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=self.model_name, lazy_init=True)
|
||||
|
||||
# system_prompt is included in initial state messages (see _build_initial_state)
|
||||
# to avoid multiple SystemMessages which some LLM APIs don't support.
|
||||
return create_agent(
|
||||
model=model,
|
||||
tools=tools if tools is not None else self.tools,
|
||||
middleware=middlewares,
|
||||
system_prompt=self.config.system_prompt,
|
||||
system_prompt=None,
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
@@ -365,14 +367,25 @@ class SubagentExecutor:
|
||||
Returns:
|
||||
Initial state dictionary and tools filtered by loaded skill metadata.
|
||||
"""
|
||||
|
||||
# Load skills as conversation items (Codex pattern)
|
||||
skills = await self._load_skills()
|
||||
filtered_tools = self._apply_skill_allowed_tools(skills)
|
||||
skill_messages = await self._load_skill_messages(skills)
|
||||
|
||||
# Combine system_prompt and skills into a single SystemMessage.
|
||||
# Some LLM APIs reject multiple SystemMessages with
|
||||
# "System message must be at the beginning."
|
||||
system_parts: list[str] = []
|
||||
if self.config.system_prompt:
|
||||
system_parts.append(self.config.system_prompt)
|
||||
for skill_msg in skill_messages:
|
||||
system_parts.append(skill_msg.content)
|
||||
|
||||
messages: list[Any] = []
|
||||
# Skill content injected as developer/system messages before the task
|
||||
messages.extend(skill_messages)
|
||||
if system_parts:
|
||||
messages.append(SystemMessage(content="\n\n".join(system_parts)))
|
||||
|
||||
# Then the actual task
|
||||
messages.append(HumanMessage(content=task))
|
||||
|
||||
|
||||
@@ -10,11 +10,11 @@ from weakref import WeakValueDictionary
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||
from deerflow.mcp.tools import _make_sync_tool_wrapper
|
||||
from deerflow.skills.security_scanner import scan_skill_content
|
||||
from deerflow.skills.storage import get_or_new_skill_storage
|
||||
from deerflow.skills.storage.skill_storage import SkillStorage
|
||||
from deerflow.skills.types import SKILL_MD_FILE
|
||||
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||
from deerflow.tools.types import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -235,4 +235,4 @@ async def skill_manage_tool(
|
||||
)
|
||||
|
||||
|
||||
skill_manage_tool.func = _make_sync_tool_wrapper(_skill_manage_impl, "skill_manage")
|
||||
skill_manage_tool.func = make_sync_tool_wrapper(_skill_manage_impl, "skill_manage")
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Utilities for invoking async tools from synchronous agent paths."""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Shared thread pool for sync tool invocation in async environments.
|
||||
_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="tool-sync")
|
||||
|
||||
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
|
||||
|
||||
|
||||
def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
|
||||
"""Build a synchronous wrapper for an asynchronous tool coroutine."""
|
||||
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
try:
|
||||
if loop is not None and loop.is_running():
|
||||
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
|
||||
return future.result()
|
||||
return asyncio.run(coro(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.error("Error invoking tool %r via sync wrapper: %s", tool_name, e, exc_info=True)
|
||||
raise
|
||||
|
||||
return sync_wrapper
|
||||
@@ -8,6 +8,7 @@ from deerflow.reflection import resolve_variable
|
||||
from deerflow.sandbox.security import is_host_bash_allowed
|
||||
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
|
||||
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
||||
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,6 +34,13 @@ def _is_host_bash_tool(tool: object) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _ensure_sync_invocable_tool(tool: BaseTool) -> BaseTool:
|
||||
"""Attach a sync wrapper to async-only tools used by sync agent callers."""
|
||||
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
|
||||
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
|
||||
return tool
|
||||
|
||||
|
||||
def get_available_tools(
|
||||
groups: list[str] | None = None,
|
||||
include_mcp: bool = True,
|
||||
@@ -77,7 +85,7 @@ def get_available_tools(
|
||||
cfg.use,
|
||||
)
|
||||
|
||||
loaded_tools = [t for _, t in loaded_tools_raw]
|
||||
loaded_tools = [_ensure_sync_invocable_tool(t) for _, t in loaded_tools_raw]
|
||||
|
||||
# Conditionally add tools based on config
|
||||
builtin_tools = BUILTIN_TOOLS.copy()
|
||||
|
||||
Reference in New Issue
Block a user