diff --git a/backend/packages/harness/deerflow/tools/sync.py b/backend/packages/harness/deerflow/tools/sync.py index c2b80781a..7521dd7b3 100644 --- a/backend/packages/harness/deerflow/tools/sync.py +++ b/backend/packages/harness/deerflow/tools/sync.py @@ -3,9 +3,13 @@ import asyncio import atexit import concurrent.futures +import contextvars +import functools import logging from collections.abc import Callable -from typing import Any +from typing import Any, get_type_hints + +from langchain_core.runnables import RunnableConfig logger = logging.getLogger(__name__) @@ -15,10 +19,49 @@ _SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thre atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False)) -def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]: - """Build a synchronous wrapper for an asynchronous tool coroutine.""" +def _get_runnable_config_param(func: Callable[..., Any]) -> str | None: + """Return the coroutine parameter that expects LangChain RunnableConfig.""" + if isinstance(func, functools.partial): + func = func.func - def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + try: + type_hints = get_type_hints(func) + except Exception: + return None + + for name, type_ in type_hints.items(): + if type_ is RunnableConfig: + return name + return None + + +def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]: + """Build a synchronous wrapper for an asynchronous tool coroutine. + + Args: + coro: Async callable backing a LangChain tool. + tool_name: Tool name used in error logs. + + Returns: + A sync callable suitable for ``BaseTool.func``. + + Notes: + If ``coro`` declares a ``RunnableConfig`` parameter, this wrapper + exposes ``config: RunnableConfig`` so LangChain can inject runtime + config and then forwards it to the coroutine's detected config + parameter. This covers DeerFlow's current config-sensitive tools, such + as ``invoke_acp_agent``. + + This wrapper intentionally does not synthesize a dynamic function + signature. A future async tool with a normal user-facing argument named + ``config`` and a separate ``RunnableConfig`` parameter named something + else, such as ``run_config``, may collide with LangChain's injected + ``config`` argument. Rename that user-facing field or extend this + helper before using that signature. + """ + config_param = _get_runnable_config_param(coro) + + def run_coroutine(*args: Any, **kwargs: Any) -> Any: try: loop = asyncio.get_running_loop() except RuntimeError: @@ -26,11 +69,24 @@ def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable try: if loop is not None and loop.is_running(): - future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs)) + context = contextvars.copy_context() + future = _SYNC_TOOL_EXECUTOR.submit(context.run, lambda: asyncio.run(coro(*args, **kwargs))) return future.result() return asyncio.run(coro(*args, **kwargs)) except Exception as e: logger.error("Error invoking tool %r via sync wrapper: %s", tool_name, e, exc_info=True) raise + if config_param: + + def sync_wrapper(*args: Any, config: RunnableConfig = None, **kwargs: Any) -> Any: + if config is not None or config_param not in kwargs: + kwargs[config_param] = config + return run_coroutine(*args, **kwargs) + + return sync_wrapper + + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + return run_coroutine(*args, **kwargs) + return sync_wrapper diff --git a/backend/packages/harness/deerflow/tools/tools.py b/backend/packages/harness/deerflow/tools/tools.py index 5c97962fc..bc2caed43 100644 --- a/backend/packages/harness/deerflow/tools/tools.py +++ b/backend/packages/harness/deerflow/tools/tools.py @@ -205,7 +205,7 @@ def get_available_tools( # Deduplicate by tool name — config-loaded tools take priority, followed by # built-ins, MCP tools, and ACP tools. Duplicate names cause the LLM to # receive ambiguous or concatenated function schemas (issue #1803). - all_tools = loaded_tools + builtin_tools + mcp_tools + acp_tools + all_tools = [_ensure_sync_invocable_tool(t) for t in loaded_tools + builtin_tools + mcp_tools + acp_tools] seen_names: set[str] = set() unique_tools: list[BaseTool] = [] for t in all_tools: diff --git a/backend/tests/test_invoke_acp_agent_tool.py b/backend/tests/test_invoke_acp_agent_tool.py index 8c44403b8..deace5b4e 100644 --- a/backend/tests/test_invoke_acp_agent_tool.py +++ b/backend/tests/test_invoke_acp_agent_tool.py @@ -699,6 +699,92 @@ def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(mo load_acp_config_from_dict({}) +def test_get_available_tools_sync_invoke_acp_agent_preserves_thread_workspace(monkeypatch, tmp_path): + from deerflow.config import paths as paths_module + from deerflow.runtime import user_context as uc_module + + monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path)) + monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None) + monkeypatch.setattr( + "deerflow.config.extensions_config.ExtensionsConfig.from_file", + classmethod(lambda cls: ExtensionsConfig(mcp_servers={}, skills={})), + ) + monkeypatch.setattr("deerflow.tools.tools.is_host_bash_allowed", lambda config=None: True) + + captured: dict[str, object] = {} + + class DummyClient: + @property + def collected_text(self) -> str: + return "ok" + + async def session_update(self, session_id, update, **kwargs): + pass + + async def request_permission(self, options, session_id, tool_call, **kwargs): + raise AssertionError("should not be called") + + class DummyConn: + async def initialize(self, **kwargs): + pass + + async def new_session(self, **kwargs): + return SimpleNamespace(session_id="s1") + + async def prompt(self, **kwargs): + pass + + class DummyProcessContext: + def __init__(self, client, cmd, *args, env=None, cwd): + captured["cwd"] = cwd + + async def __aenter__(self): + return DummyConn(), object() + + async def __aexit__(self, exc_type, exc, tb): + return False + + monkeypatch.setitem( + sys.modules, + "acp", + SimpleNamespace( + PROTOCOL_VERSION="2026-03-24", + Client=DummyClient, + spawn_agent_process=lambda client, cmd, *args, env=None, cwd: DummyProcessContext(client, cmd, *args, env=env, cwd=cwd), + text_block=lambda text: {"type": "text", "text": text}, + ), + ) + monkeypatch.setitem( + sys.modules, + "acp.schema", + SimpleNamespace( + ClientCapabilities=lambda: {}, + Implementation=lambda **kwargs: kwargs, + TextContentBlock=type("TextContentBlock", (), {"__init__": lambda self, text: setattr(self, "text", text)}), + ), + ) + + explicit_config = SimpleNamespace( + tools=[], + models=[], + tool_search=SimpleNamespace(enabled=False), + skill_evolution=SimpleNamespace(enabled=False), + sandbox=SimpleNamespace(), + get_model_config=lambda name: None, + acp_agents={"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")}, + ) + tools = get_available_tools(include_mcp=False, subagent_enabled=False, app_config=explicit_config) + tool = next(tool for tool in tools if tool.name == "invoke_acp_agent") + + thread_id = "thread-sync-123" + tool.invoke( + {"agent": "codex", "prompt": "Do something"}, + config={"configurable": {"thread_id": thread_id}}, + ) + + assert captured["cwd"] == str(tmp_path / "threads" / thread_id / "acp-workspace") + + def test_get_available_tools_uses_explicit_app_config_for_acp_agents(monkeypatch): explicit_agents = {"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")} explicit_config = SimpleNamespace( diff --git a/backend/tests/test_mcp_sync_wrapper.py b/backend/tests/test_mcp_sync_wrapper.py index 285200781..c66662bb5 100644 --- a/backend/tests/test_mcp_sync_wrapper.py +++ b/backend/tests/test_mcp_sync_wrapper.py @@ -1,7 +1,9 @@ import asyncio +import contextvars from unittest.mock import AsyncMock, MagicMock, patch import pytest +from langchain_core.runnables import RunnableConfig from langchain_core.tools import StructuredTool from pydantic import BaseModel, Field @@ -69,6 +71,58 @@ def test_mcp_tool_sync_wrapper_in_running_loop(): assert result == "async_result: 100" +def test_sync_wrapper_preserves_contextvars_in_running_loop(): + """The executor branch preserves LangGraph-style contextvars.""" + current_value: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_value", default=None) + + async def mock_coro() -> str | None: + return current_value.get() + + sync_func = make_sync_tool_wrapper(mock_coro, "test_tool") + + async def run_in_loop() -> str | None: + token = current_value.set("from-parent-context") + try: + return sync_func() + finally: + current_value.reset(token) + + assert asyncio.run(run_in_loop()) == "from-parent-context" + + +def test_sync_wrapper_preserves_runnable_config_injection(): + """LangChain can still inject RunnableConfig after an async tool is wrapped.""" + captured: dict[str, object] = {} + + async def mock_coro(x: int, config: RunnableConfig = None): + captured["thread_id"] = ((config or {}).get("configurable") or {}).get("thread_id") + return f"result: {x}" + + mock_tool = StructuredTool( + name="test_tool", + description="test description", + args_schema=MockArgs, + func=make_sync_tool_wrapper(mock_coro, "test_tool"), + coroutine=mock_coro, + ) + + result = mock_tool.invoke({"x": 42}, config={"configurable": {"thread_id": "thread-123"}}) + + assert result == "result: 42" + assert captured["thread_id"] == "thread-123" + + +def test_sync_wrapper_preserves_regular_config_argument(): + """Only RunnableConfig-annotated coroutine params get special config injection.""" + + async def mock_coro(config: str): + return config + + sync_func = make_sync_tool_wrapper(mock_coro, "test_tool") + + assert sync_func(config="user-config") == "user-config" + + def test_mcp_tool_sync_wrapper_exception_logging(): """Test the shared sync wrapper's error logging.""" diff --git a/backend/tests/test_tool_deduplication.py b/backend/tests/test_tool_deduplication.py index f018fc57d..b8a7a3127 100644 --- a/backend/tests/test_tool_deduplication.py +++ b/backend/tests/test_tool_deduplication.py @@ -95,6 +95,64 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg): assert async_tool.invoke({"x": 42}) == "result: 42" +@patch("deerflow.tools.tools.get_app_config") +@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) +def test_subagent_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg): + """Async-only tools added through the subagent path can be invoked by sync clients.""" + + async def async_tool_impl(x: int) -> str: + return f"subagent: {x}" + + async_tool = StructuredTool( + name="async_subagent_tool", + description="Async-only subagent test tool.", + args_schema=AsyncToolArgs, + func=None, + coroutine=async_tool_impl, + ) + mock_cfg.return_value = _make_minimal_config([]) + + with ( + patch("deerflow.tools.tools.BUILTIN_TOOLS", []), + patch("deerflow.tools.tools.SUBAGENT_TOOLS", [async_tool]), + ): + result = get_available_tools(include_mcp=False, subagent_enabled=True, app_config=mock_cfg.return_value) + + assert async_tool in result + assert async_tool.func is not None + assert async_tool.invoke({"x": 7}) == "subagent: 7" + + +@patch("deerflow.tools.tools.get_app_config") +@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) +def test_acp_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg): + """Async-only ACP tools can be invoked by sync clients.""" + + async def async_tool_impl(x: int) -> str: + return f"acp: {x}" + + async_tool = StructuredTool( + name="invoke_acp_agent", + description="Async-only ACP test tool.", + args_schema=AsyncToolArgs, + func=None, + coroutine=async_tool_impl, + ) + config = _make_minimal_config([]) + config.acp_agents = {"codex": object()} + mock_cfg.return_value = config + + with ( + patch("deerflow.tools.tools.BUILTIN_TOOLS", []), + patch("deerflow.tools.builtins.invoke_acp_agent_tool.build_invoke_acp_agent_tool", return_value=async_tool), + ): + result = get_available_tools(include_mcp=False, app_config=config) + + assert async_tool in result + assert async_tool.func is not None + assert async_tool.invoke({"x": 9}) == "acp: 9" + + @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) def test_no_duplicates_returned(mock_bash, mock_cfg):