mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
fix(harness): wrap all async-only tools for sync clients (#2935)
This commit is contained in:
@@ -3,9 +3,13 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,10 +19,49 @@ _SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thre
|
||||
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
|
||||
|
||||
|
||||
def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
|
||||
"""Build a synchronous wrapper for an asynchronous tool coroutine."""
|
||||
def _get_runnable_config_param(func: Callable[..., Any]) -> str | None:
|
||||
"""Return the coroutine parameter that expects LangChain RunnableConfig."""
|
||||
if isinstance(func, functools.partial):
|
||||
func = func.func
|
||||
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
type_hints = get_type_hints(func)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
for name, type_ in type_hints.items():
|
||||
if type_ is RunnableConfig:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
|
||||
"""Build a synchronous wrapper for an asynchronous tool coroutine.
|
||||
|
||||
Args:
|
||||
coro: Async callable backing a LangChain tool.
|
||||
tool_name: Tool name used in error logs.
|
||||
|
||||
Returns:
|
||||
A sync callable suitable for ``BaseTool.func``.
|
||||
|
||||
Notes:
|
||||
If ``coro`` declares a ``RunnableConfig`` parameter, this wrapper
|
||||
exposes ``config: RunnableConfig`` so LangChain can inject runtime
|
||||
config and then forwards it to the coroutine's detected config
|
||||
parameter. This covers DeerFlow's current config-sensitive tools, such
|
||||
as ``invoke_acp_agent``.
|
||||
|
||||
This wrapper intentionally does not synthesize a dynamic function
|
||||
signature. A future async tool with a normal user-facing argument named
|
||||
``config`` and a separate ``RunnableConfig`` parameter named something
|
||||
else, such as ``run_config``, may collide with LangChain's injected
|
||||
``config`` argument. Rename that user-facing field or extend this
|
||||
helper before using that signature.
|
||||
"""
|
||||
config_param = _get_runnable_config_param(coro)
|
||||
|
||||
def run_coroutine(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
@@ -26,11 +69,24 @@ def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable
|
||||
|
||||
try:
|
||||
if loop is not None and loop.is_running():
|
||||
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
|
||||
context = contextvars.copy_context()
|
||||
future = _SYNC_TOOL_EXECUTOR.submit(context.run, lambda: asyncio.run(coro(*args, **kwargs)))
|
||||
return future.result()
|
||||
return asyncio.run(coro(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.error("Error invoking tool %r via sync wrapper: %s", tool_name, e, exc_info=True)
|
||||
raise
|
||||
|
||||
if config_param:
|
||||
|
||||
def sync_wrapper(*args: Any, config: RunnableConfig = None, **kwargs: Any) -> Any:
|
||||
if config is not None or config_param not in kwargs:
|
||||
kwargs[config_param] = config
|
||||
return run_coroutine(*args, **kwargs)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
return run_coroutine(*args, **kwargs)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
@@ -205,7 +205,7 @@ def get_available_tools(
|
||||
# Deduplicate by tool name — config-loaded tools take priority, followed by
|
||||
# built-ins, MCP tools, and ACP tools. Duplicate names cause the LLM to
|
||||
# receive ambiguous or concatenated function schemas (issue #1803).
|
||||
all_tools = loaded_tools + builtin_tools + mcp_tools + acp_tools
|
||||
all_tools = [_ensure_sync_invocable_tool(t) for t in loaded_tools + builtin_tools + mcp_tools + acp_tools]
|
||||
seen_names: set[str] = set()
|
||||
unique_tools: list[BaseTool] = []
|
||||
for t in all_tools:
|
||||
|
||||
@@ -699,6 +699,92 @@ def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(mo
|
||||
load_acp_config_from_dict({})
|
||||
|
||||
|
||||
def test_get_available_tools_sync_invoke_acp_agent_preserves_thread_workspace(monkeypatch, tmp_path):
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.runtime import user_context as uc_module
|
||||
|
||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||
classmethod(lambda cls: ExtensionsConfig(mcp_servers={}, skills={})),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.tools.tools.is_host_bash_allowed", lambda config=None: True)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class DummyClient:
|
||||
@property
|
||||
def collected_text(self) -> str:
|
||||
return "ok"
|
||||
|
||||
async def session_update(self, session_id, update, **kwargs):
|
||||
pass
|
||||
|
||||
async def request_permission(self, options, session_id, tool_call, **kwargs):
|
||||
raise AssertionError("should not be called")
|
||||
|
||||
class DummyConn:
|
||||
async def initialize(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def new_session(self, **kwargs):
|
||||
return SimpleNamespace(session_id="s1")
|
||||
|
||||
async def prompt(self, **kwargs):
|
||||
pass
|
||||
|
||||
class DummyProcessContext:
|
||||
def __init__(self, client, cmd, *args, env=None, cwd):
|
||||
captured["cwd"] = cwd
|
||||
|
||||
async def __aenter__(self):
|
||||
return DummyConn(), object()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"acp",
|
||||
SimpleNamespace(
|
||||
PROTOCOL_VERSION="2026-03-24",
|
||||
Client=DummyClient,
|
||||
spawn_agent_process=lambda client, cmd, *args, env=None, cwd: DummyProcessContext(client, cmd, *args, env=env, cwd=cwd),
|
||||
text_block=lambda text: {"type": "text", "text": text},
|
||||
),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"acp.schema",
|
||||
SimpleNamespace(
|
||||
ClientCapabilities=lambda: {},
|
||||
Implementation=lambda **kwargs: kwargs,
|
||||
TextContentBlock=type("TextContentBlock", (), {"__init__": lambda self, text: setattr(self, "text", text)}),
|
||||
),
|
||||
)
|
||||
|
||||
explicit_config = SimpleNamespace(
|
||||
tools=[],
|
||||
models=[],
|
||||
tool_search=SimpleNamespace(enabled=False),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
sandbox=SimpleNamespace(),
|
||||
get_model_config=lambda name: None,
|
||||
acp_agents={"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")},
|
||||
)
|
||||
tools = get_available_tools(include_mcp=False, subagent_enabled=False, app_config=explicit_config)
|
||||
tool = next(tool for tool in tools if tool.name == "invoke_acp_agent")
|
||||
|
||||
thread_id = "thread-sync-123"
|
||||
tool.invoke(
|
||||
{"agent": "codex", "prompt": "Do something"},
|
||||
config={"configurable": {"thread_id": thread_id}},
|
||||
)
|
||||
|
||||
assert captured["cwd"] == str(tmp_path / "threads" / thread_id / "acp-workspace")
|
||||
|
||||
|
||||
def test_get_available_tools_uses_explicit_app_config_for_acp_agents(monkeypatch):
|
||||
explicit_agents = {"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")}
|
||||
explicit_config = SimpleNamespace(
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import contextvars
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -69,6 +71,58 @@ def test_mcp_tool_sync_wrapper_in_running_loop():
|
||||
assert result == "async_result: 100"
|
||||
|
||||
|
||||
def test_sync_wrapper_preserves_contextvars_in_running_loop():
|
||||
"""The executor branch preserves LangGraph-style contextvars."""
|
||||
current_value: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_value", default=None)
|
||||
|
||||
async def mock_coro() -> str | None:
|
||||
return current_value.get()
|
||||
|
||||
sync_func = make_sync_tool_wrapper(mock_coro, "test_tool")
|
||||
|
||||
async def run_in_loop() -> str | None:
|
||||
token = current_value.set("from-parent-context")
|
||||
try:
|
||||
return sync_func()
|
||||
finally:
|
||||
current_value.reset(token)
|
||||
|
||||
assert asyncio.run(run_in_loop()) == "from-parent-context"
|
||||
|
||||
|
||||
def test_sync_wrapper_preserves_runnable_config_injection():
|
||||
"""LangChain can still inject RunnableConfig after an async tool is wrapped."""
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
async def mock_coro(x: int, config: RunnableConfig = None):
|
||||
captured["thread_id"] = ((config or {}).get("configurable") or {}).get("thread_id")
|
||||
return f"result: {x}"
|
||||
|
||||
mock_tool = StructuredTool(
|
||||
name="test_tool",
|
||||
description="test description",
|
||||
args_schema=MockArgs,
|
||||
func=make_sync_tool_wrapper(mock_coro, "test_tool"),
|
||||
coroutine=mock_coro,
|
||||
)
|
||||
|
||||
result = mock_tool.invoke({"x": 42}, config={"configurable": {"thread_id": "thread-123"}})
|
||||
|
||||
assert result == "result: 42"
|
||||
assert captured["thread_id"] == "thread-123"
|
||||
|
||||
|
||||
def test_sync_wrapper_preserves_regular_config_argument():
|
||||
"""Only RunnableConfig-annotated coroutine params get special config injection."""
|
||||
|
||||
async def mock_coro(config: str):
|
||||
return config
|
||||
|
||||
sync_func = make_sync_tool_wrapper(mock_coro, "test_tool")
|
||||
|
||||
assert sync_func(config="user-config") == "user-config"
|
||||
|
||||
|
||||
def test_mcp_tool_sync_wrapper_exception_logging():
|
||||
"""Test the shared sync wrapper's error logging."""
|
||||
|
||||
|
||||
@@ -95,6 +95,64 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
|
||||
assert async_tool.invoke({"x": 42}) == "result: 42"
|
||||
|
||||
|
||||
@patch("deerflow.tools.tools.get_app_config")
|
||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||
def test_subagent_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
|
||||
"""Async-only tools added through the subagent path can be invoked by sync clients."""
|
||||
|
||||
async def async_tool_impl(x: int) -> str:
|
||||
return f"subagent: {x}"
|
||||
|
||||
async_tool = StructuredTool(
|
||||
name="async_subagent_tool",
|
||||
description="Async-only subagent test tool.",
|
||||
args_schema=AsyncToolArgs,
|
||||
func=None,
|
||||
coroutine=async_tool_impl,
|
||||
)
|
||||
mock_cfg.return_value = _make_minimal_config([])
|
||||
|
||||
with (
|
||||
patch("deerflow.tools.tools.BUILTIN_TOOLS", []),
|
||||
patch("deerflow.tools.tools.SUBAGENT_TOOLS", [async_tool]),
|
||||
):
|
||||
result = get_available_tools(include_mcp=False, subagent_enabled=True, app_config=mock_cfg.return_value)
|
||||
|
||||
assert async_tool in result
|
||||
assert async_tool.func is not None
|
||||
assert async_tool.invoke({"x": 7}) == "subagent: 7"
|
||||
|
||||
|
||||
@patch("deerflow.tools.tools.get_app_config")
|
||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||
def test_acp_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
|
||||
"""Async-only ACP tools can be invoked by sync clients."""
|
||||
|
||||
async def async_tool_impl(x: int) -> str:
|
||||
return f"acp: {x}"
|
||||
|
||||
async_tool = StructuredTool(
|
||||
name="invoke_acp_agent",
|
||||
description="Async-only ACP test tool.",
|
||||
args_schema=AsyncToolArgs,
|
||||
func=None,
|
||||
coroutine=async_tool_impl,
|
||||
)
|
||||
config = _make_minimal_config([])
|
||||
config.acp_agents = {"codex": object()}
|
||||
mock_cfg.return_value = config
|
||||
|
||||
with (
|
||||
patch("deerflow.tools.tools.BUILTIN_TOOLS", []),
|
||||
patch("deerflow.tools.builtins.invoke_acp_agent_tool.build_invoke_acp_agent_tool", return_value=async_tool),
|
||||
):
|
||||
result = get_available_tools(include_mcp=False, app_config=config)
|
||||
|
||||
assert async_tool in result
|
||||
assert async_tool.func is not None
|
||||
assert async_tool.invoke({"x": 9}) == "acp: 9"
|
||||
|
||||
|
||||
@patch("deerflow.tools.tools.get_app_config")
|
||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||
def test_no_duplicates_returned(mock_bash, mock_cfg):
|
||||
|
||||
Reference in New Issue
Block a user