mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 17:06:00 +00:00
fix(harness): wrap all async-only tools for sync clients (#2935)
This commit is contained in:
@@ -3,9 +3,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import atexit
|
import atexit
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import contextvars
|
||||||
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any, get_type_hints
|
||||||
|
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -15,10 +19,49 @@ _SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thre
|
|||||||
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
|
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
|
||||||
|
|
||||||
|
|
||||||
def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
|
def _get_runnable_config_param(func: Callable[..., Any]) -> str | None:
|
||||||
"""Build a synchronous wrapper for an asynchronous tool coroutine."""
|
"""Return the coroutine parameter that expects LangChain RunnableConfig."""
|
||||||
|
if isinstance(func, functools.partial):
|
||||||
|
func = func.func
|
||||||
|
|
||||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
try:
|
||||||
|
type_hints = get_type_hints(func)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for name, type_ in type_hints.items():
|
||||||
|
if type_ is RunnableConfig:
|
||||||
|
return name
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
|
||||||
|
"""Build a synchronous wrapper for an asynchronous tool coroutine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coro: Async callable backing a LangChain tool.
|
||||||
|
tool_name: Tool name used in error logs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sync callable suitable for ``BaseTool.func``.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
If ``coro`` declares a ``RunnableConfig`` parameter, this wrapper
|
||||||
|
exposes ``config: RunnableConfig`` so LangChain can inject runtime
|
||||||
|
config and then forwards it to the coroutine's detected config
|
||||||
|
parameter. This covers DeerFlow's current config-sensitive tools, such
|
||||||
|
as ``invoke_acp_agent``.
|
||||||
|
|
||||||
|
This wrapper intentionally does not synthesize a dynamic function
|
||||||
|
signature. A future async tool with a normal user-facing argument named
|
||||||
|
``config`` and a separate ``RunnableConfig`` parameter named something
|
||||||
|
else, such as ``run_config``, may collide with LangChain's injected
|
||||||
|
``config`` argument. Rename that user-facing field or extend this
|
||||||
|
helper before using that signature.
|
||||||
|
"""
|
||||||
|
config_param = _get_runnable_config_param(coro)
|
||||||
|
|
||||||
|
def run_coroutine(*args: Any, **kwargs: Any) -> Any:
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
@@ -26,11 +69,24 @@ def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if loop is not None and loop.is_running():
|
if loop is not None and loop.is_running():
|
||||||
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
|
context = contextvars.copy_context()
|
||||||
|
future = _SYNC_TOOL_EXECUTOR.submit(context.run, lambda: asyncio.run(coro(*args, **kwargs)))
|
||||||
return future.result()
|
return future.result()
|
||||||
return asyncio.run(coro(*args, **kwargs))
|
return asyncio.run(coro(*args, **kwargs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error invoking tool %r via sync wrapper: %s", tool_name, e, exc_info=True)
|
logger.error("Error invoking tool %r via sync wrapper: %s", tool_name, e, exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
if config_param:
|
||||||
|
|
||||||
|
def sync_wrapper(*args: Any, config: RunnableConfig = None, **kwargs: Any) -> Any:
|
||||||
|
if config is not None or config_param not in kwargs:
|
||||||
|
kwargs[config_param] = config
|
||||||
|
return run_coroutine(*args, **kwargs)
|
||||||
|
|
||||||
|
return sync_wrapper
|
||||||
|
|
||||||
|
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
return run_coroutine(*args, **kwargs)
|
||||||
|
|
||||||
return sync_wrapper
|
return sync_wrapper
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ def get_available_tools(
|
|||||||
# Deduplicate by tool name — config-loaded tools take priority, followed by
|
# Deduplicate by tool name — config-loaded tools take priority, followed by
|
||||||
# built-ins, MCP tools, and ACP tools. Duplicate names cause the LLM to
|
# built-ins, MCP tools, and ACP tools. Duplicate names cause the LLM to
|
||||||
# receive ambiguous or concatenated function schemas (issue #1803).
|
# receive ambiguous or concatenated function schemas (issue #1803).
|
||||||
all_tools = loaded_tools + builtin_tools + mcp_tools + acp_tools
|
all_tools = [_ensure_sync_invocable_tool(t) for t in loaded_tools + builtin_tools + mcp_tools + acp_tools]
|
||||||
seen_names: set[str] = set()
|
seen_names: set[str] = set()
|
||||||
unique_tools: list[BaseTool] = []
|
unique_tools: list[BaseTool] = []
|
||||||
for t in all_tools:
|
for t in all_tools:
|
||||||
|
|||||||
@@ -699,6 +699,92 @@ def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(mo
|
|||||||
load_acp_config_from_dict({})
|
load_acp_config_from_dict({})
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_available_tools_sync_invoke_acp_agent_preserves_thread_workspace(monkeypatch, tmp_path):
|
||||||
|
from deerflow.config import paths as paths_module
|
||||||
|
from deerflow.runtime import user_context as uc_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||||
|
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||||
|
classmethod(lambda cls: ExtensionsConfig(mcp_servers={}, skills={})),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("deerflow.tools.tools.is_host_bash_allowed", lambda config=None: True)
|
||||||
|
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
class DummyClient:
|
||||||
|
@property
|
||||||
|
def collected_text(self) -> str:
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
async def session_update(self, session_id, update, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def request_permission(self, options, session_id, tool_call, **kwargs):
|
||||||
|
raise AssertionError("should not be called")
|
||||||
|
|
||||||
|
class DummyConn:
|
||||||
|
async def initialize(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def new_session(self, **kwargs):
|
||||||
|
return SimpleNamespace(session_id="s1")
|
||||||
|
|
||||||
|
async def prompt(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DummyProcessContext:
|
||||||
|
def __init__(self, client, cmd, *args, env=None, cwd):
|
||||||
|
captured["cwd"] = cwd
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return DummyConn(), object()
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"acp",
|
||||||
|
SimpleNamespace(
|
||||||
|
PROTOCOL_VERSION="2026-03-24",
|
||||||
|
Client=DummyClient,
|
||||||
|
spawn_agent_process=lambda client, cmd, *args, env=None, cwd: DummyProcessContext(client, cmd, *args, env=env, cwd=cwd),
|
||||||
|
text_block=lambda text: {"type": "text", "text": text},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"acp.schema",
|
||||||
|
SimpleNamespace(
|
||||||
|
ClientCapabilities=lambda: {},
|
||||||
|
Implementation=lambda **kwargs: kwargs,
|
||||||
|
TextContentBlock=type("TextContentBlock", (), {"__init__": lambda self, text: setattr(self, "text", text)}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
explicit_config = SimpleNamespace(
|
||||||
|
tools=[],
|
||||||
|
models=[],
|
||||||
|
tool_search=SimpleNamespace(enabled=False),
|
||||||
|
skill_evolution=SimpleNamespace(enabled=False),
|
||||||
|
sandbox=SimpleNamespace(),
|
||||||
|
get_model_config=lambda name: None,
|
||||||
|
acp_agents={"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")},
|
||||||
|
)
|
||||||
|
tools = get_available_tools(include_mcp=False, subagent_enabled=False, app_config=explicit_config)
|
||||||
|
tool = next(tool for tool in tools if tool.name == "invoke_acp_agent")
|
||||||
|
|
||||||
|
thread_id = "thread-sync-123"
|
||||||
|
tool.invoke(
|
||||||
|
{"agent": "codex", "prompt": "Do something"},
|
||||||
|
config={"configurable": {"thread_id": thread_id}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured["cwd"] == str(tmp_path / "threads" / thread_id / "acp-workspace")
|
||||||
|
|
||||||
|
|
||||||
def test_get_available_tools_uses_explicit_app_config_for_acp_agents(monkeypatch):
|
def test_get_available_tools_uses_explicit_app_config_for_acp_agents(monkeypatch):
|
||||||
explicit_agents = {"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")}
|
explicit_agents = {"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")}
|
||||||
explicit_config = SimpleNamespace(
|
explicit_config = SimpleNamespace(
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import contextvars
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
from langchain_core.tools import StructuredTool
|
from langchain_core.tools import StructuredTool
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -69,6 +71,58 @@ def test_mcp_tool_sync_wrapper_in_running_loop():
|
|||||||
assert result == "async_result: 100"
|
assert result == "async_result: 100"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_wrapper_preserves_contextvars_in_running_loop():
|
||||||
|
"""The executor branch preserves LangGraph-style contextvars."""
|
||||||
|
current_value: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_value", default=None)
|
||||||
|
|
||||||
|
async def mock_coro() -> str | None:
|
||||||
|
return current_value.get()
|
||||||
|
|
||||||
|
sync_func = make_sync_tool_wrapper(mock_coro, "test_tool")
|
||||||
|
|
||||||
|
async def run_in_loop() -> str | None:
|
||||||
|
token = current_value.set("from-parent-context")
|
||||||
|
try:
|
||||||
|
return sync_func()
|
||||||
|
finally:
|
||||||
|
current_value.reset(token)
|
||||||
|
|
||||||
|
assert asyncio.run(run_in_loop()) == "from-parent-context"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_wrapper_preserves_runnable_config_injection():
|
||||||
|
"""LangChain can still inject RunnableConfig after an async tool is wrapped."""
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
async def mock_coro(x: int, config: RunnableConfig = None):
|
||||||
|
captured["thread_id"] = ((config or {}).get("configurable") or {}).get("thread_id")
|
||||||
|
return f"result: {x}"
|
||||||
|
|
||||||
|
mock_tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
description="test description",
|
||||||
|
args_schema=MockArgs,
|
||||||
|
func=make_sync_tool_wrapper(mock_coro, "test_tool"),
|
||||||
|
coroutine=mock_coro,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = mock_tool.invoke({"x": 42}, config={"configurable": {"thread_id": "thread-123"}})
|
||||||
|
|
||||||
|
assert result == "result: 42"
|
||||||
|
assert captured["thread_id"] == "thread-123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_wrapper_preserves_regular_config_argument():
|
||||||
|
"""Only RunnableConfig-annotated coroutine params get special config injection."""
|
||||||
|
|
||||||
|
async def mock_coro(config: str):
|
||||||
|
return config
|
||||||
|
|
||||||
|
sync_func = make_sync_tool_wrapper(mock_coro, "test_tool")
|
||||||
|
|
||||||
|
assert sync_func(config="user-config") == "user-config"
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_tool_sync_wrapper_exception_logging():
|
def test_mcp_tool_sync_wrapper_exception_logging():
|
||||||
"""Test the shared sync wrapper's error logging."""
|
"""Test the shared sync wrapper's error logging."""
|
||||||
|
|
||||||
|
|||||||
@@ -95,6 +95,64 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
|
|||||||
assert async_tool.invoke({"x": 42}) == "result: 42"
|
assert async_tool.invoke({"x": 42}) == "result: 42"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("deerflow.tools.tools.get_app_config")
|
||||||
|
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||||
|
def test_subagent_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
|
||||||
|
"""Async-only tools added through the subagent path can be invoked by sync clients."""
|
||||||
|
|
||||||
|
async def async_tool_impl(x: int) -> str:
|
||||||
|
return f"subagent: {x}"
|
||||||
|
|
||||||
|
async_tool = StructuredTool(
|
||||||
|
name="async_subagent_tool",
|
||||||
|
description="Async-only subagent test tool.",
|
||||||
|
args_schema=AsyncToolArgs,
|
||||||
|
func=None,
|
||||||
|
coroutine=async_tool_impl,
|
||||||
|
)
|
||||||
|
mock_cfg.return_value = _make_minimal_config([])
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.tools.tools.BUILTIN_TOOLS", []),
|
||||||
|
patch("deerflow.tools.tools.SUBAGENT_TOOLS", [async_tool]),
|
||||||
|
):
|
||||||
|
result = get_available_tools(include_mcp=False, subagent_enabled=True, app_config=mock_cfg.return_value)
|
||||||
|
|
||||||
|
assert async_tool in result
|
||||||
|
assert async_tool.func is not None
|
||||||
|
assert async_tool.invoke({"x": 7}) == "subagent: 7"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("deerflow.tools.tools.get_app_config")
|
||||||
|
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||||
|
def test_acp_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
|
||||||
|
"""Async-only ACP tools can be invoked by sync clients."""
|
||||||
|
|
||||||
|
async def async_tool_impl(x: int) -> str:
|
||||||
|
return f"acp: {x}"
|
||||||
|
|
||||||
|
async_tool = StructuredTool(
|
||||||
|
name="invoke_acp_agent",
|
||||||
|
description="Async-only ACP test tool.",
|
||||||
|
args_schema=AsyncToolArgs,
|
||||||
|
func=None,
|
||||||
|
coroutine=async_tool_impl,
|
||||||
|
)
|
||||||
|
config = _make_minimal_config([])
|
||||||
|
config.acp_agents = {"codex": object()}
|
||||||
|
mock_cfg.return_value = config
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.tools.tools.BUILTIN_TOOLS", []),
|
||||||
|
patch("deerflow.tools.builtins.invoke_acp_agent_tool.build_invoke_acp_agent_tool", return_value=async_tool),
|
||||||
|
):
|
||||||
|
result = get_available_tools(include_mcp=False, app_config=config)
|
||||||
|
|
||||||
|
assert async_tool in result
|
||||||
|
assert async_tool.func is not None
|
||||||
|
assert async_tool.invoke({"x": 9}) == "acp: 9"
|
||||||
|
|
||||||
|
|
||||||
@patch("deerflow.tools.tools.get_app_config")
|
@patch("deerflow.tools.tools.get_app_config")
|
||||||
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
|
||||||
def test_no_duplicates_returned(mock_bash, mock_cfg):
|
def test_no_duplicates_returned(mock_bash, mock_cfg):
|
||||||
|
|||||||
Reference in New Issue
Block a user