mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-12 02:15:58 +00:00
Merge remote-tracking branch 'origin/main' into codex/im-channel-connections
This commit is contained in:
@@ -69,6 +69,7 @@
|
||||
"keys": [
|
||||
"artifacts",
|
||||
"messages",
|
||||
"sandbox",
|
||||
"thread_data",
|
||||
"title",
|
||||
"viewed_images"
|
||||
@@ -79,6 +80,7 @@
|
||||
"keys": [
|
||||
"artifacts",
|
||||
"messages",
|
||||
"sandbox",
|
||||
"thread_data",
|
||||
"title",
|
||||
"viewed_images"
|
||||
@@ -89,6 +91,7 @@
|
||||
"keys": [
|
||||
"artifacts",
|
||||
"messages",
|
||||
"sandbox",
|
||||
"thread_data",
|
||||
"title",
|
||||
"viewed_images"
|
||||
@@ -99,6 +102,7 @@
|
||||
"keys": [
|
||||
"artifacts",
|
||||
"messages",
|
||||
"sandbox",
|
||||
"thread_data",
|
||||
"title",
|
||||
"viewed_images"
|
||||
@@ -109,6 +113,7 @@
|
||||
"keys": [
|
||||
"artifacts",
|
||||
"messages",
|
||||
"sandbox",
|
||||
"thread_data",
|
||||
"title",
|
||||
"viewed_images"
|
||||
@@ -119,6 +124,7 @@
|
||||
"keys": [
|
||||
"artifacts",
|
||||
"messages",
|
||||
"sandbox",
|
||||
"thread_data",
|
||||
"title",
|
||||
"viewed_images"
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.community.aio_sandbox.local_backend import (
|
||||
LocalContainerBackend,
|
||||
_format_container_command_for_log,
|
||||
@@ -234,3 +237,99 @@ def test_start_container_keeps_apple_container_port_format(monkeypatch):
|
||||
captured_cmd = _capture_start_container_command(monkeypatch, backend, runtime="container")
|
||||
|
||||
assert captured_cmd[captured_cmd.index("-p") + 1] == "18080:8080"
|
||||
|
||||
|
||||
def _backend_for_inspect_tests() -> LocalContainerBackend:
|
||||
backend = LocalContainerBackend(
|
||||
image="sandbox:latest",
|
||||
base_port=8080,
|
||||
container_prefix="sandbox",
|
||||
config_mounts=[],
|
||||
environment={},
|
||||
)
|
||||
backend._runtime = "docker"
|
||||
return backend
|
||||
|
||||
|
||||
def test_is_container_running_false_when_container_missing(monkeypatch):
|
||||
backend = _backend_for_inspect_tests()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
return SimpleNamespace(stdout="", stderr="Error: No such object: sandbox-missing", returncode=1)
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
assert backend._is_container_running("sandbox-missing") is False
|
||||
|
||||
|
||||
def test_is_container_running_raises_on_runtime_error(monkeypatch):
|
||||
backend = _backend_for_inspect_tests()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
return SimpleNamespace(stdout="", stderr="Cannot connect to the Docker daemon", returncode=1)
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to inspect container sandbox-busy"):
|
||||
backend._is_container_running("sandbox-busy")
|
||||
|
||||
|
||||
def test_is_container_running_raises_on_timeout(monkeypatch):
|
||||
backend = _backend_for_inspect_tests()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
raise subprocess.TimeoutExpired(cmd=cmd, timeout=kwargs["timeout"])
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Timed out checking container sandbox-timeout"):
|
||||
backend._is_container_running("sandbox-timeout")
|
||||
|
||||
|
||||
def test_discover_returns_none_when_runtime_check_fails(monkeypatch):
|
||||
"""A transient daemon error during discovery must fall through to create, not fail acquire."""
|
||||
backend = _backend_for_inspect_tests()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
return SimpleNamespace(stdout="", stderr="Cannot connect to the Docker daemon", returncode=1)
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
assert backend.discover("sandbox-blip") is None
|
||||
|
||||
|
||||
def test_discover_returns_none_when_runtime_check_times_out(monkeypatch):
|
||||
"""An inspect timeout during discovery must not propagate out of discover()."""
|
||||
backend = _backend_for_inspect_tests()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
raise subprocess.TimeoutExpired(cmd=cmd, timeout=kwargs["timeout"])
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
assert backend.discover("sandbox-timeout") is None
|
||||
|
||||
|
||||
def test_is_container_running_false_on_apple_container_not_found(monkeypatch):
|
||||
"""Apple Container's generic "not found" is trusted when it names the container."""
|
||||
backend = _backend_for_inspect_tests()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
return SimpleNamespace(stdout="", stderr='Error: not found: "sandbox-apple"', returncode=1)
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
assert backend._is_container_running("sandbox-apple") is False
|
||||
|
||||
|
||||
def test_is_container_running_raises_on_unrelated_not_found_error(monkeypatch):
|
||||
"""Transient errors whose text contains "not found" must not be misread as a dead container."""
|
||||
backend = _backend_for_inspect_tests()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
return SimpleNamespace(stdout="", stderr="Error: credential helper not found in $PATH", returncode=1)
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to inspect container sandbox-busy"):
|
||||
backend._is_container_running("sandbox-busy")
|
||||
|
||||
@@ -317,6 +317,28 @@ async def test_acquire_async_cancelled_waiter_does_not_block_successor(tmp_path,
|
||||
pytest.fail("provider thread lock was not released after successor acquire_async")
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_acquire_internal_async_offloads_cached_reuse_health_check(tmp_path, monkeypatch):
|
||||
"""Async cached reuse must keep backend health checks off the event loop."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
provider, _sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-cached-async")
|
||||
provider._thread_sandboxes = {"thread-cached-async": "sandbox-cached-async"}
|
||||
provider._backend.is_alive = MagicMock(return_value=True)
|
||||
|
||||
to_thread_calls: list[tuple[object, tuple[object, ...]]] = []
|
||||
|
||||
async def fake_to_thread(func, /, *args, **kwargs):
|
||||
to_thread_calls.append((func, args))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(aio_mod.asyncio, "to_thread", fake_to_thread)
|
||||
|
||||
sandbox_id = await provider._acquire_internal_async("thread-cached-async")
|
||||
|
||||
assert sandbox_id == "sandbox-cached-async"
|
||||
assert to_thread_calls == [(provider._reuse_in_process_sandbox, ("thread-cached-async",))]
|
||||
|
||||
|
||||
def test_remote_backend_create_forwards_effective_user_id(monkeypatch):
|
||||
"""Provisioner mode must receive user_id so PVC subPath matches user isolation."""
|
||||
remote_mod = importlib.import_module("deerflow.community.aio_sandbox.remote_backend")
|
||||
@@ -424,6 +446,136 @@ def test_release_swallows_close_errors(tmp_path, caplog):
|
||||
assert "sandbox-rel-err" in provider._warm_pool
|
||||
|
||||
|
||||
def test_get_uses_in_memory_registry_only(tmp_path):
|
||||
"""get() must stay event-loop safe by avoiding backend health checks."""
|
||||
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-dead")
|
||||
provider._backend.is_alive = MagicMock(side_effect=AssertionError("get must not call backend health checks"))
|
||||
|
||||
assert provider.get("sandbox-dead") is sandbox
|
||||
|
||||
|
||||
def test_acquire_drops_dead_cached_sandbox(tmp_path, monkeypatch):
|
||||
"""acquire() must replace a stale active cache entry after its container dies."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-dead")
|
||||
provider._thread_locks = {}
|
||||
provider._thread_sandboxes = {"thread-dead": "sandbox-dead"}
|
||||
provider._config = {"replicas": 3}
|
||||
provider._backend.is_alive = MagicMock(return_value=False)
|
||||
provider._backend.discover = MagicMock(return_value=None)
|
||||
provider._backend.create = MagicMock(
|
||||
return_value=aio_mod.SandboxInfo(
|
||||
sandbox_id="sandbox-dead",
|
||||
sandbox_url="http://fresh-sandbox",
|
||||
container_name="deer-flow-sandbox-sandbox-dead",
|
||||
)
|
||||
)
|
||||
|
||||
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_sandbox_id_for_thread", lambda _self, _thread_id: "sandbox-dead")
|
||||
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_get_extra_mounts", lambda _self, _thread_id: [])
|
||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
|
||||
monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready", lambda _url, timeout=60: True)
|
||||
|
||||
sandbox_id = provider.acquire("thread-dead")
|
||||
|
||||
assert sandbox_id == "sandbox-dead"
|
||||
sandbox.close.assert_called_once_with()
|
||||
provider._backend.destroy.assert_called_once()
|
||||
provider._backend.create.assert_called_once()
|
||||
assert provider._thread_sandboxes["thread-dead"] == "sandbox-dead"
|
||||
assert provider._sandboxes["sandbox-dead"].base_url == "http://fresh-sandbox"
|
||||
|
||||
|
||||
def test_acquire_keeps_cached_sandbox_when_health_check_errors(tmp_path):
|
||||
"""Transient backend health-check errors must not destroy a tracked sandbox."""
|
||||
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-transient")
|
||||
provider._thread_locks = {}
|
||||
provider._thread_sandboxes = {"thread-transient": "sandbox-transient"}
|
||||
provider._backend.is_alive = MagicMock(side_effect=OSError("docker daemon busy"))
|
||||
|
||||
sandbox_id = provider.acquire("thread-transient")
|
||||
|
||||
assert sandbox_id == "sandbox-transient"
|
||||
sandbox.close.assert_not_called()
|
||||
provider._backend.destroy.assert_not_called()
|
||||
assert provider._sandboxes["sandbox-transient"] is sandbox
|
||||
|
||||
|
||||
def test_drop_unhealthy_sandbox_skips_recreated_entry(tmp_path):
|
||||
"""A stale health-check result must not delete a newly registered sandbox."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
provider = _make_provider(tmp_path)
|
||||
provider._lock = aio_mod.threading.Lock()
|
||||
provider._warm_pool = {}
|
||||
provider._last_activity = {"sandbox-toctou": 1.0}
|
||||
provider._thread_sandboxes = {"thread-toctou": "sandbox-toctou"}
|
||||
old_info = aio_mod.SandboxInfo(sandbox_id="sandbox-toctou", sandbox_url="http://old-sandbox")
|
||||
new_info = aio_mod.SandboxInfo(sandbox_id="sandbox-toctou", sandbox_url="http://new-sandbox")
|
||||
new_sandbox = MagicMock()
|
||||
provider._sandbox_infos = {"sandbox-toctou": new_info}
|
||||
provider._sandboxes = {"sandbox-toctou": new_sandbox}
|
||||
provider._backend = SimpleNamespace(destroy=MagicMock())
|
||||
|
||||
provider._drop_unhealthy_sandbox("sandbox-toctou", "stale health check", expected_info=old_info)
|
||||
|
||||
new_sandbox.close.assert_not_called()
|
||||
provider._backend.destroy.assert_not_called()
|
||||
assert provider._sandbox_infos["sandbox-toctou"] is new_info
|
||||
assert provider._sandboxes["sandbox-toctou"] is new_sandbox
|
||||
assert provider._thread_sandboxes == {"thread-toctou": "sandbox-toctou"}
|
||||
|
||||
|
||||
def test_acquire_skips_dead_warm_pool_sandbox(tmp_path, monkeypatch):
|
||||
"""acquire() must create a fresh sandbox when the warm-pool entry died."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
provider = _make_provider(tmp_path)
|
||||
provider._lock = aio_mod.threading.Lock()
|
||||
provider._thread_locks = {}
|
||||
provider._sandboxes = {}
|
||||
provider._sandbox_infos = {}
|
||||
provider._thread_sandboxes = {}
|
||||
provider._last_activity = {}
|
||||
provider._warm_pool = {
|
||||
"sandbox-warm-dead": (
|
||||
aio_mod.SandboxInfo(
|
||||
sandbox_id="sandbox-warm-dead",
|
||||
sandbox_url="http://stale-sandbox",
|
||||
container_name="deer-flow-sandbox-sandbox-warm-dead",
|
||||
),
|
||||
0.0,
|
||||
)
|
||||
}
|
||||
provider._config = {"replicas": 3}
|
||||
provider._backend = SimpleNamespace(
|
||||
is_alive=MagicMock(return_value=False),
|
||||
destroy=MagicMock(),
|
||||
discover=MagicMock(return_value=None),
|
||||
create=MagicMock(
|
||||
return_value=aio_mod.SandboxInfo(
|
||||
sandbox_id="sandbox-warm-dead",
|
||||
sandbox_url="http://fresh-sandbox",
|
||||
container_name="deer-flow-sandbox-sandbox-warm-dead",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_sandbox_id_for_thread", lambda _self, _thread_id: "sandbox-warm-dead")
|
||||
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_get_extra_mounts", lambda _self, _thread_id: [])
|
||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
|
||||
monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready", lambda _url, timeout=60: True)
|
||||
|
||||
sandbox_id = provider.acquire("thread-warm-dead")
|
||||
|
||||
assert sandbox_id == "sandbox-warm-dead"
|
||||
provider._backend.destroy.assert_called_once()
|
||||
provider._backend.create.assert_called_once()
|
||||
assert provider._warm_pool == {}
|
||||
assert provider._thread_sandboxes["thread-warm-dead"] == "sandbox-warm-dead"
|
||||
assert provider._sandboxes["sandbox-warm-dead"].base_url == "http://fresh-sandbox"
|
||||
|
||||
|
||||
def test_destroy_swallows_close_errors_and_still_destroys_backend(tmp_path, caplog):
|
||||
"""A failure in sandbox.close() must not skip backend container destruction."""
|
||||
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-dest-err")
|
||||
|
||||
@@ -257,14 +257,38 @@ def test_provisioner_is_alive_true_only_when_status_running(monkeypatch):
|
||||
assert backend._provisioner_is_alive("abc123") is False
|
||||
|
||||
|
||||
def test_provisioner_is_alive_returns_false_on_request_exception(monkeypatch):
|
||||
def test_provisioner_is_alive_returns_false_on_404(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
return _StubResponse(status_code=404)
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
assert backend._provisioner_is_alive("abc123") is False
|
||||
|
||||
|
||||
def test_provisioner_is_alive_raises_on_request_exception(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
raise requests.RequestException("boom")
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
assert backend._provisioner_is_alive("abc123") is False
|
||||
with pytest.raises(RuntimeError, match="Provisioner health check failed for abc123"):
|
||||
backend._provisioner_is_alive("abc123")
|
||||
|
||||
|
||||
def test_provisioner_is_alive_raises_on_server_error(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_get(url: str, timeout: int):
|
||||
response = _StubResponse(status_code=503)
|
||||
response.text = "unavailable"
|
||||
return response
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
with pytest.raises(RuntimeError, match="HTTP 503 unavailable"):
|
||||
backend._provisioner_is_alive("abc123")
|
||||
|
||||
|
||||
def test_discover_delegates_to_provisioner_discover(monkeypatch):
|
||||
|
||||
@@ -5,7 +5,10 @@ import asyncio
|
||||
import pytest
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.sandbox.middleware import SandboxMiddleware
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
@@ -223,3 +226,183 @@ async def test_aafter_agent_delegates_to_super_when_no_sandbox(monkeypatch: pyte
|
||||
|
||||
assert result == {"delegated": True}
|
||||
assert calls == [(state, runtime)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wrap_tool_call / awrap_tool_call: persistent sandbox state via Command
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_tool_call_request(state: dict) -> ToolCallRequest:
|
||||
"""Build a minimal ToolCallRequest backed by a real ToolRuntime."""
|
||||
runtime = ToolRuntime(
|
||||
state=state,
|
||||
context={},
|
||||
config={"configurable": {}},
|
||||
stream_writer=lambda _: None,
|
||||
tools=[],
|
||||
tool_call_id="call-1",
|
||||
store=None,
|
||||
)
|
||||
return ToolCallRequest(
|
||||
tool_call={"id": "call-1", "name": "bash", "args": {}},
|
||||
tool=None,
|
||||
state=state,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
|
||||
def test_wrap_tool_call_emits_command_when_lazy_init_happens() -> None:
|
||||
middleware = SandboxMiddleware()
|
||||
state: dict = {}
|
||||
request = _make_tool_call_request(state)
|
||||
|
||||
def handler(req: ToolCallRequest) -> ToolMessage:
|
||||
# Simulate ensure_sandbox_initialized() mutating runtime.state in-place.
|
||||
req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"}
|
||||
return ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
|
||||
assert isinstance(result, Command)
|
||||
assert isinstance(result.update, dict)
|
||||
assert result.update["sandbox"] == {"sandbox_id": "new-sandbox"}
|
||||
messages = result.update["messages"]
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content == "ok"
|
||||
assert messages[0].tool_call_id == "call-1"
|
||||
|
||||
|
||||
def test_wrap_tool_call_passthrough_when_sandbox_already_in_state() -> None:
|
||||
middleware = SandboxMiddleware()
|
||||
state: dict = {"sandbox": {"sandbox_id": "existing"}}
|
||||
request = _make_tool_call_request(state)
|
||||
original = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
||||
|
||||
def handler(req: ToolCallRequest) -> ToolMessage:
|
||||
return original
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
|
||||
assert result is original
|
||||
|
||||
|
||||
def test_wrap_tool_call_passthrough_when_handler_did_not_initialize_sandbox() -> None:
|
||||
middleware = SandboxMiddleware()
|
||||
state: dict = {}
|
||||
request = _make_tool_call_request(state)
|
||||
original = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
||||
|
||||
def handler(req: ToolCallRequest) -> ToolMessage:
|
||||
return original
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
|
||||
assert result is original
|
||||
|
||||
|
||||
def test_wrap_tool_call_merges_with_existing_command_update() -> None:
|
||||
middleware = SandboxMiddleware()
|
||||
state: dict = {}
|
||||
request = _make_tool_call_request(state)
|
||||
tool_msg = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
||||
|
||||
def handler(req: ToolCallRequest) -> Command:
|
||||
req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"}
|
||||
return Command(
|
||||
update={
|
||||
"messages": [tool_msg],
|
||||
"viewed_images": {"a.png": {"base64": "x", "mime_type": "image/png"}},
|
||||
},
|
||||
goto="next-node",
|
||||
)
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
|
||||
assert isinstance(result, Command)
|
||||
assert result.goto == "next-node"
|
||||
assert isinstance(result.update, dict)
|
||||
assert result.update["messages"] == [tool_msg]
|
||||
assert result.update["viewed_images"] == {"a.png": {"base64": "x", "mime_type": "image/png"}}
|
||||
assert result.update["sandbox"] == {"sandbox_id": "new-sandbox"}
|
||||
|
||||
|
||||
def test_wrap_tool_call_does_not_override_non_dict_update() -> None:
|
||||
middleware = SandboxMiddleware()
|
||||
state: dict = {}
|
||||
request = _make_tool_call_request(state)
|
||||
cmd = Command(update=[("messages", [ToolMessage(content="x", tool_call_id="c", name="bash")])])
|
||||
|
||||
def handler(req: ToolCallRequest) -> Command:
|
||||
req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"}
|
||||
return cmd
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
|
||||
# Non-dict update is left untouched to avoid silent data loss.
|
||||
assert result is cmd
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_awrap_tool_call_emits_command_when_lazy_init_happens() -> None:
|
||||
middleware = SandboxMiddleware()
|
||||
state: dict = {}
|
||||
request = _make_tool_call_request(state)
|
||||
|
||||
async def handler(req: ToolCallRequest) -> ToolMessage:
|
||||
req.runtime.state["sandbox"] = {"sandbox_id": "async-new"}
|
||||
return ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
||||
|
||||
result = await middleware.awrap_tool_call(request, handler)
|
||||
|
||||
assert isinstance(result, Command)
|
||||
assert isinstance(result.update, dict)
|
||||
assert result.update["sandbox"] == {"sandbox_id": "async-new"}
|
||||
messages = result.update["messages"]
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content == "ok"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_awrap_tool_call_passthrough_when_sandbox_already_in_state() -> None:
|
||||
middleware = SandboxMiddleware()
|
||||
state: dict = {"sandbox": {"sandbox_id": "existing"}}
|
||||
request = _make_tool_call_request(state)
|
||||
original = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
||||
|
||||
async def handler(req: ToolCallRequest) -> ToolMessage:
|
||||
return original
|
||||
|
||||
result = await middleware.awrap_tool_call(request, handler)
|
||||
|
||||
assert result is original
|
||||
|
||||
|
||||
def test_wrap_tool_call_preserves_existing_command_fields_when_merging() -> None:
|
||||
"""Regression: when merging sandbox_update into an existing Command,
|
||||
all other Command fields (e.g. graph, goto, resume) must be preserved.
|
||||
"""
|
||||
middleware = SandboxMiddleware()
|
||||
state: dict = {}
|
||||
request = _make_tool_call_request(state)
|
||||
|
||||
def handler(req: ToolCallRequest) -> Command:
|
||||
req.runtime.state["sandbox"] = {"sandbox_id": "sbx-merge"}
|
||||
return Command(
|
||||
update={"existing_key": "existing_value"},
|
||||
graph="parent",
|
||||
goto="next_node",
|
||||
resume="resume-token",
|
||||
)
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
|
||||
assert isinstance(result, Command)
|
||||
assert result.update == {
|
||||
"existing_key": "existing_value",
|
||||
"sandbox": {"sandbox_id": "sbx-merge"},
|
||||
}
|
||||
# Critical: other Command fields must NOT be dropped by the merge.
|
||||
assert result.graph == "parent"
|
||||
assert result.goto == "next_node"
|
||||
assert result.resume == "resume-token"
|
||||
|
||||
Reference in New Issue
Block a user