fix(harness): wrap async-only config tools for sync client execution (#2878)
* fix(harness): wrap async-only config tools for sync clients * refactor(tools): share async tool sync wrapper
This commit is contained in:
@@ -1,11 +1,6 @@
|
|||||||
"""Load MCP tools using langchain-mcp-adapters."""
|
"""Load MCP tools using langchain-mcp-adapters."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import atexit
|
|
||||||
import concurrent.futures
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
@@ -13,46 +8,10 @@ from deerflow.config.extensions_config import ExtensionsConfig
|
|||||||
from deerflow.mcp.client import build_servers_config
|
from deerflow.mcp.client import build_servers_config
|
||||||
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
|
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
|
||||||
from deerflow.reflection import resolve_variable
|
from deerflow.reflection import resolve_variable
|
||||||
|
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Global thread pool for sync tool invocation in async environments
|
|
||||||
_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="mcp-sync-tool")
|
|
||||||
|
|
||||||
# Register shutdown hook for the global executor
|
|
||||||
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
|
|
||||||
|
|
||||||
|
|
||||||
def _make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
|
|
||||||
"""Build a synchronous wrapper for an asynchronous tool coroutine.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
coro: The tool's asynchronous coroutine.
|
|
||||||
tool_name: Name of the tool (for logging).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A synchronous function that correctly handles nested event loops.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
loop = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
if loop is not None and loop.is_running():
|
|
||||||
# Use global executor to avoid nested loop issues and improve performance
|
|
||||||
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
|
|
||||||
return future.result()
|
|
||||||
else:
|
|
||||||
return asyncio.run(coro(*args, **kwargs))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error invoking MCP tool '{tool_name}' via sync wrapper: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return sync_wrapper
|
|
||||||
|
|
||||||
|
|
||||||
async def get_mcp_tools() -> list[BaseTool]:
|
async def get_mcp_tools() -> list[BaseTool]:
|
||||||
"""Get all tools from enabled MCP servers.
|
"""Get all tools from enabled MCP servers.
|
||||||
@@ -126,7 +85,7 @@ async def get_mcp_tools() -> list[BaseTool]:
|
|||||||
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
|
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
|
||||||
tool.func = _make_sync_tool_wrapper(tool.coroutine, tool.name)
|
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|||||||
@@ -10,11 +10,11 @@ from weakref import WeakValueDictionary
|
|||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
|
|
||||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||||
from deerflow.mcp.tools import _make_sync_tool_wrapper
|
|
||||||
from deerflow.skills.security_scanner import scan_skill_content
|
from deerflow.skills.security_scanner import scan_skill_content
|
||||||
from deerflow.skills.storage import get_or_new_skill_storage
|
from deerflow.skills.storage import get_or_new_skill_storage
|
||||||
from deerflow.skills.storage.skill_storage import SkillStorage
|
from deerflow.skills.storage.skill_storage import SkillStorage
|
||||||
from deerflow.skills.types import SKILL_MD_FILE
|
from deerflow.skills.types import SKILL_MD_FILE
|
||||||
|
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||||
from deerflow.tools.types import Runtime
|
from deerflow.tools.types import Runtime
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -235,4 +235,4 @@ async def skill_manage_tool(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
skill_manage_tool.func = _make_sync_tool_wrapper(_skill_manage_impl, "skill_manage")
|
skill_manage_tool.func = make_sync_tool_wrapper(_skill_manage_impl, "skill_manage")
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
"""Utilities for invoking async tools from synchronous agent paths."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import atexit
|
||||||
|
import concurrent.futures
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Shared thread pool for sync tool invocation in async environments.
|
||||||
|
_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="tool-sync")
|
||||||
|
|
||||||
|
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
|
||||||
|
|
||||||
|
|
||||||
|
def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
|
||||||
|
"""Build a synchronous wrapper for an asynchronous tool coroutine."""
|
||||||
|
|
||||||
|
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if loop is not None and loop.is_running():
|
||||||
|
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
|
||||||
|
return future.result()
|
||||||
|
return asyncio.run(coro(*args, **kwargs))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error invoking tool %r via sync wrapper: %s", tool_name, e, exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return sync_wrapper
|
||||||
@@ -8,6 +8,7 @@ from deerflow.reflection import resolve_variable
|
|||||||
from deerflow.sandbox.security import is_host_bash_allowed
|
from deerflow.sandbox.security import is_host_bash_allowed
|
||||||
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
|
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
|
||||||
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
||||||
|
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -33,6 +34,13 @@ def _is_host_bash_tool(tool: object) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_sync_invocable_tool(tool: BaseTool) -> BaseTool:
|
||||||
|
"""Attach a sync wrapper to async-only tools used by sync agent callers."""
|
||||||
|
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
|
||||||
|
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
|
||||||
|
return tool
|
||||||
|
|
||||||
|
|
||||||
def get_available_tools(
|
def get_available_tools(
|
||||||
groups: list[str] | None = None,
|
groups: list[str] | None = None,
|
||||||
include_mcp: bool = True,
|
include_mcp: bool = True,
|
||||||
@@ -77,7 +85,7 @@ def get_available_tools(
|
|||||||
cfg.use,
|
cfg.use,
|
||||||
)
|
)
|
||||||
|
|
||||||
loaded_tools = [t for _, t in loaded_tools_raw]
|
loaded_tools = [_ensure_sync_invocable_tool(t) for _, t in loaded_tools_raw]
|
||||||
|
|
||||||
# Conditionally add tools based on config
|
# Conditionally add tools based on config
|
||||||
builtin_tools = BUILTIN_TOOLS.copy()
|
builtin_tools = BUILTIN_TOOLS.copy()
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ import pytest
|
|||||||
from langchain_core.tools import StructuredTool
|
from langchain_core.tools import StructuredTool
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from deerflow.mcp.tools import _make_sync_tool_wrapper, get_mcp_tools
|
from deerflow.mcp.tools import get_mcp_tools
|
||||||
|
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||||
|
|
||||||
|
|
||||||
class MockArgs(BaseModel):
|
class MockArgs(BaseModel):
|
||||||
@@ -51,14 +52,13 @@ def test_mcp_tool_sync_wrapper_generation():
|
|||||||
|
|
||||||
|
|
||||||
def test_mcp_tool_sync_wrapper_in_running_loop():
|
def test_mcp_tool_sync_wrapper_in_running_loop():
|
||||||
"""Test the actual helper function from production code (Fix for Comment 1 & 3)."""
|
"""Test the shared sync wrapper from production code."""
|
||||||
|
|
||||||
async def mock_coro(x: int):
|
async def mock_coro(x: int):
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
return f"async_result: {x}"
|
return f"async_result: {x}"
|
||||||
|
|
||||||
# Test the real helper function exported from deerflow.mcp.tools
|
sync_func = make_sync_tool_wrapper(mock_coro, "test_tool")
|
||||||
sync_func = _make_sync_tool_wrapper(mock_coro, "test_tool")
|
|
||||||
|
|
||||||
async def run_in_loop():
|
async def run_in_loop():
|
||||||
# This call should succeed due to ThreadPoolExecutor in the real helper
|
# This call should succeed due to ThreadPoolExecutor in the real helper
|
||||||
@@ -70,16 +70,16 @@ def test_mcp_tool_sync_wrapper_in_running_loop():
|
|||||||
|
|
||||||
|
|
||||||
def test_mcp_tool_sync_wrapper_exception_logging():
|
def test_mcp_tool_sync_wrapper_exception_logging():
|
||||||
"""Test the actual helper's error logging (Fix for Comment 3)."""
|
"""Test the shared sync wrapper's error logging."""
|
||||||
|
|
||||||
async def error_coro():
|
async def error_coro():
|
||||||
raise ValueError("Tool failure")
|
raise ValueError("Tool failure")
|
||||||
|
|
||||||
sync_func = _make_sync_tool_wrapper(error_coro, "error_tool")
|
sync_func = make_sync_tool_wrapper(error_coro, "error_tool")
|
||||||
|
|
||||||
with patch("deerflow.mcp.tools.logger.error") as mock_log_error:
|
with patch("deerflow.tools.sync.logger.error") as mock_log_error:
|
||||||
with pytest.raises(ValueError, match="Tool failure"):
|
with pytest.raises(ValueError, match="Tool failure"):
|
||||||
sync_func()
|
sync_func()
|
||||||
mock_log_error.assert_called_once()
|
mock_log_error.assert_called_once()
|
||||||
# Verify the tool name is in the log message
|
# Verify the tool name is in the log message
|
||||||
assert "error_tool" in mock_log_error.call_args[0][0]
|
assert mock_log_error.call_args[0][1] == "error_tool"
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from langchain_core.tools import BaseTool, tool
|
from langchain_core.tools import BaseTool, StructuredTool, tool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from deerflow.tools.tools import get_available_tools
|
from deerflow.tools.tools import get_available_tools
|
||||||
|
|
||||||
@@ -19,6 +20,10 @@ from deerflow.tools.tools import get_available_tools
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncToolArgs(BaseModel):
|
||||||
|
x: int = Field(..., description="test input")
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def _tool_alpha(x: str) -> str:
|
def _tool_alpha(x: str) -> str:
|
||||||
"""Alpha tool."""
|
"""Alpha tool."""
|
||||||
@@ -52,10 +57,45 @@ def _make_minimal_config(tools):
|
|||||||
config.tools = tools
|
config.tools = tools
|
||||||
config.models = []
|
config.models = []
|
||||||
config.tool_search.enabled = False
|
config.tool_search.enabled = False
|
||||||
|
config.skill_evolution.enabled = False
|
||||||
config.sandbox = MagicMock()
|
config.sandbox = MagicMock()
|
||||||
|
config.acp_agents = {}
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@patch("deerflow.tools.tools.get_app_config")
|
||||||
|
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||||
|
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||||
|
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, mock_cfg):
|
||||||
|
"""Config-loaded async-only tools can still be invoked by sync clients."""
|
||||||
|
|
||||||
|
async def async_tool_impl(x: int) -> str:
|
||||||
|
return f"result: {x}"
|
||||||
|
|
||||||
|
async_tool = StructuredTool(
|
||||||
|
name="async_tool",
|
||||||
|
description="Async-only test tool.",
|
||||||
|
args_schema=AsyncToolArgs,
|
||||||
|
func=None,
|
||||||
|
coroutine=async_tool_impl,
|
||||||
|
)
|
||||||
|
tool_cfg = MagicMock()
|
||||||
|
tool_cfg.name = "async_tool"
|
||||||
|
tool_cfg.group = "test"
|
||||||
|
tool_cfg.use = "tests.fake:async_tool"
|
||||||
|
mock_cfg.return_value = _make_minimal_config([tool_cfg])
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.tools.tools.resolve_variable", return_value=async_tool),
|
||||||
|
patch("deerflow.tools.tools.BUILTIN_TOOLS", []),
|
||||||
|
):
|
||||||
|
result = get_available_tools(include_mcp=False, app_config=mock_cfg.return_value)
|
||||||
|
|
||||||
|
assert async_tool in result
|
||||||
|
assert async_tool.func is not None
|
||||||
|
assert async_tool.invoke({"x": 42}) == "result: 42"
|
||||||
|
|
||||||
|
|
||||||
@patch("deerflow.tools.tools.get_app_config")
|
@patch("deerflow.tools.tools.get_app_config")
|
||||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||||
@patch("deerflow.tools.tools.reset_deferred_registry")
|
@patch("deerflow.tools.tools.reset_deferred_registry")
|
||||||
|
|||||||
Reference in New Issue
Block a user