import asyncio import logging from collections.abc import Awaitable, Callable from dataclasses import replace as dc_replace from typing import NotRequired, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langchain_core.messages import ToolMessage from langgraph.prebuilt.tool_node import ToolCallRequest from langgraph.runtime import Runtime from langgraph.types import Command from deerflow.agents.thread_state import SandboxState, ThreadDataState from deerflow.sandbox import get_sandbox_provider logger = logging.getLogger(__name__) class SandboxMiddlewareState(AgentState): """Compatible with the `ThreadState` schema.""" sandbox: NotRequired[SandboxState | None] thread_data: NotRequired[ThreadDataState | None] class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): """Create a sandbox environment and assign it to an agent. Lifecycle Management: - With lazy_init=True (default): Sandbox is acquired on first tool call - With lazy_init=False: Sandbox is acquired on first agent invocation (before_agent) - Sandbox is reused across multiple turns within the same thread - Sandbox is NOT released after each agent call to avoid wasteful recreation - Cleanup happens at application shutdown via SandboxProvider.shutdown() """ state_schema = SandboxMiddlewareState def __init__(self, lazy_init: bool = True): """Initialize sandbox middleware. Args: lazy_init: If True, defer sandbox acquisition until first tool call. If False, acquire sandbox eagerly in before_agent(). Default is True for optimal performance. """ super().__init__() self._lazy_init = lazy_init def _acquire_sandbox(self, thread_id: str) -> str: provider = get_sandbox_provider() sandbox_id = provider.acquire(thread_id) logger.info(f"Acquiring sandbox {sandbox_id}") return sandbox_id async def _acquire_sandbox_async(self, thread_id: str) -> str: provider = get_sandbox_provider() sandbox_id = await provider.acquire_async(thread_id) logger.info(f"Acquiring sandbox {sandbox_id}") return sandbox_id async def _release_sandbox_async(self, sandbox_id: str) -> None: await asyncio.to_thread(get_sandbox_provider().release, sandbox_id) @override def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: # Skip acquisition if lazy_init is enabled if self._lazy_init: return super().before_agent(state, runtime) # Eager initialization (original behavior) if "sandbox" not in state or state["sandbox"] is None: thread_id = (runtime.context or {}).get("thread_id") if thread_id is None: return super().before_agent(state, runtime) sandbox_id = self._acquire_sandbox(thread_id) logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}") return {"sandbox": {"sandbox_id": sandbox_id}} return super().before_agent(state, runtime) @override async def abefore_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: # Skip acquisition if lazy_init is enabled if self._lazy_init: return await super().abefore_agent(state, runtime) # Eager initialization (original behavior), but use the async provider # hook so blocking sandbox startup/polling runs outside the event loop. if "sandbox" not in state or state["sandbox"] is None: thread_id = (runtime.context or {}).get("thread_id") if thread_id is None: return await super().abefore_agent(state, runtime) sandbox_id = await self._acquire_sandbox_async(thread_id) logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}") return {"sandbox": {"sandbox_id": sandbox_id}} return await super().abefore_agent(state, runtime) @override def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: sandbox = state.get("sandbox") if sandbox is not None: sandbox_id = sandbox["sandbox_id"] logger.info(f"Releasing sandbox {sandbox_id}") get_sandbox_provider().release(sandbox_id) return None if (runtime.context or {}).get("sandbox_id") is not None: sandbox_id = runtime.context.get("sandbox_id") logger.info(f"Releasing sandbox {sandbox_id} from context") get_sandbox_provider().release(sandbox_id) return None # No sandbox to release return super().after_agent(state, runtime) @override async def aafter_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: sandbox = state.get("sandbox") if sandbox is not None: sandbox_id = sandbox["sandbox_id"] logger.info(f"Releasing sandbox {sandbox_id}") await self._release_sandbox_async(sandbox_id) return None if (runtime.context or {}).get("sandbox_id") is not None: sandbox_id = runtime.context.get("sandbox_id") logger.info(f"Releasing sandbox {sandbox_id} from context") await self._release_sandbox_async(sandbox_id) return None # No sandbox to release return await super().aafter_agent(state, runtime) # ------------------------------------------------------------------ # Tool-call wrappers: persist lazily-acquired sandbox state into the # graph state via Command(update=...). # # Background: # ``ensure_sandbox_initialized*`` in ``deerflow.sandbox.tools`` mutates # ``runtime.state["sandbox"]`` directly. That mutation is local to the # current tool invocation and is NOT picked up by LangGraph's channel # reducer, so subsequent graph steps (and downstream consumers such as # ``ToolOutputBudgetMiddleware`` and the sub-agent ``task_tool``) # cannot observe the sandbox id. Wrapping the tool call lets us detect # a fresh lazy init by diffing the state snapshot before/after the # handler and emit a proper state update via ``Command``. # ------------------------------------------------------------------ @staticmethod def _read_sandbox_id_from_state(state: object) -> str | None: if not isinstance(state, dict): return None sandbox_state = state.get("sandbox") if not isinstance(sandbox_state, dict): return None sandbox_id = sandbox_state.get("sandbox_id") return sandbox_id if isinstance(sandbox_id, str) else None @staticmethod def _attach_sandbox_update(result: ToolMessage | Command, sandbox_id: str) -> ToolMessage | Command: """Wrap or merge ``result`` so that ``sandbox.sandbox_id`` is persisted. - ``ToolMessage`` -> ``Command(update={"sandbox": ..., "messages": [msg]})`` - ``Command`` with dict update -> merge ``sandbox`` key, preserve all existing fields (``messages``, ``goto``, ``graph``, ``resume``, ...). - ``Command`` with non-dict / None update -> leave it untouched to avoid silent data loss on unknown update shapes. """ sandbox_update = {"sandbox": {"sandbox_id": sandbox_id}} if isinstance(result, ToolMessage): return Command(update={**sandbox_update, "messages": [result]}) existing_update = result.update if isinstance(existing_update, dict): merged_update = {**existing_update, **sandbox_update} return dc_replace(result, update=merged_update) return result @staticmethod def _read_sandbox_id_from_request(request: ToolCallRequest) -> str | None: """Read sandbox_id from runtime.state (where ensure_sandbox_initialized writes).""" runtime = request.runtime if runtime is None or runtime.state is None: return None return SandboxMiddleware._read_sandbox_id_from_state(runtime.state) @override def wrap_tool_call( self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command], ) -> ToolMessage | Command: prev_sandbox_id = self._read_sandbox_id_from_request(request) result = handler(request) if prev_sandbox_id is not None: return result curr_sandbox_id = self._read_sandbox_id_from_request(request) if curr_sandbox_id is None: return result return self._attach_sandbox_update(result, curr_sandbox_id) @override async def awrap_tool_call( self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], ) -> ToolMessage | Command: prev_sandbox_id = self._read_sandbox_id_from_request(request) result = await handler(request) if prev_sandbox_id is not None: return result curr_sandbox_id = self._read_sandbox_id_from_request(request) if curr_sandbox_id is None: return result return self._attach_sandbox_update(result, curr_sandbox_id)