mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
c881d95898
* fix(mcp): persist MCP sessions across tool calls for stateful servers MCP tools loaded via langchain-mcp-adapters created a new session on every call, causing stateful servers like Playwright to lose browser state (pages, forms) between consecutive tool invocations within the same thread. Add MCPSessionPool that maintains persistent sessions scoped by (server_name, thread_id). Tool calls within the same thread now reuse the same MCP session, preserving server-side state. Sessions are evicted in LRU order (max 256) and cleaned up on cache invalidation. Fixes #3054 * fix(sandbox): add group/other read permissions to uploaded files for Docker sandbox (#3127) When using AIO sandbox with LocalContainerBackend, uploaded files are created with 0o600 (owner-only) permissions by the gateway process running as root. The sandbox process inside the Docker container runs as a non-root user and cannot read these bind-mounted files, causing a "Permission denied" error on read_file. Add `needs_upload_permission_adjustment` attribute to SandboxProvider (default True) to indicate that uploaded files need chmod adjustment. LocalSandboxProvider opts out (same user). A new `_make_file_sandbox_readable` function adds S_IRGRP | S_IROTH bits after files are written, changing permissions from 0o600 to 0o644 so the sandbox can read the uploads. * fix(mcp): address review comments on session pool and tools - _extract_thread_id: return "default" instead of stringifying None when get_config() returns no thread_id - call_with_persistent_session: fix **arguments annotation from dict[str,Any] to Any - Replace private _convert_call_tool_result import with a local implementation that handles all MCP content block types - _make_session_pool_tool: accept tool_interceptors and apply the configured interceptor chain on every call (preserving OAuth and custom interceptors) - MCPSessionPool: replace asyncio.Lock with threading.Lock; restructure get/close methods to never await while holding the lock; add close_all_sync() that closes sessions on their owning event loops - reset_mcp_tools_cache: use pool.close_all_sync() instead of asyncio.run-in-thread to close sessions deterministically - test: add test_session_pool_tool_sync_wrapper_path_is_safe covering tool invocation via the sync wrapper (tool.func) path Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/9e7f9e7f-1d2b-464a-b3b7-7f1649b74122 Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com> * fix(mcp): extract SESSION_CLOSE_TIMEOUT to class constant Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/9e7f9e7f-1d2b-464a-b3b7-7f1649b74122 Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com> * Potential fix for pull request finding 'Empty except' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
277 lines
11 KiB
Python
277 lines
11 KiB
Python
"""Load MCP tools using langchain-mcp-adapters with persistent sessions."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Any
|
||
|
||
from langchain_core.tools import BaseTool, StructuredTool
|
||
from langgraph.config import get_config
|
||
|
||
from deerflow.config.extensions_config import ExtensionsConfig
|
||
from deerflow.mcp.client import build_servers_config
|
||
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
|
||
from deerflow.mcp.session_pool import get_session_pool
|
||
from deerflow.reflection import resolve_variable
|
||
from deerflow.tools.sync import make_sync_tool_wrapper
|
||
from deerflow.tools.types import Runtime
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _extract_thread_id(runtime: Runtime | None) -> str:
|
||
"""Extract thread_id from the injected tool runtime or LangGraph config."""
|
||
if runtime is not None:
|
||
tid = runtime.context.get("thread_id") if runtime.context else None
|
||
if tid is not None:
|
||
return str(tid)
|
||
config = runtime.config or {}
|
||
tid = config.get("configurable", {}).get("thread_id")
|
||
if tid is not None:
|
||
return str(tid)
|
||
|
||
try:
|
||
tid = get_config().get("configurable", {}).get("thread_id")
|
||
return str(tid) if tid is not None else "default"
|
||
except RuntimeError:
|
||
return "default"
|
||
|
||
|
||
def _convert_call_tool_result(call_tool_result: Any) -> Any:
|
||
"""Convert an MCP CallToolResult to the LangChain ``content_and_artifact`` format.
|
||
|
||
Implements the same conversion logic as the adapter without relying on
|
||
the private ``langchain_mcp_adapters.tools._convert_call_tool_result`` symbol.
|
||
"""
|
||
from langchain_core.messages import ToolMessage
|
||
from langchain_core.messages.content import create_file_block, create_image_block, create_text_block
|
||
from langchain_core.tools import ToolException
|
||
from mcp.types import EmbeddedResource, ImageContent, ResourceLink, TextContent, TextResourceContents
|
||
|
||
# Pass ToolMessage through directly (interceptor short-circuit).
|
||
if isinstance(call_tool_result, ToolMessage):
|
||
return call_tool_result, None
|
||
|
||
# Pass LangGraph Command through directly when langgraph is installed.
|
||
try:
|
||
from langgraph.types import Command
|
||
|
||
if isinstance(call_tool_result, Command):
|
||
return call_tool_result, None
|
||
except ImportError:
|
||
# langgraph is optional; if unavailable, continue with standard MCP content conversion.
|
||
pass
|
||
|
||
# Convert MCP content blocks to LangChain content blocks.
|
||
lc_content = []
|
||
for item in call_tool_result.content:
|
||
if isinstance(item, TextContent):
|
||
lc_content.append(create_text_block(text=item.text))
|
||
elif isinstance(item, ImageContent):
|
||
lc_content.append(create_image_block(base64=item.data, mime_type=item.mimeType))
|
||
elif isinstance(item, ResourceLink):
|
||
mime = item.mimeType or None
|
||
if mime and mime.startswith("image/"):
|
||
lc_content.append(create_image_block(url=str(item.uri), mime_type=mime))
|
||
else:
|
||
lc_content.append(create_file_block(url=str(item.uri), mime_type=mime))
|
||
elif isinstance(item, EmbeddedResource):
|
||
from mcp.types import BlobResourceContents
|
||
|
||
res = item.resource
|
||
if isinstance(res, TextResourceContents):
|
||
lc_content.append(create_text_block(text=res.text))
|
||
elif isinstance(res, BlobResourceContents):
|
||
mime = res.mimeType or None
|
||
if mime and mime.startswith("image/"):
|
||
lc_content.append(create_image_block(base64=res.blob, mime_type=mime))
|
||
else:
|
||
lc_content.append(create_file_block(base64=res.blob, mime_type=mime))
|
||
else:
|
||
lc_content.append(create_text_block(text=str(res)))
|
||
else:
|
||
lc_content.append(create_text_block(text=str(item)))
|
||
|
||
if call_tool_result.isError:
|
||
error_parts = [item["text"] for item in lc_content if isinstance(item, dict) and item.get("type") == "text"]
|
||
raise ToolException("\n".join(error_parts) if error_parts else str(lc_content))
|
||
|
||
artifact = None
|
||
if call_tool_result.structuredContent is not None:
|
||
artifact = {"structured_content": call_tool_result.structuredContent}
|
||
|
||
return lc_content, artifact
|
||
|
||
|
||
def _make_session_pool_tool(
|
||
tool: BaseTool,
|
||
server_name: str,
|
||
connection: dict[str, Any],
|
||
tool_interceptors: list[Any] | None = None,
|
||
) -> BaseTool:
|
||
"""Wrap an MCP tool so it reuses a persistent session from the pool.
|
||
|
||
Replaces the per-call session creation with pool-managed sessions scoped
|
||
by ``(server_name, thread_id)``. This ensures stateful MCP servers (e.g.
|
||
Playwright) keep their state across tool calls within the same thread.
|
||
|
||
The configured ``tool_interceptors`` (OAuth, custom) are preserved and
|
||
applied on every call before invoking the pooled session.
|
||
"""
|
||
# Strip the server-name prefix to recover the original MCP tool name.
|
||
original_name = tool.name
|
||
prefix = f"{server_name}_"
|
||
if original_name.startswith(prefix):
|
||
original_name = original_name[len(prefix) :]
|
||
|
||
pool = get_session_pool()
|
||
|
||
async def call_with_persistent_session(
|
||
runtime: Runtime | None = None,
|
||
**arguments: Any,
|
||
) -> Any:
|
||
thread_id = _extract_thread_id(runtime)
|
||
session = await pool.get_session(server_name, thread_id, connection)
|
||
|
||
if tool_interceptors:
|
||
from langchain_mcp_adapters.interceptors import MCPToolCallRequest
|
||
|
||
async def base_handler(request: MCPToolCallRequest) -> Any:
|
||
return await session.call_tool(request.name, request.args)
|
||
|
||
handler = base_handler
|
||
for interceptor in reversed(tool_interceptors):
|
||
outer = handler
|
||
|
||
async def wrapped(req: Any, _i: Any = interceptor, _h: Any = outer) -> Any:
|
||
return await _i(req, _h)
|
||
|
||
handler = wrapped
|
||
|
||
request = MCPToolCallRequest(
|
||
name=original_name,
|
||
args=arguments,
|
||
server_name=server_name,
|
||
runtime=runtime,
|
||
)
|
||
call_tool_result = await handler(request)
|
||
else:
|
||
call_tool_result = await session.call_tool(original_name, arguments)
|
||
|
||
return _convert_call_tool_result(call_tool_result)
|
||
|
||
return StructuredTool(
|
||
name=tool.name,
|
||
description=tool.description,
|
||
args_schema=tool.args_schema,
|
||
coroutine=call_with_persistent_session,
|
||
response_format="content_and_artifact",
|
||
metadata=tool.metadata,
|
||
)
|
||
|
||
|
||
async def get_mcp_tools() -> list[BaseTool]:
|
||
"""Get all tools from enabled MCP servers.
|
||
|
||
Tools are wrapped with persistent-session logic so that consecutive
|
||
calls within the same thread reuse the same MCP session.
|
||
|
||
Returns:
|
||
List of LangChain tools from all enabled MCP servers.
|
||
"""
|
||
try:
|
||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||
except ImportError:
|
||
logger.warning("langchain-mcp-adapters not installed. Install it to enable MCP tools: pip install langchain-mcp-adapters")
|
||
return []
|
||
|
||
# NOTE: We use ExtensionsConfig.from_file() instead of get_extensions_config()
|
||
# to always read the latest configuration from disk. This ensures that changes
|
||
# made through the Gateway API (which runs in a separate process) are immediately
|
||
# reflected when initializing MCP tools.
|
||
extensions_config = ExtensionsConfig.from_file()
|
||
servers_config = build_servers_config(extensions_config)
|
||
|
||
if not servers_config:
|
||
logger.info("No enabled MCP servers configured")
|
||
return []
|
||
|
||
try:
|
||
# Create the multi-server MCP client
|
||
logger.info(f"Initializing MCP client with {len(servers_config)} server(s)")
|
||
|
||
# Inject initial OAuth headers for server connections (tool discovery/session init)
|
||
initial_oauth_headers = await get_initial_oauth_headers(extensions_config)
|
||
for server_name, auth_header in initial_oauth_headers.items():
|
||
if server_name not in servers_config:
|
||
continue
|
||
if servers_config[server_name].get("transport") in ("sse", "http"):
|
||
existing_headers = dict(servers_config[server_name].get("headers", {}))
|
||
existing_headers["Authorization"] = auth_header
|
||
servers_config[server_name]["headers"] = existing_headers
|
||
|
||
tool_interceptors: list[Any] = []
|
||
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
|
||
if oauth_interceptor is not None:
|
||
tool_interceptors.append(oauth_interceptor)
|
||
|
||
# Load custom interceptors declared in extensions_config.json
|
||
# Format: "mcpInterceptors": ["pkg.module:builder_func", ...]
|
||
raw_interceptor_paths = (extensions_config.model_extra or {}).get("mcpInterceptors")
|
||
if isinstance(raw_interceptor_paths, str):
|
||
raw_interceptor_paths = [raw_interceptor_paths]
|
||
elif not isinstance(raw_interceptor_paths, list):
|
||
if raw_interceptor_paths is not None:
|
||
logger.warning(f"mcpInterceptors must be a list of strings, got {type(raw_interceptor_paths).__name__}; skipping")
|
||
raw_interceptor_paths = []
|
||
for interceptor_path in raw_interceptor_paths:
|
||
try:
|
||
builder = resolve_variable(interceptor_path)
|
||
interceptor = builder()
|
||
if callable(interceptor):
|
||
tool_interceptors.append(interceptor)
|
||
logger.info(f"Loaded MCP interceptor: {interceptor_path}")
|
||
elif interceptor is not None:
|
||
logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping")
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"Failed to load MCP interceptor {interceptor_path}: {e}",
|
||
exc_info=True,
|
||
)
|
||
|
||
client = MultiServerMCPClient(
|
||
servers_config,
|
||
tool_interceptors=tool_interceptors,
|
||
tool_name_prefix=True,
|
||
)
|
||
|
||
# Get all tools from all servers (discovers tool definitions via
|
||
# temporary sessions – the persistent-session wrapping is applied below).
|
||
tools = await client.get_tools()
|
||
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
|
||
|
||
# Wrap each tool with persistent-session logic.
|
||
wrapped_tools: list[BaseTool] = []
|
||
for tool in tools:
|
||
tool_server: str | None = None
|
||
for name in servers_config:
|
||
if tool.name.startswith(f"{name}_"):
|
||
tool_server = name
|
||
break
|
||
|
||
if tool_server is not None:
|
||
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
|
||
else:
|
||
wrapped_tools.append(tool)
|
||
|
||
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
||
for tool in wrapped_tools:
|
||
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
|
||
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
|
||
|
||
return wrapped_tools
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to load MCP tools: {e}", exc_info=True)
|
||
return []
|