fix(harness): wrap all async-only tools for sync clients (#2935)

This commit is contained in:
AochenShen99
2026-05-19 22:11:46 +08:00
committed by GitHub
parent c810e9f809
commit 3599b570a9
5 changed files with 260 additions and 6 deletions
@@ -3,9 +3,13 @@
import asyncio import asyncio
import atexit import atexit
import concurrent.futures import concurrent.futures
import contextvars
import functools
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any, get_type_hints
from langchain_core.runnables import RunnableConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -15,10 +19,49 @@ _SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thre
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False)) atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]: def _get_runnable_config_param(func: Callable[..., Any]) -> str | None:
"""Build a synchronous wrapper for an asynchronous tool coroutine.""" """Return the coroutine parameter that expects LangChain RunnableConfig."""
if isinstance(func, functools.partial):
func = func.func
def sync_wrapper(*args: Any, **kwargs: Any) -> Any: try:
type_hints = get_type_hints(func)
except Exception:
return None
for name, type_ in type_hints.items():
if type_ is RunnableConfig:
return name
return None
def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
"""Build a synchronous wrapper for an asynchronous tool coroutine.
Args:
coro: Async callable backing a LangChain tool.
tool_name: Tool name used in error logs.
Returns:
A sync callable suitable for ``BaseTool.func``.
Notes:
If ``coro`` declares a ``RunnableConfig`` parameter, this wrapper
exposes ``config: RunnableConfig`` so LangChain can inject runtime
config and then forwards it to the coroutine's detected config
parameter. This covers DeerFlow's current config-sensitive tools, such
as ``invoke_acp_agent``.
This wrapper intentionally does not synthesize a dynamic function
signature. A future async tool with a normal user-facing argument named
``config`` and a separate ``RunnableConfig`` parameter named something
else, such as ``run_config``, may collide with LangChain's injected
``config`` argument. Rename that user-facing field or extend this
helper before using that signature.
"""
config_param = _get_runnable_config_param(coro)
def run_coroutine(*args: Any, **kwargs: Any) -> Any:
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
except RuntimeError: except RuntimeError:
@@ -26,11 +69,24 @@ def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable
try: try:
if loop is not None and loop.is_running(): if loop is not None and loop.is_running():
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs)) context = contextvars.copy_context()
future = _SYNC_TOOL_EXECUTOR.submit(context.run, lambda: asyncio.run(coro(*args, **kwargs)))
return future.result() return future.result()
return asyncio.run(coro(*args, **kwargs)) return asyncio.run(coro(*args, **kwargs))
except Exception as e: except Exception as e:
logger.error("Error invoking tool %r via sync wrapper: %s", tool_name, e, exc_info=True) logger.error("Error invoking tool %r via sync wrapper: %s", tool_name, e, exc_info=True)
raise raise
if config_param:
def sync_wrapper(*args: Any, config: RunnableConfig = None, **kwargs: Any) -> Any:
if config is not None or config_param not in kwargs:
kwargs[config_param] = config
return run_coroutine(*args, **kwargs)
return sync_wrapper
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
return run_coroutine(*args, **kwargs)
return sync_wrapper return sync_wrapper
@@ -205,7 +205,7 @@ def get_available_tools(
# Deduplicate by tool name — config-loaded tools take priority, followed by # Deduplicate by tool name — config-loaded tools take priority, followed by
# built-ins, MCP tools, and ACP tools. Duplicate names cause the LLM to # built-ins, MCP tools, and ACP tools. Duplicate names cause the LLM to
# receive ambiguous or concatenated function schemas (issue #1803). # receive ambiguous or concatenated function schemas (issue #1803).
all_tools = loaded_tools + builtin_tools + mcp_tools + acp_tools all_tools = [_ensure_sync_invocable_tool(t) for t in loaded_tools + builtin_tools + mcp_tools + acp_tools]
seen_names: set[str] = set() seen_names: set[str] = set()
unique_tools: list[BaseTool] = [] unique_tools: list[BaseTool] = []
for t in all_tools: for t in all_tools:
@@ -699,6 +699,92 @@ def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(mo
load_acp_config_from_dict({}) load_acp_config_from_dict({})
def test_get_available_tools_sync_invoke_acp_agent_preserves_thread_workspace(monkeypatch, tmp_path):
from deerflow.config import paths as paths_module
from deerflow.runtime import user_context as uc_module
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: ExtensionsConfig(mcp_servers={}, skills={})),
)
monkeypatch.setattr("deerflow.tools.tools.is_host_bash_allowed", lambda config=None: True)
captured: dict[str, object] = {}
class DummyClient:
@property
def collected_text(self) -> str:
return "ok"
async def session_update(self, session_id, update, **kwargs):
pass
async def request_permission(self, options, session_id, tool_call, **kwargs):
raise AssertionError("should not be called")
class DummyConn:
async def initialize(self, **kwargs):
pass
async def new_session(self, **kwargs):
return SimpleNamespace(session_id="s1")
async def prompt(self, **kwargs):
pass
class DummyProcessContext:
def __init__(self, client, cmd, *args, env=None, cwd):
captured["cwd"] = cwd
async def __aenter__(self):
return DummyConn(), object()
async def __aexit__(self, exc_type, exc, tb):
return False
monkeypatch.setitem(
sys.modules,
"acp",
SimpleNamespace(
PROTOCOL_VERSION="2026-03-24",
Client=DummyClient,
spawn_agent_process=lambda client, cmd, *args, env=None, cwd: DummyProcessContext(client, cmd, *args, env=env, cwd=cwd),
text_block=lambda text: {"type": "text", "text": text},
),
)
monkeypatch.setitem(
sys.modules,
"acp.schema",
SimpleNamespace(
ClientCapabilities=lambda: {},
Implementation=lambda **kwargs: kwargs,
TextContentBlock=type("TextContentBlock", (), {"__init__": lambda self, text: setattr(self, "text", text)}),
),
)
explicit_config = SimpleNamespace(
tools=[],
models=[],
tool_search=SimpleNamespace(enabled=False),
skill_evolution=SimpleNamespace(enabled=False),
sandbox=SimpleNamespace(),
get_model_config=lambda name: None,
acp_agents={"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")},
)
tools = get_available_tools(include_mcp=False, subagent_enabled=False, app_config=explicit_config)
tool = next(tool for tool in tools if tool.name == "invoke_acp_agent")
thread_id = "thread-sync-123"
tool.invoke(
{"agent": "codex", "prompt": "Do something"},
config={"configurable": {"thread_id": thread_id}},
)
assert captured["cwd"] == str(tmp_path / "threads" / thread_id / "acp-workspace")
def test_get_available_tools_uses_explicit_app_config_for_acp_agents(monkeypatch): def test_get_available_tools_uses_explicit_app_config_for_acp_agents(monkeypatch):
explicit_agents = {"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")} explicit_agents = {"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")}
explicit_config = SimpleNamespace( explicit_config = SimpleNamespace(
+54
View File
@@ -1,7 +1,9 @@
import asyncio import asyncio
import contextvars
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import StructuredTool from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -69,6 +71,58 @@ def test_mcp_tool_sync_wrapper_in_running_loop():
assert result == "async_result: 100" assert result == "async_result: 100"
def test_sync_wrapper_preserves_contextvars_in_running_loop():
"""The executor branch preserves LangGraph-style contextvars."""
current_value: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_value", default=None)
async def mock_coro() -> str | None:
return current_value.get()
sync_func = make_sync_tool_wrapper(mock_coro, "test_tool")
async def run_in_loop() -> str | None:
token = current_value.set("from-parent-context")
try:
return sync_func()
finally:
current_value.reset(token)
assert asyncio.run(run_in_loop()) == "from-parent-context"
def test_sync_wrapper_preserves_runnable_config_injection():
"""LangChain can still inject RunnableConfig after an async tool is wrapped."""
captured: dict[str, object] = {}
async def mock_coro(x: int, config: RunnableConfig = None):
captured["thread_id"] = ((config or {}).get("configurable") or {}).get("thread_id")
return f"result: {x}"
mock_tool = StructuredTool(
name="test_tool",
description="test description",
args_schema=MockArgs,
func=make_sync_tool_wrapper(mock_coro, "test_tool"),
coroutine=mock_coro,
)
result = mock_tool.invoke({"x": 42}, config={"configurable": {"thread_id": "thread-123"}})
assert result == "result: 42"
assert captured["thread_id"] == "thread-123"
def test_sync_wrapper_preserves_regular_config_argument():
"""Only RunnableConfig-annotated coroutine params get special config injection."""
async def mock_coro(config: str):
return config
sync_func = make_sync_tool_wrapper(mock_coro, "test_tool")
assert sync_func(config="user-config") == "user-config"
def test_mcp_tool_sync_wrapper_exception_logging(): def test_mcp_tool_sync_wrapper_exception_logging():
"""Test the shared sync wrapper's error logging.""" """Test the shared sync wrapper's error logging."""
+58
View File
@@ -95,6 +95,64 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
assert async_tool.invoke({"x": 42}) == "result: 42" assert async_tool.invoke({"x": 42}) == "result: 42"
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
def test_subagent_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
"""Async-only tools added through the subagent path can be invoked by sync clients."""
async def async_tool_impl(x: int) -> str:
return f"subagent: {x}"
async_tool = StructuredTool(
name="async_subagent_tool",
description="Async-only subagent test tool.",
args_schema=AsyncToolArgs,
func=None,
coroutine=async_tool_impl,
)
mock_cfg.return_value = _make_minimal_config([])
with (
patch("deerflow.tools.tools.BUILTIN_TOOLS", []),
patch("deerflow.tools.tools.SUBAGENT_TOOLS", [async_tool]),
):
result = get_available_tools(include_mcp=False, subagent_enabled=True, app_config=mock_cfg.return_value)
assert async_tool in result
assert async_tool.func is not None
assert async_tool.invoke({"x": 7}) == "subagent: 7"
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
def test_acp_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
"""Async-only ACP tools can be invoked by sync clients."""
async def async_tool_impl(x: int) -> str:
return f"acp: {x}"
async_tool = StructuredTool(
name="invoke_acp_agent",
description="Async-only ACP test tool.",
args_schema=AsyncToolArgs,
func=None,
coroutine=async_tool_impl,
)
config = _make_minimal_config([])
config.acp_agents = {"codex": object()}
mock_cfg.return_value = config
with (
patch("deerflow.tools.tools.BUILTIN_TOOLS", []),
patch("deerflow.tools.builtins.invoke_acp_agent_tool.build_invoke_acp_agent_tool", return_value=async_tool),
):
result = get_available_tools(include_mcp=False, app_config=config)
assert async_tool in result
assert async_tool.func is not None
assert async_tool.invoke({"x": 9}) == "acp: 9"
@patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
def test_no_duplicates_returned(mock_bash, mock_cfg): def test_no_duplicates_returned(mock_bash, mock_cfg):