mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-17 13:05:58 +00:00
f212da9f89
* fix(sandbox): create shell session before retrying on a fresh id The AIO sandbox recovery path generated a UUID and passed it straight to exec_command(id=...). The sandbox image only auto-creates a session when exec_command is called with *no* id; an exec carrying an unknown id returns HTTP 404 "Session not found". So every ErrorObservation recovery itself 404'd, turning a transient session lapse into an unrecoverable tool error that looped the run up to the LangGraph recursion limit. Explicitly create_session(id=fresh_id) before targeting that id on retry. create_session is idempotent (returns the existing session if the id already exists), so this is safe under the serializing lock. Updated the regression test to assert the retry targets exactly the created session id rather than a fabricated, uncreated one. * fix(sandbox): release the one-shot recovery session after retry The fresh session created on the ErrorObservation recovery path is used for exactly one command -- the next execute_command runs with no id and returns to the default session. Under persistent session corruption every command would create another session that is never reused or released, accumulating sessions on the container. Release it best-effort with cleanup_session() in a finally, swallowing any cleanup error so it never masks a successful retry. Addresses review feedback on #3577. --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
350 lines
15 KiB
Python
350 lines
15 KiB
Python
import base64
|
|
import errno
|
|
import logging
|
|
import shlex
|
|
import threading
|
|
import uuid
|
|
|
|
from agent_sandbox import Sandbox as AioSandboxClient
|
|
|
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
|
|
from deerflow.sandbox.sandbox import Sandbox
|
|
from deerflow.sandbox.search import GrepMatch, path_matches, should_ignore_path, truncate_line
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_MAX_DOWNLOAD_SIZE = 100 * 1024 * 1024 # 100 MB
|
|
|
|
_ERROR_OBSERVATION_SIGNATURE = "'ErrorObservation' object has no attribute 'exit_code'"
|
|
|
|
|
|
class AioSandbox(Sandbox):
|
|
"""Sandbox implementation using the agent-infra/sandbox Docker container.
|
|
|
|
This sandbox connects to a running AIO sandbox container via HTTP API.
|
|
A threading lock serializes shell commands to prevent concurrent requests
|
|
from corrupting the container's single persistent session (see #1433).
|
|
"""
|
|
|
|
def __init__(self, id: str, base_url: str, home_dir: str | None = None):
|
|
"""Initialize the AIO sandbox.
|
|
|
|
Args:
|
|
id: Unique identifier for this sandbox instance.
|
|
base_url: URL of the sandbox API (e.g., http://localhost:8080).
|
|
home_dir: Home directory inside the sandbox. If None, will be fetched from the sandbox.
|
|
"""
|
|
super().__init__(id)
|
|
self._base_url = base_url
|
|
self._client = AioSandboxClient(base_url=base_url, timeout=600)
|
|
self._home_dir = home_dir
|
|
self._lock = threading.Lock()
|
|
self._closed = False
|
|
|
|
@property
|
|
def base_url(self) -> str:
|
|
return self._base_url
|
|
|
|
def close(self) -> None:
|
|
"""Best-effort close of the host-side HTTP client owned by this sandbox.
|
|
|
|
The agent_sandbox SDK is Fern-generated and exposes no ``close()`` /
|
|
``__exit__``, so we reach the socket-owning ``httpx.Client`` explicitly
|
|
through its attribute chain::
|
|
|
|
Sandbox._client_wrapper -> SyncClientWrapper
|
|
.httpx_client -> Fern HttpClient (a wrapper, NOT httpx.Client)
|
|
.httpx_client -> httpx.Client <- the real socket owner
|
|
|
|
Closing it releases pooled sockets so long-running provider lifecycles
|
|
do not accumulate unreclaimed host-side resources (#2872).
|
|
|
|
Resolution is most-specific-first with graceful degradation: if a future
|
|
SDK adds a top-level ``Sandbox.close()`` it is picked up automatically
|
|
without changing this code. Idempotent, thread-safe, and non-fatal:
|
|
failures during teardown are logged and swallowed so provider/backend
|
|
cleanup is never blocked.
|
|
"""
|
|
with self._lock:
|
|
if self._closed:
|
|
return
|
|
self._closed = True
|
|
client = self._client
|
|
# Drop the reference under the lock for use-after-close safety: any
|
|
# later command on this instance fails loudly instead of reusing a
|
|
# half-closed client.
|
|
self._client = None
|
|
|
|
if client is None:
|
|
return
|
|
|
|
# Walk from the real httpx.Client up to the top-level client, picking the
|
|
# first object that actually exposes close().
|
|
wrapper = getattr(client, "_client_wrapper", None)
|
|
fern_http = getattr(wrapper, "httpx_client", None)
|
|
real_httpx = getattr(fern_http, "httpx_client", None)
|
|
target = next(
|
|
(c for c in (real_httpx, fern_http, client) if c is not None and hasattr(c, "close")),
|
|
None,
|
|
)
|
|
if target is None:
|
|
logger.debug("AioSandbox %s: no closable client found, nothing to release", self.id)
|
|
return
|
|
|
|
try:
|
|
target.close()
|
|
except Exception as e:
|
|
logger.warning(f"Error closing AioSandbox client for {self.id}: {e}")
|
|
|
|
@property
|
|
def home_dir(self) -> str:
|
|
"""Get the home directory inside the sandbox."""
|
|
if self._home_dir is None:
|
|
context = self._client.sandbox.get_context()
|
|
self._home_dir = context.home_dir
|
|
return self._home_dir
|
|
|
|
# Default no_change_timeout for exec_command (seconds). Matches the
|
|
# client-level timeout so that long-running commands which produce no
|
|
# output are not prematurely terminated by the sandbox's built-in 120 s
|
|
# default.
|
|
_DEFAULT_NO_CHANGE_TIMEOUT = 600
|
|
|
|
def execute_command(self, command: str) -> str:
|
|
"""Execute a shell command in the sandbox.
|
|
|
|
Uses a lock to serialize concurrent requests. The AIO sandbox
|
|
container maintains a single persistent shell session that
|
|
corrupts when hit with concurrent exec_command calls (returns
|
|
``ErrorObservation`` instead of real output). If corruption is
|
|
detected despite the lock (e.g. multiple processes sharing a
|
|
sandbox), the command is retried on a fresh session.
|
|
|
|
Args:
|
|
command: The command to execute.
|
|
|
|
Returns:
|
|
The output of the command.
|
|
"""
|
|
with self._lock:
|
|
try:
|
|
result = self._client.shell.exec_command(command=command, no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT)
|
|
output = result.data.output if result.data else ""
|
|
|
|
if output and _ERROR_OBSERVATION_SIGNATURE in output:
|
|
logger.warning("ErrorObservation detected in sandbox output, retrying on a fresh session")
|
|
# exec_command only auto-creates a session when called with
|
|
# no id, so the recovery session must be created explicitly
|
|
# before we target it on retry.
|
|
fresh_id = str(uuid.uuid4())
|
|
self._client.shell.create_session(id=fresh_id)
|
|
try:
|
|
result = self._client.shell.exec_command(command=command, id=fresh_id, no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT)
|
|
output = result.data.output if result.data else ""
|
|
finally:
|
|
# Release the one-shot recovery session, best-effort, so
|
|
# repeated corruption can't accumulate sessions.
|
|
try:
|
|
self._client.shell.cleanup_session(fresh_id)
|
|
except Exception as cleanup_error:
|
|
logger.warning(f"Failed to release recovery session {fresh_id}: {cleanup_error}")
|
|
|
|
return output if output else "(no output)"
|
|
except Exception as e:
|
|
logger.error(f"Failed to execute command in sandbox: {e}")
|
|
return f"Error: {e}"
|
|
|
|
def read_file(self, path: str) -> str:
|
|
"""Read the content of a file in the sandbox.
|
|
|
|
Args:
|
|
path: The absolute path of the file to read.
|
|
|
|
Returns:
|
|
The content of the file.
|
|
"""
|
|
try:
|
|
result = self._client.file.read_file(file=path)
|
|
return result.data.content if result.data else ""
|
|
except Exception as e:
|
|
logger.error(f"Failed to read file in sandbox: {e}")
|
|
return f"Error: {e}"
|
|
|
|
def download_file(self, path: str) -> bytes:
|
|
"""Download file bytes from the sandbox.
|
|
|
|
Raises:
|
|
PermissionError: If the path contains '..' traversal segments or is
|
|
outside ``VIRTUAL_PATH_PREFIX``.
|
|
OSError: If the file cannot be retrieved from the sandbox.
|
|
"""
|
|
# Reject path traversal before sending to the container API.
|
|
# LocalSandbox gets this implicitly via _resolve_path;
|
|
# here the path is forwarded verbatim so we must check explicitly.
|
|
normalised = path.replace("\\", "/")
|
|
for segment in normalised.split("/"):
|
|
if segment == "..":
|
|
logger.error(f"Refused download due to path traversal: {path}")
|
|
raise PermissionError(f"Access denied: path traversal detected in '{path}'")
|
|
|
|
stripped_path = normalised.lstrip("/")
|
|
allowed_prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
|
|
if stripped_path != allowed_prefix and not stripped_path.startswith(f"{allowed_prefix}/"):
|
|
logger.error("Refused download outside allowed directory: path=%s, allowed_prefix=%s", path, VIRTUAL_PATH_PREFIX)
|
|
raise PermissionError(f"Access denied: path must be under '{VIRTUAL_PATH_PREFIX}': '{path}'")
|
|
|
|
with self._lock:
|
|
try:
|
|
chunks: list[bytes] = []
|
|
total = 0
|
|
for chunk in self._client.file.download_file(path=path):
|
|
total += len(chunk)
|
|
if total > _MAX_DOWNLOAD_SIZE:
|
|
raise OSError(
|
|
errno.EFBIG,
|
|
f"File exceeds maximum download size of {_MAX_DOWNLOAD_SIZE} bytes",
|
|
path,
|
|
)
|
|
chunks.append(chunk)
|
|
return b"".join(chunks)
|
|
except OSError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Failed to download file in sandbox: {e}")
|
|
raise OSError(f"Failed to download file '{path}' from sandbox: {e}") from e
|
|
|
|
def list_dir(self, path: str, max_depth: int = 2) -> list[str]:
|
|
"""List the contents of a directory in the sandbox.
|
|
|
|
Args:
|
|
path: The absolute path of the directory to list.
|
|
max_depth: The maximum depth to traverse. Default is 2.
|
|
|
|
Returns:
|
|
The contents of the directory.
|
|
"""
|
|
with self._lock:
|
|
try:
|
|
result = self._client.shell.exec_command(command=f"find {shlex.quote(path)} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500", no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT)
|
|
output = result.data.output if result.data else ""
|
|
if output:
|
|
return [line.strip() for line in output.strip().split("\n") if line.strip()]
|
|
return []
|
|
except Exception as e:
|
|
logger.error(f"Failed to list directory in sandbox: {e}")
|
|
return []
|
|
|
|
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
|
"""Write content to a file in the sandbox.
|
|
|
|
Args:
|
|
path: The absolute path of the file to write to.
|
|
content: The text content to write to the file.
|
|
append: Whether to append the content to the file.
|
|
"""
|
|
with self._lock:
|
|
try:
|
|
if append:
|
|
existing = self.read_file(path)
|
|
if not existing.startswith("Error:"):
|
|
content = existing + content
|
|
self._client.file.write_file(file=path, content=content)
|
|
except Exception as e:
|
|
logger.error(f"Failed to write file in sandbox: {e}")
|
|
raise
|
|
|
|
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
|
if not include_dirs:
|
|
result = self._client.file.find_files(path=path, glob=pattern)
|
|
files = result.data.files if result.data and result.data.files else []
|
|
filtered = [file_path for file_path in files if not should_ignore_path(file_path)]
|
|
truncated = len(filtered) > max_results
|
|
return filtered[:max_results], truncated
|
|
|
|
result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
|
|
entries = result.data.files if result.data and result.data.files else []
|
|
matches: list[str] = []
|
|
root_path = path.rstrip("/") or "/"
|
|
root_prefix = root_path if root_path == "/" else f"{root_path}/"
|
|
for entry in entries:
|
|
if entry.path != root_path and not entry.path.startswith(root_prefix):
|
|
continue
|
|
if should_ignore_path(entry.path):
|
|
continue
|
|
rel_path = entry.path[len(root_path) :].lstrip("/")
|
|
if path_matches(pattern, rel_path):
|
|
matches.append(entry.path)
|
|
if len(matches) >= max_results:
|
|
return matches, True
|
|
return matches, False
|
|
|
|
def grep(
|
|
self,
|
|
path: str,
|
|
pattern: str,
|
|
*,
|
|
glob: str | None = None,
|
|
literal: bool = False,
|
|
case_sensitive: bool = False,
|
|
max_results: int = 100,
|
|
) -> tuple[list[GrepMatch], bool]:
|
|
import re as _re
|
|
|
|
regex_source = _re.escape(pattern) if literal else pattern
|
|
# Validate the pattern locally so an invalid regex raises re.error
|
|
# (caught by grep_tool's except re.error handler) rather than a
|
|
# generic remote API error.
|
|
_re.compile(regex_source, 0 if case_sensitive else _re.IGNORECASE)
|
|
regex = regex_source if case_sensitive else f"(?i){regex_source}"
|
|
|
|
if glob is not None:
|
|
find_result = self._client.file.find_files(path=path, glob=glob)
|
|
candidate_paths = find_result.data.files if find_result.data and find_result.data.files else []
|
|
else:
|
|
list_result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
|
|
entries = list_result.data.files if list_result.data and list_result.data.files else []
|
|
candidate_paths = [entry.path for entry in entries if not entry.is_directory]
|
|
|
|
matches: list[GrepMatch] = []
|
|
truncated = False
|
|
|
|
for file_path in candidate_paths:
|
|
if should_ignore_path(file_path):
|
|
continue
|
|
|
|
search_result = self._client.file.search_in_file(file=file_path, regex=regex)
|
|
data = search_result.data
|
|
if data is None:
|
|
continue
|
|
|
|
line_numbers = data.line_numbers or []
|
|
matched_lines = data.matches or []
|
|
for line_number, line in zip(line_numbers, matched_lines):
|
|
matches.append(
|
|
GrepMatch(
|
|
path=file_path,
|
|
line_number=line_number if isinstance(line_number, int) else 0,
|
|
line=truncate_line(line),
|
|
)
|
|
)
|
|
if len(matches) >= max_results:
|
|
truncated = True
|
|
return matches, truncated
|
|
|
|
return matches, truncated
|
|
|
|
def update_file(self, path: str, content: bytes) -> None:
|
|
"""Update a file with binary content in the sandbox.
|
|
|
|
Args:
|
|
path: The absolute path of the file to update.
|
|
content: The binary content to write to the file.
|
|
"""
|
|
with self._lock:
|
|
try:
|
|
base64_content = base64.b64encode(content).decode("utf-8")
|
|
self._client.file.write_file(file=path, content=base64_content, encoding="base64")
|
|
except Exception as e:
|
|
logger.error(f"Failed to update file in sandbox: {e}")
|
|
raise
|