diff --git a/backend/packages/harness/deerflow/sandbox/middleware.py b/backend/packages/harness/deerflow/sandbox/middleware.py index f40781333..5bdb5a700 100644 --- a/backend/packages/harness/deerflow/sandbox/middleware.py +++ b/backend/packages/harness/deerflow/sandbox/middleware.py @@ -1,10 +1,15 @@ import asyncio import logging +from collections.abc import Awaitable, Callable +from dataclasses import replace as dc_replace from typing import NotRequired, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import ToolMessage +from langgraph.prebuilt.tool_node import ToolCallRequest from langgraph.runtime import Runtime +from langgraph.types import Command from deerflow.agents.thread_state import SandboxState, ThreadDataState from deerflow.sandbox import get_sandbox_provider @@ -126,3 +131,87 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): # No sandbox to release return await super().aafter_agent(state, runtime) + + # ------------------------------------------------------------------ + # Tool-call wrappers: persist lazily-acquired sandbox state into the + # graph state via Command(update=...). + # + # Background: + # ``ensure_sandbox_initialized*`` in ``deerflow.sandbox.tools`` mutates + # ``runtime.state["sandbox"]`` directly. That mutation is local to the + # current tool invocation and is NOT picked up by LangGraph's channel + # reducer, so subsequent graph steps (and downstream consumers such as + # ``ToolOutputBudgetMiddleware`` and the sub-agent ``task_tool``) + # cannot observe the sandbox id. Wrapping the tool call lets us detect + # a fresh lazy init by diffing the state snapshot before/after the + # handler and emit a proper state update via ``Command``. + # ------------------------------------------------------------------ + + @staticmethod + def _read_sandbox_id_from_state(state: object) -> str | None: + if not isinstance(state, dict): + return None + sandbox_state = state.get("sandbox") + if not isinstance(sandbox_state, dict): + return None + sandbox_id = sandbox_state.get("sandbox_id") + return sandbox_id if isinstance(sandbox_id, str) else None + + @staticmethod + def _attach_sandbox_update(result: ToolMessage | Command, sandbox_id: str) -> ToolMessage | Command: + """Wrap or merge ``result`` so that ``sandbox.sandbox_id`` is persisted. + + - ``ToolMessage`` -> ``Command(update={"sandbox": ..., "messages": [msg]})`` + - ``Command`` with dict update -> merge ``sandbox`` key, preserve all + existing fields (``messages``, ``goto``, ``graph``, ``resume``, ...). + - ``Command`` with non-dict / None update -> leave it untouched to + avoid silent data loss on unknown update shapes. + """ + sandbox_update = {"sandbox": {"sandbox_id": sandbox_id}} + + if isinstance(result, ToolMessage): + return Command(update={**sandbox_update, "messages": [result]}) + + existing_update = result.update + if isinstance(existing_update, dict): + merged_update = {**existing_update, **sandbox_update} + return dc_replace(result, update=merged_update) + return result + + @staticmethod + def _read_sandbox_id_from_request(request: ToolCallRequest) -> str | None: + """Read sandbox_id from runtime.state (where ensure_sandbox_initialized writes).""" + runtime = request.runtime + if runtime is None or runtime.state is None: + return None + return SandboxMiddleware._read_sandbox_id_from_state(runtime.state) + + @override + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + prev_sandbox_id = self._read_sandbox_id_from_request(request) + result = handler(request) + if prev_sandbox_id is not None: + return result + curr_sandbox_id = self._read_sandbox_id_from_request(request) + if curr_sandbox_id is None: + return result + return self._attach_sandbox_update(result, curr_sandbox_id) + + @override + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], + ) -> ToolMessage | Command: + prev_sandbox_id = self._read_sandbox_id_from_request(request) + result = await handler(request) + if prev_sandbox_id is not None: + return result + curr_sandbox_id = self._read_sandbox_id_from_request(request) + if curr_sandbox_id is None: + return result + return self._attach_sandbox_update(result, curr_sandbox_id) diff --git a/backend/tests/fixtures/replay/write_read_file.ultra.events.json b/backend/tests/fixtures/replay/write_read_file.ultra.events.json index babf90d9d..83976df10 100644 --- a/backend/tests/fixtures/replay/write_read_file.ultra.events.json +++ b/backend/tests/fixtures/replay/write_read_file.ultra.events.json @@ -69,6 +69,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -79,6 +80,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -89,6 +91,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -99,6 +102,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -109,6 +113,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -119,6 +124,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" diff --git a/backend/tests/test_sandbox_middleware.py b/backend/tests/test_sandbox_middleware.py index e3daa3088..c584c759a 100644 --- a/backend/tests/test_sandbox_middleware.py +++ b/backend/tests/test_sandbox_middleware.py @@ -5,7 +5,10 @@ import asyncio import pytest from langchain.agents.middleware import AgentMiddleware from langchain.tools import ToolRuntime +from langchain_core.messages import ToolMessage +from langgraph.prebuilt.tool_node import ToolCallRequest from langgraph.runtime import Runtime +from langgraph.types import Command from deerflow.sandbox.middleware import SandboxMiddleware from deerflow.sandbox.sandbox import Sandbox @@ -223,3 +226,183 @@ async def test_aafter_agent_delegates_to_super_when_no_sandbox(monkeypatch: pyte assert result == {"delegated": True} assert calls == [(state, runtime)] + + +# --------------------------------------------------------------------------- +# wrap_tool_call / awrap_tool_call: persistent sandbox state via Command +# --------------------------------------------------------------------------- + + +def _make_tool_call_request(state: dict) -> ToolCallRequest: + """Build a minimal ToolCallRequest backed by a real ToolRuntime.""" + runtime = ToolRuntime( + state=state, + context={}, + config={"configurable": {}}, + stream_writer=lambda _: None, + tools=[], + tool_call_id="call-1", + store=None, + ) + return ToolCallRequest( + tool_call={"id": "call-1", "name": "bash", "args": {}}, + tool=None, + state=state, + runtime=runtime, + ) + + +def test_wrap_tool_call_emits_command_when_lazy_init_happens() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + + def handler(req: ToolCallRequest) -> ToolMessage: + # Simulate ensure_sandbox_initialized() mutating runtime.state in-place. + req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"} + return ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + result = middleware.wrap_tool_call(request, handler) + + assert isinstance(result, Command) + assert isinstance(result.update, dict) + assert result.update["sandbox"] == {"sandbox_id": "new-sandbox"} + messages = result.update["messages"] + assert len(messages) == 1 + assert messages[0].content == "ok" + assert messages[0].tool_call_id == "call-1" + + +def test_wrap_tool_call_passthrough_when_sandbox_already_in_state() -> None: + middleware = SandboxMiddleware() + state: dict = {"sandbox": {"sandbox_id": "existing"}} + request = _make_tool_call_request(state) + original = ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + def handler(req: ToolCallRequest) -> ToolMessage: + return original + + result = middleware.wrap_tool_call(request, handler) + + assert result is original + + +def test_wrap_tool_call_passthrough_when_handler_did_not_initialize_sandbox() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + original = ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + def handler(req: ToolCallRequest) -> ToolMessage: + return original + + result = middleware.wrap_tool_call(request, handler) + + assert result is original + + +def test_wrap_tool_call_merges_with_existing_command_update() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + tool_msg = ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + def handler(req: ToolCallRequest) -> Command: + req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"} + return Command( + update={ + "messages": [tool_msg], + "viewed_images": {"a.png": {"base64": "x", "mime_type": "image/png"}}, + }, + goto="next-node", + ) + + result = middleware.wrap_tool_call(request, handler) + + assert isinstance(result, Command) + assert result.goto == "next-node" + assert isinstance(result.update, dict) + assert result.update["messages"] == [tool_msg] + assert result.update["viewed_images"] == {"a.png": {"base64": "x", "mime_type": "image/png"}} + assert result.update["sandbox"] == {"sandbox_id": "new-sandbox"} + + +def test_wrap_tool_call_does_not_override_non_dict_update() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + cmd = Command(update=[("messages", [ToolMessage(content="x", tool_call_id="c", name="bash")])]) + + def handler(req: ToolCallRequest) -> Command: + req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"} + return cmd + + result = middleware.wrap_tool_call(request, handler) + + # Non-dict update is left untouched to avoid silent data loss. + assert result is cmd + + +@pytest.mark.anyio +async def test_awrap_tool_call_emits_command_when_lazy_init_happens() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + + async def handler(req: ToolCallRequest) -> ToolMessage: + req.runtime.state["sandbox"] = {"sandbox_id": "async-new"} + return ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + result = await middleware.awrap_tool_call(request, handler) + + assert isinstance(result, Command) + assert isinstance(result.update, dict) + assert result.update["sandbox"] == {"sandbox_id": "async-new"} + messages = result.update["messages"] + assert len(messages) == 1 + assert messages[0].content == "ok" + + +@pytest.mark.anyio +async def test_awrap_tool_call_passthrough_when_sandbox_already_in_state() -> None: + middleware = SandboxMiddleware() + state: dict = {"sandbox": {"sandbox_id": "existing"}} + request = _make_tool_call_request(state) + original = ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + async def handler(req: ToolCallRequest) -> ToolMessage: + return original + + result = await middleware.awrap_tool_call(request, handler) + + assert result is original + + +def test_wrap_tool_call_preserves_existing_command_fields_when_merging() -> None: + """Regression: when merging sandbox_update into an existing Command, + all other Command fields (e.g. graph, goto, resume) must be preserved. + """ + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + + def handler(req: ToolCallRequest) -> Command: + req.runtime.state["sandbox"] = {"sandbox_id": "sbx-merge"} + return Command( + update={"existing_key": "existing_value"}, + graph="parent", + goto="next_node", + resume="resume-token", + ) + + result = middleware.wrap_tool_call(request, handler) + + assert isinstance(result, Command) + assert result.update == { + "existing_key": "existing_value", + "sandbox": {"sandbox_id": "sbx-merge"}, + } + # Critical: other Command fields must NOT be dropped by the merge. + assert result.graph == "parent" + assert result.goto == "next_node" + assert result.resume == "resume-token"