mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 08:55:59 +00:00
fix(harness): wrap all async-only tools for sync clients (#2935)
This commit is contained in:
@@ -3,9 +3,13 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,10 +19,49 @@ _SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thre
|
||||
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
|
||||
|
||||
|
||||
def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
|
||||
"""Build a synchronous wrapper for an asynchronous tool coroutine."""
|
||||
def _get_runnable_config_param(func: Callable[..., Any]) -> str | None:
|
||||
"""Return the coroutine parameter that expects LangChain RunnableConfig."""
|
||||
if isinstance(func, functools.partial):
|
||||
func = func.func
|
||||
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
type_hints = get_type_hints(func)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
for name, type_ in type_hints.items():
|
||||
if type_ is RunnableConfig:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
|
||||
"""Build a synchronous wrapper for an asynchronous tool coroutine.
|
||||
|
||||
Args:
|
||||
coro: Async callable backing a LangChain tool.
|
||||
tool_name: Tool name used in error logs.
|
||||
|
||||
Returns:
|
||||
A sync callable suitable for ``BaseTool.func``.
|
||||
|
||||
Notes:
|
||||
If ``coro`` declares a ``RunnableConfig`` parameter, this wrapper
|
||||
exposes ``config: RunnableConfig`` so LangChain can inject runtime
|
||||
config and then forwards it to the coroutine's detected config
|
||||
parameter. This covers DeerFlow's current config-sensitive tools, such
|
||||
as ``invoke_acp_agent``.
|
||||
|
||||
This wrapper intentionally does not synthesize a dynamic function
|
||||
signature. A future async tool with a normal user-facing argument named
|
||||
``config`` and a separate ``RunnableConfig`` parameter named something
|
||||
else, such as ``run_config``, may collide with LangChain's injected
|
||||
``config`` argument. Rename that user-facing field or extend this
|
||||
helper before using that signature.
|
||||
"""
|
||||
config_param = _get_runnable_config_param(coro)
|
||||
|
||||
def run_coroutine(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
@@ -26,11 +69,24 @@ def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable
|
||||
|
||||
try:
|
||||
if loop is not None and loop.is_running():
|
||||
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
|
||||
context = contextvars.copy_context()
|
||||
future = _SYNC_TOOL_EXECUTOR.submit(context.run, lambda: asyncio.run(coro(*args, **kwargs)))
|
||||
return future.result()
|
||||
return asyncio.run(coro(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.error("Error invoking tool %r via sync wrapper: %s", tool_name, e, exc_info=True)
|
||||
raise
|
||||
|
||||
if config_param:
|
||||
|
||||
def sync_wrapper(*args: Any, config: RunnableConfig = None, **kwargs: Any) -> Any:
|
||||
if config is not None or config_param not in kwargs:
|
||||
kwargs[config_param] = config
|
||||
return run_coroutine(*args, **kwargs)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
return run_coroutine(*args, **kwargs)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
Reference in New Issue
Block a user