From 919d8bc27919ab87eee3edb4fcda5855a75c678d Mon Sep 17 00:00:00 2001 From: Huixin615 Date: Thu, 11 Jun 2026 17:50:36 +0800 Subject: [PATCH] fix(sandbox): persist lazily-acquired sandbox state via Command (#3464) * fix(sandbox): persist lazily-acquired sandbox state via Command ensure_sandbox_initialized mutates runtime.state in place, which is local to the current tool invocation and is not picked up by LangGraph's channel reducer. Subsequent graph steps and downstream consumers (such as ToolOutputBudgetMiddleware and the sub-agent task_tool) therefore cannot observe the sandbox id from state. Wrap tool calls in SandboxMiddleware (wrap_tool_call / awrap_tool_call) to detect fresh lazy initialization by diffing runtime.state before and after the handler, and emit a proper state update via Command(update=...): - ToolMessage results are wrapped into Command(update={sandbox, messages}) - Command results with a dict update are merged on the sandbox key while preserving messages / goto / graph / resume - Command results with non-dict updates are left untouched to avoid silent data loss on unknown update shapes Tests: - 7 new unit tests cover lazy-init emit, passthrough, dict-update merge, non-dict-update passthrough (sync and async) - Refresh replay golden write_read_file.ultra.events.json: SSE 'values' events now correctly carry the 'sandbox' key in their keys list, which is the direct evidence that the fix is effective Closes #3463 * refactor(sandbox): use dataclasses.replace to preserve Command fields Address Copilot review on #3464: replace manual field-copy with dataclasses.replace so any current or future Command fields are preserved automatically when merging sandbox_update. Also add a regression test that constructs a Command with non-None graph/goto/resume to lock this behavior in. --- .../harness/deerflow/sandbox/middleware.py | 89 +++++++++ .../replay/write_read_file.ultra.events.json | 6 + backend/tests/test_sandbox_middleware.py | 183 ++++++++++++++++++ 3 files changed, 278 insertions(+) diff --git a/backend/packages/harness/deerflow/sandbox/middleware.py b/backend/packages/harness/deerflow/sandbox/middleware.py index f40781333..5bdb5a700 100644 --- a/backend/packages/harness/deerflow/sandbox/middleware.py +++ b/backend/packages/harness/deerflow/sandbox/middleware.py @@ -1,10 +1,15 @@ import asyncio import logging +from collections.abc import Awaitable, Callable +from dataclasses import replace as dc_replace from typing import NotRequired, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import ToolMessage +from langgraph.prebuilt.tool_node import ToolCallRequest from langgraph.runtime import Runtime +from langgraph.types import Command from deerflow.agents.thread_state import SandboxState, ThreadDataState from deerflow.sandbox import get_sandbox_provider @@ -126,3 +131,87 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): # No sandbox to release return await super().aafter_agent(state, runtime) + + # ------------------------------------------------------------------ + # Tool-call wrappers: persist lazily-acquired sandbox state into the + # graph state via Command(update=...). + # + # Background: + # ``ensure_sandbox_initialized*`` in ``deerflow.sandbox.tools`` mutates + # ``runtime.state["sandbox"]`` directly. That mutation is local to the + # current tool invocation and is NOT picked up by LangGraph's channel + # reducer, so subsequent graph steps (and downstream consumers such as + # ``ToolOutputBudgetMiddleware`` and the sub-agent ``task_tool``) + # cannot observe the sandbox id. Wrapping the tool call lets us detect + # a fresh lazy init by diffing the state snapshot before/after the + # handler and emit a proper state update via ``Command``. + # ------------------------------------------------------------------ + + @staticmethod + def _read_sandbox_id_from_state(state: object) -> str | None: + if not isinstance(state, dict): + return None + sandbox_state = state.get("sandbox") + if not isinstance(sandbox_state, dict): + return None + sandbox_id = sandbox_state.get("sandbox_id") + return sandbox_id if isinstance(sandbox_id, str) else None + + @staticmethod + def _attach_sandbox_update(result: ToolMessage | Command, sandbox_id: str) -> ToolMessage | Command: + """Wrap or merge ``result`` so that ``sandbox.sandbox_id`` is persisted. + + - ``ToolMessage`` -> ``Command(update={"sandbox": ..., "messages": [msg]})`` + - ``Command`` with dict update -> merge ``sandbox`` key, preserve all + existing fields (``messages``, ``goto``, ``graph``, ``resume``, ...). + - ``Command`` with non-dict / None update -> leave it untouched to + avoid silent data loss on unknown update shapes. + """ + sandbox_update = {"sandbox": {"sandbox_id": sandbox_id}} + + if isinstance(result, ToolMessage): + return Command(update={**sandbox_update, "messages": [result]}) + + existing_update = result.update + if isinstance(existing_update, dict): + merged_update = {**existing_update, **sandbox_update} + return dc_replace(result, update=merged_update) + return result + + @staticmethod + def _read_sandbox_id_from_request(request: ToolCallRequest) -> str | None: + """Read sandbox_id from runtime.state (where ensure_sandbox_initialized writes).""" + runtime = request.runtime + if runtime is None or runtime.state is None: + return None + return SandboxMiddleware._read_sandbox_id_from_state(runtime.state) + + @override + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + prev_sandbox_id = self._read_sandbox_id_from_request(request) + result = handler(request) + if prev_sandbox_id is not None: + return result + curr_sandbox_id = self._read_sandbox_id_from_request(request) + if curr_sandbox_id is None: + return result + return self._attach_sandbox_update(result, curr_sandbox_id) + + @override + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], + ) -> ToolMessage | Command: + prev_sandbox_id = self._read_sandbox_id_from_request(request) + result = await handler(request) + if prev_sandbox_id is not None: + return result + curr_sandbox_id = self._read_sandbox_id_from_request(request) + if curr_sandbox_id is None: + return result + return self._attach_sandbox_update(result, curr_sandbox_id) diff --git a/backend/tests/fixtures/replay/write_read_file.ultra.events.json b/backend/tests/fixtures/replay/write_read_file.ultra.events.json index babf90d9d..83976df10 100644 --- a/backend/tests/fixtures/replay/write_read_file.ultra.events.json +++ b/backend/tests/fixtures/replay/write_read_file.ultra.events.json @@ -69,6 +69,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -79,6 +80,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -89,6 +91,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -99,6 +102,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -109,6 +113,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -119,6 +124,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" diff --git a/backend/tests/test_sandbox_middleware.py b/backend/tests/test_sandbox_middleware.py index e3daa3088..c584c759a 100644 --- a/backend/tests/test_sandbox_middleware.py +++ b/backend/tests/test_sandbox_middleware.py @@ -5,7 +5,10 @@ import asyncio import pytest from langchain.agents.middleware import AgentMiddleware from langchain.tools import ToolRuntime +from langchain_core.messages import ToolMessage +from langgraph.prebuilt.tool_node import ToolCallRequest from langgraph.runtime import Runtime +from langgraph.types import Command from deerflow.sandbox.middleware import SandboxMiddleware from deerflow.sandbox.sandbox import Sandbox @@ -223,3 +226,183 @@ async def test_aafter_agent_delegates_to_super_when_no_sandbox(monkeypatch: pyte assert result == {"delegated": True} assert calls == [(state, runtime)] + + +# --------------------------------------------------------------------------- +# wrap_tool_call / awrap_tool_call: persistent sandbox state via Command +# --------------------------------------------------------------------------- + + +def _make_tool_call_request(state: dict) -> ToolCallRequest: + """Build a minimal ToolCallRequest backed by a real ToolRuntime.""" + runtime = ToolRuntime( + state=state, + context={}, + config={"configurable": {}}, + stream_writer=lambda _: None, + tools=[], + tool_call_id="call-1", + store=None, + ) + return ToolCallRequest( + tool_call={"id": "call-1", "name": "bash", "args": {}}, + tool=None, + state=state, + runtime=runtime, + ) + + +def test_wrap_tool_call_emits_command_when_lazy_init_happens() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + + def handler(req: ToolCallRequest) -> ToolMessage: + # Simulate ensure_sandbox_initialized() mutating runtime.state in-place. + req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"} + return ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + result = middleware.wrap_tool_call(request, handler) + + assert isinstance(result, Command) + assert isinstance(result.update, dict) + assert result.update["sandbox"] == {"sandbox_id": "new-sandbox"} + messages = result.update["messages"] + assert len(messages) == 1 + assert messages[0].content == "ok" + assert messages[0].tool_call_id == "call-1" + + +def test_wrap_tool_call_passthrough_when_sandbox_already_in_state() -> None: + middleware = SandboxMiddleware() + state: dict = {"sandbox": {"sandbox_id": "existing"}} + request = _make_tool_call_request(state) + original = ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + def handler(req: ToolCallRequest) -> ToolMessage: + return original + + result = middleware.wrap_tool_call(request, handler) + + assert result is original + + +def test_wrap_tool_call_passthrough_when_handler_did_not_initialize_sandbox() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + original = ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + def handler(req: ToolCallRequest) -> ToolMessage: + return original + + result = middleware.wrap_tool_call(request, handler) + + assert result is original + + +def test_wrap_tool_call_merges_with_existing_command_update() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + tool_msg = ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + def handler(req: ToolCallRequest) -> Command: + req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"} + return Command( + update={ + "messages": [tool_msg], + "viewed_images": {"a.png": {"base64": "x", "mime_type": "image/png"}}, + }, + goto="next-node", + ) + + result = middleware.wrap_tool_call(request, handler) + + assert isinstance(result, Command) + assert result.goto == "next-node" + assert isinstance(result.update, dict) + assert result.update["messages"] == [tool_msg] + assert result.update["viewed_images"] == {"a.png": {"base64": "x", "mime_type": "image/png"}} + assert result.update["sandbox"] == {"sandbox_id": "new-sandbox"} + + +def test_wrap_tool_call_does_not_override_non_dict_update() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + cmd = Command(update=[("messages", [ToolMessage(content="x", tool_call_id="c", name="bash")])]) + + def handler(req: ToolCallRequest) -> Command: + req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"} + return cmd + + result = middleware.wrap_tool_call(request, handler) + + # Non-dict update is left untouched to avoid silent data loss. + assert result is cmd + + +@pytest.mark.anyio +async def test_awrap_tool_call_emits_command_when_lazy_init_happens() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + + async def handler(req: ToolCallRequest) -> ToolMessage: + req.runtime.state["sandbox"] = {"sandbox_id": "async-new"} + return ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + result = await middleware.awrap_tool_call(request, handler) + + assert isinstance(result, Command) + assert isinstance(result.update, dict) + assert result.update["sandbox"] == {"sandbox_id": "async-new"} + messages = result.update["messages"] + assert len(messages) == 1 + assert messages[0].content == "ok" + + +@pytest.mark.anyio +async def test_awrap_tool_call_passthrough_when_sandbox_already_in_state() -> None: + middleware = SandboxMiddleware() + state: dict = {"sandbox": {"sandbox_id": "existing"}} + request = _make_tool_call_request(state) + original = ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + async def handler(req: ToolCallRequest) -> ToolMessage: + return original + + result = await middleware.awrap_tool_call(request, handler) + + assert result is original + + +def test_wrap_tool_call_preserves_existing_command_fields_when_merging() -> None: + """Regression: when merging sandbox_update into an existing Command, + all other Command fields (e.g. graph, goto, resume) must be preserved. + """ + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + + def handler(req: ToolCallRequest) -> Command: + req.runtime.state["sandbox"] = {"sandbox_id": "sbx-merge"} + return Command( + update={"existing_key": "existing_value"}, + graph="parent", + goto="next_node", + resume="resume-token", + ) + + result = middleware.wrap_tool_call(request, handler) + + assert isinstance(result, Command) + assert result.update == { + "existing_key": "existing_value", + "sandbox": {"sandbox_id": "sbx-merge"}, + } + # Critical: other Command fields must NOT be dropped by the merge. + assert result.graph == "parent" + assert result.goto == "next_node" + assert result.resume == "resume-token"