Files
deer-flow/backend/packages/harness/deerflow/sandbox/middleware.py
T
Huixin615 919d8bc279 fix(sandbox): persist lazily-acquired sandbox state via Command (#3464)
* fix(sandbox): persist lazily-acquired sandbox state via Command

ensure_sandbox_initialized mutates runtime.state in place, which is local
to the current tool invocation and is not picked up by LangGraph's channel
reducer. Subsequent graph steps and downstream consumers (such as
ToolOutputBudgetMiddleware and the sub-agent task_tool) therefore cannot
observe the sandbox id from state.

Wrap tool calls in SandboxMiddleware (wrap_tool_call / awrap_tool_call) to
detect fresh lazy initialization by diffing runtime.state before and after
the handler, and emit a proper state update via Command(update=...):

- ToolMessage results are wrapped into Command(update={sandbox, messages})
- Command results with a dict update are merged on the sandbox key while
  preserving messages / goto / graph / resume
- Command results with non-dict updates are left untouched to avoid silent
  data loss on unknown update shapes

Tests:
- 7 new unit tests cover lazy-init emit, passthrough, dict-update merge,
  non-dict-update passthrough (sync and async)
- Refresh replay golden write_read_file.ultra.events.json: SSE 'values'
  events now correctly carry the 'sandbox' key in their keys list, which
  is the direct evidence that the fix is effective

Closes #3463

* refactor(sandbox): use dataclasses.replace to preserve Command fields

Address Copilot review on #3464: replace manual field-copy with
dataclasses.replace so any current or future Command fields are
preserved automatically when merging sandbox_update.

Also add a regression test that constructs a Command with non-None
graph/goto/resume to lock this behavior in.
2026-06-11 17:50:36 +08:00

218 lines
9.1 KiB
Python

import asyncio
import logging
from collections.abc import Awaitable, Callable
from dataclasses import replace as dc_replace
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.runtime import Runtime
from langgraph.types import Command
from deerflow.agents.thread_state import SandboxState, ThreadDataState
from deerflow.sandbox import get_sandbox_provider
logger = logging.getLogger(__name__)
class SandboxMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
sandbox: NotRequired[SandboxState | None]
thread_data: NotRequired[ThreadDataState | None]
class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
"""Create a sandbox environment and assign it to an agent.
Lifecycle Management:
- With lazy_init=True (default): Sandbox is acquired on first tool call
- With lazy_init=False: Sandbox is acquired on first agent invocation (before_agent)
- Sandbox is reused across multiple turns within the same thread
- Sandbox is NOT released after each agent call to avoid wasteful recreation
- Cleanup happens at application shutdown via SandboxProvider.shutdown()
"""
state_schema = SandboxMiddlewareState
def __init__(self, lazy_init: bool = True):
"""Initialize sandbox middleware.
Args:
lazy_init: If True, defer sandbox acquisition until first tool call.
If False, acquire sandbox eagerly in before_agent().
Default is True for optimal performance.
"""
super().__init__()
self._lazy_init = lazy_init
def _acquire_sandbox(self, thread_id: str) -> str:
provider = get_sandbox_provider()
sandbox_id = provider.acquire(thread_id)
logger.info(f"Acquiring sandbox {sandbox_id}")
return sandbox_id
async def _acquire_sandbox_async(self, thread_id: str) -> str:
provider = get_sandbox_provider()
sandbox_id = await provider.acquire_async(thread_id)
logger.info(f"Acquiring sandbox {sandbox_id}")
return sandbox_id
async def _release_sandbox_async(self, sandbox_id: str) -> None:
await asyncio.to_thread(get_sandbox_provider().release, sandbox_id)
@override
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
# Skip acquisition if lazy_init is enabled
if self._lazy_init:
return super().before_agent(state, runtime)
# Eager initialization (original behavior)
if "sandbox" not in state or state["sandbox"] is None:
thread_id = (runtime.context or {}).get("thread_id")
if thread_id is None:
return super().before_agent(state, runtime)
sandbox_id = self._acquire_sandbox(thread_id)
logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}")
return {"sandbox": {"sandbox_id": sandbox_id}}
return super().before_agent(state, runtime)
@override
async def abefore_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
# Skip acquisition if lazy_init is enabled
if self._lazy_init:
return await super().abefore_agent(state, runtime)
# Eager initialization (original behavior), but use the async provider
# hook so blocking sandbox startup/polling runs outside the event loop.
if "sandbox" not in state or state["sandbox"] is None:
thread_id = (runtime.context or {}).get("thread_id")
if thread_id is None:
return await super().abefore_agent(state, runtime)
sandbox_id = await self._acquire_sandbox_async(thread_id)
logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}")
return {"sandbox": {"sandbox_id": sandbox_id}}
return await super().abefore_agent(state, runtime)
@override
def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
sandbox = state.get("sandbox")
if sandbox is not None:
sandbox_id = sandbox["sandbox_id"]
logger.info(f"Releasing sandbox {sandbox_id}")
get_sandbox_provider().release(sandbox_id)
return None
if (runtime.context or {}).get("sandbox_id") is not None:
sandbox_id = runtime.context.get("sandbox_id")
logger.info(f"Releasing sandbox {sandbox_id} from context")
get_sandbox_provider().release(sandbox_id)
return None
# No sandbox to release
return super().after_agent(state, runtime)
@override
async def aafter_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
sandbox = state.get("sandbox")
if sandbox is not None:
sandbox_id = sandbox["sandbox_id"]
logger.info(f"Releasing sandbox {sandbox_id}")
await self._release_sandbox_async(sandbox_id)
return None
if (runtime.context or {}).get("sandbox_id") is not None:
sandbox_id = runtime.context.get("sandbox_id")
logger.info(f"Releasing sandbox {sandbox_id} from context")
await self._release_sandbox_async(sandbox_id)
return None
# No sandbox to release
return await super().aafter_agent(state, runtime)
# ------------------------------------------------------------------
# Tool-call wrappers: persist lazily-acquired sandbox state into the
# graph state via Command(update=...).
#
# Background:
# ``ensure_sandbox_initialized*`` in ``deerflow.sandbox.tools`` mutates
# ``runtime.state["sandbox"]`` directly. That mutation is local to the
# current tool invocation and is NOT picked up by LangGraph's channel
# reducer, so subsequent graph steps (and downstream consumers such as
# ``ToolOutputBudgetMiddleware`` and the sub-agent ``task_tool``)
# cannot observe the sandbox id. Wrapping the tool call lets us detect
# a fresh lazy init by diffing the state snapshot before/after the
# handler and emit a proper state update via ``Command``.
# ------------------------------------------------------------------
@staticmethod
def _read_sandbox_id_from_state(state: object) -> str | None:
if not isinstance(state, dict):
return None
sandbox_state = state.get("sandbox")
if not isinstance(sandbox_state, dict):
return None
sandbox_id = sandbox_state.get("sandbox_id")
return sandbox_id if isinstance(sandbox_id, str) else None
@staticmethod
def _attach_sandbox_update(result: ToolMessage | Command, sandbox_id: str) -> ToolMessage | Command:
"""Wrap or merge ``result`` so that ``sandbox.sandbox_id`` is persisted.
- ``ToolMessage`` -> ``Command(update={"sandbox": ..., "messages": [msg]})``
- ``Command`` with dict update -> merge ``sandbox`` key, preserve all
existing fields (``messages``, ``goto``, ``graph``, ``resume``, ...).
- ``Command`` with non-dict / None update -> leave it untouched to
avoid silent data loss on unknown update shapes.
"""
sandbox_update = {"sandbox": {"sandbox_id": sandbox_id}}
if isinstance(result, ToolMessage):
return Command(update={**sandbox_update, "messages": [result]})
existing_update = result.update
if isinstance(existing_update, dict):
merged_update = {**existing_update, **sandbox_update}
return dc_replace(result, update=merged_update)
return result
@staticmethod
def _read_sandbox_id_from_request(request: ToolCallRequest) -> str | None:
"""Read sandbox_id from runtime.state (where ensure_sandbox_initialized writes)."""
runtime = request.runtime
if runtime is None or runtime.state is None:
return None
return SandboxMiddleware._read_sandbox_id_from_state(runtime.state)
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
prev_sandbox_id = self._read_sandbox_id_from_request(request)
result = handler(request)
if prev_sandbox_id is not None:
return result
curr_sandbox_id = self._read_sandbox_id_from_request(request)
if curr_sandbox_id is None:
return result
return self._attach_sandbox_update(result, curr_sandbox_id)
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
prev_sandbox_id = self._read_sandbox_id_from_request(request)
result = await handler(request)
if prev_sandbox_id is not None:
return result
curr_sandbox_id = self._read_sandbox_id_from_request(request)
if curr_sandbox_id is None:
return result
return self._attach_sandbox_update(result, curr_sandbox_id)