mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 18:05:58 +00:00
919d8bc279
* fix(sandbox): persist lazily-acquired sandbox state via Command
ensure_sandbox_initialized mutates runtime.state in place, which is local
to the current tool invocation and is not picked up by LangGraph's channel
reducer. Subsequent graph steps and downstream consumers (such as
ToolOutputBudgetMiddleware and the sub-agent task_tool) therefore cannot
observe the sandbox id from state.
Wrap tool calls in SandboxMiddleware (wrap_tool_call / awrap_tool_call) to
detect fresh lazy initialization by diffing runtime.state before and after
the handler, and emit a proper state update via Command(update=...):
- ToolMessage results are wrapped into Command(update={sandbox, messages})
- Command results with a dict update are merged on the sandbox key while
preserving messages / goto / graph / resume
- Command results with non-dict updates are left untouched to avoid silent
data loss on unknown update shapes
Tests:
- 7 new unit tests cover lazy-init emit, passthrough, dict-update merge,
non-dict-update passthrough (sync and async)
- Refresh replay golden write_read_file.ultra.events.json: SSE 'values'
events now correctly carry the 'sandbox' key in their keys list, which
is the direct evidence that the fix is effective
Closes #3463
* refactor(sandbox): use dataclasses.replace to preserve Command fields
Address Copilot review on #3464: replace manual field-copy with
dataclasses.replace so any current or future Command fields are
preserved automatically when merging sandbox_update.
Also add a regression test that constructs a Command with non-None
graph/goto/resume to lock this behavior in.
409 lines
13 KiB
Python
409 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
|
|
import pytest
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langchain.tools import ToolRuntime
|
|
from langchain_core.messages import ToolMessage
|
|
from langgraph.prebuilt.tool_node import ToolCallRequest
|
|
from langgraph.runtime import Runtime
|
|
from langgraph.types import Command
|
|
|
|
from deerflow.sandbox.middleware import SandboxMiddleware
|
|
from deerflow.sandbox.sandbox import Sandbox
|
|
from deerflow.sandbox.sandbox_provider import SandboxProvider, reset_sandbox_provider, set_sandbox_provider
|
|
from deerflow.sandbox.search import GrepMatch
|
|
from deerflow.sandbox.tools import ls_tool
|
|
|
|
|
|
class _SyncProvider(SandboxProvider):
|
|
def __init__(self) -> None:
|
|
self.thread_ids: list[str | None] = []
|
|
|
|
def acquire(self, thread_id: str | None = None) -> str:
|
|
self.thread_ids.append(thread_id)
|
|
return "sync-sandbox"
|
|
|
|
def get(self, sandbox_id: str) -> Sandbox | None:
|
|
return None
|
|
|
|
def release(self, sandbox_id: str) -> None:
|
|
return None
|
|
|
|
|
|
class _SandboxStub(Sandbox):
|
|
def execute_command(self, command: str) -> str:
|
|
return "OK"
|
|
|
|
def read_file(self, path: str) -> str:
|
|
return "content"
|
|
|
|
def download_file(self, path: str) -> bytes:
|
|
return b"content"
|
|
|
|
def list_dir(self, path: str, max_depth: int = 2) -> list[str]:
|
|
return ["/mnt/user-data/workspace/file.txt"]
|
|
|
|
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
|
return None
|
|
|
|
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
|
return [], False
|
|
|
|
def grep(
|
|
self,
|
|
path: str,
|
|
pattern: str,
|
|
*,
|
|
glob: str | None = None,
|
|
literal: bool = False,
|
|
case_sensitive: bool = False,
|
|
max_results: int = 100,
|
|
) -> tuple[list[GrepMatch], bool]:
|
|
return [], False
|
|
|
|
def update_file(self, path: str, content: bytes) -> None:
|
|
return None
|
|
|
|
|
|
class _AsyncOnlyProvider(SandboxProvider):
|
|
def __init__(self) -> None:
|
|
self.thread_ids: list[str | None] = []
|
|
self.released_ids: list[str] = []
|
|
self.sandbox = _SandboxStub("async-sandbox")
|
|
|
|
def acquire(self, thread_id: str | None = None) -> str:
|
|
raise AssertionError("async middleware should not call sync acquire")
|
|
|
|
async def acquire_async(self, thread_id: str | None = None) -> str:
|
|
self.thread_ids.append(thread_id)
|
|
return "async-sandbox"
|
|
|
|
def get(self, sandbox_id: str) -> Sandbox | None:
|
|
if sandbox_id == "async-sandbox":
|
|
return self.sandbox
|
|
return None
|
|
|
|
def release(self, sandbox_id: str) -> None:
|
|
self.released_ids.append(sandbox_id)
|
|
return None
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_provider_default_acquire_async_offloads_sync_acquire(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
provider = _SyncProvider()
|
|
calls: list[tuple[object, tuple[object, ...]]] = []
|
|
|
|
async def fake_to_thread(func, /, *args):
|
|
calls.append((func, args))
|
|
return func(*args)
|
|
|
|
monkeypatch.setattr(asyncio, "to_thread", fake_to_thread)
|
|
|
|
sandbox_id = await provider.acquire_async("thread-1")
|
|
|
|
assert sandbox_id == "sync-sandbox"
|
|
assert provider.thread_ids == ["thread-1"]
|
|
assert calls == [(provider.acquire, ("thread-1",))]
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_abefore_agent_uses_async_provider_acquire() -> None:
|
|
provider = _AsyncOnlyProvider()
|
|
set_sandbox_provider(provider)
|
|
try:
|
|
middleware = SandboxMiddleware(lazy_init=False)
|
|
|
|
result = await middleware.abefore_agent({}, Runtime(context={"thread_id": "thread-2"}))
|
|
finally:
|
|
reset_sandbox_provider()
|
|
|
|
assert result == {"sandbox": {"sandbox_id": "async-sandbox"}}
|
|
assert provider.thread_ids == ["thread-2"]
|
|
|
|
|
|
@pytest.mark.anyio
|
|
@pytest.mark.parametrize(
|
|
("middleware", "state", "runtime"),
|
|
[
|
|
(SandboxMiddleware(lazy_init=True), {}, Runtime(context={"thread_id": "thread-lazy"})),
|
|
(SandboxMiddleware(lazy_init=False), {}, Runtime(context={})),
|
|
(SandboxMiddleware(lazy_init=False), {"sandbox": {"sandbox_id": "existing"}}, Runtime(context={"thread_id": "thread-existing"})),
|
|
],
|
|
)
|
|
async def test_abefore_agent_delegates_to_super_when_not_acquiring(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
middleware: SandboxMiddleware,
|
|
state: dict,
|
|
runtime: Runtime,
|
|
) -> None:
|
|
calls: list[tuple[dict, Runtime]] = []
|
|
|
|
async def fake_super_abefore_agent(self, state_arg, runtime_arg):
|
|
calls.append((state_arg, runtime_arg))
|
|
return {"delegated": True}
|
|
|
|
monkeypatch.setattr(AgentMiddleware, "abefore_agent", fake_super_abefore_agent)
|
|
|
|
result = await middleware.abefore_agent(state, runtime)
|
|
|
|
assert result == {"delegated": True}
|
|
assert calls == [(state, runtime)]
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_default_lazy_tool_acquisition_uses_async_provider() -> None:
|
|
provider = _AsyncOnlyProvider()
|
|
set_sandbox_provider(provider)
|
|
try:
|
|
runtime = ToolRuntime(
|
|
state={},
|
|
context={"thread_id": "thread-lazy"},
|
|
config={"configurable": {}},
|
|
stream_writer=lambda _: None,
|
|
tools=[],
|
|
tool_call_id="call-1",
|
|
store=None,
|
|
)
|
|
|
|
result = await ls_tool.ainvoke({"runtime": runtime, "description": "list workspace", "path": "/mnt/user-data/workspace"})
|
|
finally:
|
|
reset_sandbox_provider()
|
|
|
|
assert result == "/mnt/user-data/workspace/file.txt"
|
|
assert provider.thread_ids == ["thread-lazy"]
|
|
assert runtime.state["sandbox"] == {"sandbox_id": "async-sandbox"}
|
|
assert runtime.context["sandbox_id"] == "async-sandbox"
|
|
|
|
|
|
@pytest.mark.anyio
|
|
@pytest.mark.parametrize(
|
|
("state", "runtime", "expected_sandbox_id"),
|
|
[
|
|
({"sandbox": {"sandbox_id": "state-sandbox"}}, Runtime(context={}), "state-sandbox"),
|
|
({}, Runtime(context={"sandbox_id": "context-sandbox"}), "context-sandbox"),
|
|
],
|
|
)
|
|
async def test_aafter_agent_releases_sandbox_off_thread(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
state: dict,
|
|
runtime: Runtime,
|
|
expected_sandbox_id: str,
|
|
) -> None:
|
|
provider = _AsyncOnlyProvider()
|
|
to_thread_calls: list[tuple[object, tuple[object, ...]]] = []
|
|
|
|
async def fake_to_thread(func, /, *args):
|
|
to_thread_calls.append((func, args))
|
|
return func(*args)
|
|
|
|
monkeypatch.setattr(asyncio, "to_thread", fake_to_thread)
|
|
set_sandbox_provider(provider)
|
|
try:
|
|
result = await SandboxMiddleware().aafter_agent(state, runtime)
|
|
finally:
|
|
reset_sandbox_provider()
|
|
|
|
assert result is None
|
|
assert provider.released_ids == [expected_sandbox_id]
|
|
assert to_thread_calls == [(provider.release, (expected_sandbox_id,))]
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_aafter_agent_delegates_to_super_when_no_sandbox(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
calls: list[tuple[dict, Runtime]] = []
|
|
|
|
async def fake_super_aafter_agent(self, state_arg, runtime_arg):
|
|
calls.append((state_arg, runtime_arg))
|
|
return {"delegated": True}
|
|
|
|
monkeypatch.setattr(AgentMiddleware, "aafter_agent", fake_super_aafter_agent)
|
|
|
|
state = {}
|
|
runtime = Runtime(context={})
|
|
result = await SandboxMiddleware().aafter_agent(state, runtime)
|
|
|
|
assert result == {"delegated": True}
|
|
assert calls == [(state, runtime)]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# wrap_tool_call / awrap_tool_call: persistent sandbox state via Command
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_tool_call_request(state: dict) -> ToolCallRequest:
|
|
"""Build a minimal ToolCallRequest backed by a real ToolRuntime."""
|
|
runtime = ToolRuntime(
|
|
state=state,
|
|
context={},
|
|
config={"configurable": {}},
|
|
stream_writer=lambda _: None,
|
|
tools=[],
|
|
tool_call_id="call-1",
|
|
store=None,
|
|
)
|
|
return ToolCallRequest(
|
|
tool_call={"id": "call-1", "name": "bash", "args": {}},
|
|
tool=None,
|
|
state=state,
|
|
runtime=runtime,
|
|
)
|
|
|
|
|
|
def test_wrap_tool_call_emits_command_when_lazy_init_happens() -> None:
|
|
middleware = SandboxMiddleware()
|
|
state: dict = {}
|
|
request = _make_tool_call_request(state)
|
|
|
|
def handler(req: ToolCallRequest) -> ToolMessage:
|
|
# Simulate ensure_sandbox_initialized() mutating runtime.state in-place.
|
|
req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"}
|
|
return ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
|
|
|
result = middleware.wrap_tool_call(request, handler)
|
|
|
|
assert isinstance(result, Command)
|
|
assert isinstance(result.update, dict)
|
|
assert result.update["sandbox"] == {"sandbox_id": "new-sandbox"}
|
|
messages = result.update["messages"]
|
|
assert len(messages) == 1
|
|
assert messages[0].content == "ok"
|
|
assert messages[0].tool_call_id == "call-1"
|
|
|
|
|
|
def test_wrap_tool_call_passthrough_when_sandbox_already_in_state() -> None:
|
|
middleware = SandboxMiddleware()
|
|
state: dict = {"sandbox": {"sandbox_id": "existing"}}
|
|
request = _make_tool_call_request(state)
|
|
original = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
|
|
|
def handler(req: ToolCallRequest) -> ToolMessage:
|
|
return original
|
|
|
|
result = middleware.wrap_tool_call(request, handler)
|
|
|
|
assert result is original
|
|
|
|
|
|
def test_wrap_tool_call_passthrough_when_handler_did_not_initialize_sandbox() -> None:
|
|
middleware = SandboxMiddleware()
|
|
state: dict = {}
|
|
request = _make_tool_call_request(state)
|
|
original = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
|
|
|
def handler(req: ToolCallRequest) -> ToolMessage:
|
|
return original
|
|
|
|
result = middleware.wrap_tool_call(request, handler)
|
|
|
|
assert result is original
|
|
|
|
|
|
def test_wrap_tool_call_merges_with_existing_command_update() -> None:
|
|
middleware = SandboxMiddleware()
|
|
state: dict = {}
|
|
request = _make_tool_call_request(state)
|
|
tool_msg = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
|
|
|
def handler(req: ToolCallRequest) -> Command:
|
|
req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"}
|
|
return Command(
|
|
update={
|
|
"messages": [tool_msg],
|
|
"viewed_images": {"a.png": {"base64": "x", "mime_type": "image/png"}},
|
|
},
|
|
goto="next-node",
|
|
)
|
|
|
|
result = middleware.wrap_tool_call(request, handler)
|
|
|
|
assert isinstance(result, Command)
|
|
assert result.goto == "next-node"
|
|
assert isinstance(result.update, dict)
|
|
assert result.update["messages"] == [tool_msg]
|
|
assert result.update["viewed_images"] == {"a.png": {"base64": "x", "mime_type": "image/png"}}
|
|
assert result.update["sandbox"] == {"sandbox_id": "new-sandbox"}
|
|
|
|
|
|
def test_wrap_tool_call_does_not_override_non_dict_update() -> None:
|
|
middleware = SandboxMiddleware()
|
|
state: dict = {}
|
|
request = _make_tool_call_request(state)
|
|
cmd = Command(update=[("messages", [ToolMessage(content="x", tool_call_id="c", name="bash")])])
|
|
|
|
def handler(req: ToolCallRequest) -> Command:
|
|
req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"}
|
|
return cmd
|
|
|
|
result = middleware.wrap_tool_call(request, handler)
|
|
|
|
# Non-dict update is left untouched to avoid silent data loss.
|
|
assert result is cmd
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_awrap_tool_call_emits_command_when_lazy_init_happens() -> None:
|
|
middleware = SandboxMiddleware()
|
|
state: dict = {}
|
|
request = _make_tool_call_request(state)
|
|
|
|
async def handler(req: ToolCallRequest) -> ToolMessage:
|
|
req.runtime.state["sandbox"] = {"sandbox_id": "async-new"}
|
|
return ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
|
|
|
result = await middleware.awrap_tool_call(request, handler)
|
|
|
|
assert isinstance(result, Command)
|
|
assert isinstance(result.update, dict)
|
|
assert result.update["sandbox"] == {"sandbox_id": "async-new"}
|
|
messages = result.update["messages"]
|
|
assert len(messages) == 1
|
|
assert messages[0].content == "ok"
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_awrap_tool_call_passthrough_when_sandbox_already_in_state() -> None:
|
|
middleware = SandboxMiddleware()
|
|
state: dict = {"sandbox": {"sandbox_id": "existing"}}
|
|
request = _make_tool_call_request(state)
|
|
original = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
|
|
|
async def handler(req: ToolCallRequest) -> ToolMessage:
|
|
return original
|
|
|
|
result = await middleware.awrap_tool_call(request, handler)
|
|
|
|
assert result is original
|
|
|
|
|
|
def test_wrap_tool_call_preserves_existing_command_fields_when_merging() -> None:
|
|
"""Regression: when merging sandbox_update into an existing Command,
|
|
all other Command fields (e.g. graph, goto, resume) must be preserved.
|
|
"""
|
|
middleware = SandboxMiddleware()
|
|
state: dict = {}
|
|
request = _make_tool_call_request(state)
|
|
|
|
def handler(req: ToolCallRequest) -> Command:
|
|
req.runtime.state["sandbox"] = {"sandbox_id": "sbx-merge"}
|
|
return Command(
|
|
update={"existing_key": "existing_value"},
|
|
graph="parent",
|
|
goto="next_node",
|
|
resume="resume-token",
|
|
)
|
|
|
|
result = middleware.wrap_tool_call(request, handler)
|
|
|
|
assert isinstance(result, Command)
|
|
assert result.update == {
|
|
"existing_key": "existing_value",
|
|
"sandbox": {"sandbox_id": "sbx-merge"},
|
|
}
|
|
# Critical: other Command fields must NOT be dropped by the merge.
|
|
assert result.graph == "parent"
|
|
assert result.goto == "next_node"
|
|
assert result.resume == "resume-token"
|