mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-25 01:15:58 +00:00
Merge branch 'main' into release/2.0-rc
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Middleware for intercepting clarification requests and presenting them to the user."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import override
|
||||
@@ -60,6 +61,20 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
|
||||
context = args.get("context")
|
||||
options = args.get("options", [])
|
||||
|
||||
# Some models (e.g. Qwen3-Max) serialize array parameters as JSON strings
|
||||
# instead of native arrays. Deserialize and normalize so `options`
|
||||
# is always a list for the rendering logic below.
|
||||
if isinstance(options, str):
|
||||
try:
|
||||
options = json.loads(options)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
options = [options]
|
||||
|
||||
if options is None:
|
||||
options = []
|
||||
elif not isinstance(options, list):
|
||||
options = [options]
|
||||
|
||||
# Type-specific icons
|
||||
type_icons = {
|
||||
"missing_info": "❓",
|
||||
|
||||
@@ -33,30 +33,92 @@ _DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
|
||||
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
|
||||
|
||||
|
||||
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
|
||||
"""Normalize tool call args to a dict plus an optional fallback key.
|
||||
|
||||
Some providers serialize ``args`` as a JSON string instead of a dict.
|
||||
We defensively parse those cases so loop detection does not crash while
|
||||
still preserving a stable fallback key for non-dict payloads.
|
||||
"""
|
||||
if isinstance(raw_args, dict):
|
||||
return raw_args, None
|
||||
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
parsed = json.loads(raw_args)
|
||||
except (TypeError, ValueError, json.JSONDecodeError):
|
||||
return {}, raw_args
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
return parsed, None
|
||||
return {}, json.dumps(parsed, sort_keys=True, default=str)
|
||||
|
||||
if raw_args is None:
|
||||
return {}, None
|
||||
|
||||
return {}, json.dumps(raw_args, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def _stable_tool_key(name: str, args: dict, fallback_key: str | None) -> str:
|
||||
"""Derive a stable key from salient args without overfitting to noise."""
|
||||
if name == "read_file" and fallback_key is None:
|
||||
path = args.get("path") or ""
|
||||
start_line = args.get("start_line")
|
||||
end_line = args.get("end_line")
|
||||
|
||||
bucket_size = 200
|
||||
try:
|
||||
start_line = int(start_line) if start_line is not None else 1
|
||||
except (TypeError, ValueError):
|
||||
start_line = 1
|
||||
try:
|
||||
end_line = int(end_line) if end_line is not None else start_line
|
||||
except (TypeError, ValueError):
|
||||
end_line = start_line
|
||||
|
||||
start_line, end_line = sorted((start_line, end_line))
|
||||
bucket_start = max(start_line, 1)
|
||||
bucket_end = max(end_line, 1)
|
||||
bucket_start = (bucket_start - 1) // bucket_size
|
||||
bucket_end = (bucket_end - 1) // bucket_size
|
||||
return f"{path}:{bucket_start}-{bucket_end}"
|
||||
|
||||
# write_file / str_replace are content-sensitive: same path may be updated
|
||||
# with different payloads during iteration. Using only salient fields (path)
|
||||
# can collapse distinct calls, so we hash full args to reduce false positives.
|
||||
if name in {"write_file", "str_replace"}:
|
||||
if fallback_key is not None:
|
||||
return fallback_key
|
||||
return json.dumps(args, sort_keys=True, default=str)
|
||||
|
||||
salient_fields = ("path", "url", "query", "command", "pattern", "glob", "cmd")
|
||||
stable_args = {field: args[field] for field in salient_fields if args.get(field) is not None}
|
||||
if stable_args:
|
||||
return json.dumps(stable_args, sort_keys=True, default=str)
|
||||
|
||||
if fallback_key is not None:
|
||||
return fallback_key
|
||||
|
||||
return json.dumps(args, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def _hash_tool_calls(tool_calls: list[dict]) -> str:
|
||||
"""Deterministic hash of a set of tool calls (name + args).
|
||||
"""Deterministic hash of a set of tool calls (name + stable key).
|
||||
|
||||
This is intended to be order-independent: the same multiset of tool calls
|
||||
should always produce the same hash, regardless of their input order.
|
||||
"""
|
||||
# First normalize each tool call to a minimal (name, args) structure.
|
||||
normalized: list[dict] = []
|
||||
# Normalize each tool call to a stable (name, key) structure.
|
||||
normalized: list[str] = []
|
||||
for tc in tool_calls:
|
||||
normalized.append(
|
||||
{
|
||||
"name": tc.get("name", ""),
|
||||
"args": tc.get("args", {}),
|
||||
}
|
||||
)
|
||||
name = tc.get("name", "")
|
||||
args, fallback_key = _normalize_tool_call_args(tc.get("args", {}))
|
||||
key = _stable_tool_key(name, args, fallback_key)
|
||||
|
||||
# Sort by both name and a deterministic serialization of args so that
|
||||
# permutations of the same multiset of calls yield the same ordering.
|
||||
normalized.sort(
|
||||
key=lambda tc: (
|
||||
tc["name"],
|
||||
json.dumps(tc["args"], sort_keys=True, default=str),
|
||||
)
|
||||
)
|
||||
normalized.append(f"{name}:{key}")
|
||||
|
||||
# Sort so permutations of the same multiset of calls yield the same ordering.
|
||||
normalized.sort()
|
||||
blob = json.dumps(normalized, sort_keys=True, default=str)
|
||||
return hashlib.md5(blob.encode()).hexdigest()[:12]
|
||||
|
||||
|
||||
@@ -23,25 +23,119 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Each pattern is compiled once at import time.
|
||||
_HIGH_RISK_PATTERNS: list[re.Pattern[str]] = [
|
||||
re.compile(r"rm\s+-[^\s]*r[^\s]*\s+(/\*?|~/?\*?|/home\b|/root\b)\s*$"), # rm -rf / /* ~ /home /root
|
||||
re.compile(r"(curl|wget).+\|\s*(ba)?sh"), # curl|sh, wget|sh
|
||||
# --- original rules (retained) ---
|
||||
re.compile(r"rm\s+-[^\s]*r[^\s]*\s+(/\*?|~/?\*?|/home\b|/root\b)\s*$"),
|
||||
re.compile(r"dd\s+if="),
|
||||
re.compile(r"mkfs"),
|
||||
re.compile(r"cat\s+/etc/shadow"),
|
||||
re.compile(r">\s*/etc/"), # overwrite /etc/ files
|
||||
re.compile(r">+\s*/etc/"),
|
||||
# --- pipe to sh/bash (generalised, replaces old curl|sh rule) ---
|
||||
re.compile(r"\|\s*(ba)?sh\b"),
|
||||
# --- command substitution (targeted – only dangerous executables) ---
|
||||
re.compile(r"[`$]\(?\s*(curl|wget|bash|sh|python|ruby|perl|base64)"),
|
||||
# --- base64 decode piped to execution ---
|
||||
re.compile(r"base64\s+.*-d.*\|"),
|
||||
# --- overwrite system binaries ---
|
||||
re.compile(r">+\s*(/usr/bin/|/bin/|/sbin/)"),
|
||||
# --- overwrite shell startup files ---
|
||||
re.compile(r">+\s*~/?\.(bashrc|profile|zshrc|bash_profile)"),
|
||||
# --- process environment leakage ---
|
||||
re.compile(r"/proc/[^/]+/environ"),
|
||||
# --- dynamic linker hijack (one-step escalation) ---
|
||||
re.compile(r"\b(LD_PRELOAD|LD_LIBRARY_PATH)\s*="),
|
||||
# --- bash built-in networking (bypasses tool allowlists) ---
|
||||
re.compile(r"/dev/tcp/"),
|
||||
# --- fork bomb ---
|
||||
re.compile(r"\S+\(\)\s*\{[^}]*\|\s*\S+\s*&"), # :(){ :|:& };:
|
||||
re.compile(r"while\s+true.*&\s*done"), # while true; do bash & done
|
||||
]
|
||||
|
||||
_MEDIUM_RISK_PATTERNS: list[re.Pattern[str]] = [
|
||||
re.compile(r"chmod\s+777"), # overly permissive, but reversible
|
||||
re.compile(r"pip\s+install"),
|
||||
re.compile(r"pip3\s+install"),
|
||||
re.compile(r"chmod\s+777"),
|
||||
re.compile(r"pip3?\s+install"),
|
||||
re.compile(r"apt(-get)?\s+install"),
|
||||
# sudo/su: no-op under Docker root; warn so LLM is aware
|
||||
re.compile(r"\b(sudo|su)\b"),
|
||||
# PATH modification: long attack chain, warn rather than block
|
||||
re.compile(r"\bPATH\s*="),
|
||||
]
|
||||
|
||||
|
||||
def _classify_command(command: str) -> str:
|
||||
"""Return 'block', 'warn', or 'pass'."""
|
||||
# Normalize for matching (collapse whitespace)
|
||||
def _split_compound_command(command: str) -> list[str]:
|
||||
"""Split a compound command into sub-commands (quote-aware).
|
||||
|
||||
Scans the raw command string so unquoted shell control operators are
|
||||
recognised even when they are not surrounded by whitespace
|
||||
(e.g. ``safe;rm -rf /`` or ``rm -rf /&&echo ok``). Operators inside
|
||||
quotes are ignored. If the command ends with an unclosed quote or a
|
||||
dangling escape, return the whole command unchanged (fail-closed —
|
||||
safer to classify the unsplit string than silently drop parts).
|
||||
"""
|
||||
parts: list[str] = []
|
||||
current: list[str] = []
|
||||
in_single_quote = False
|
||||
in_double_quote = False
|
||||
escaping = False
|
||||
index = 0
|
||||
|
||||
while index < len(command):
|
||||
char = command[index]
|
||||
|
||||
if escaping:
|
||||
current.append(char)
|
||||
escaping = False
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if char == "\\" and not in_single_quote:
|
||||
current.append(char)
|
||||
escaping = True
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if char == "'" and not in_double_quote:
|
||||
in_single_quote = not in_single_quote
|
||||
current.append(char)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if char == '"' and not in_single_quote:
|
||||
in_double_quote = not in_double_quote
|
||||
current.append(char)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if not in_single_quote and not in_double_quote:
|
||||
if command.startswith("&&", index) or command.startswith("||", index):
|
||||
part = "".join(current).strip()
|
||||
if part:
|
||||
parts.append(part)
|
||||
current = []
|
||||
index += 2
|
||||
continue
|
||||
if char == ";":
|
||||
part = "".join(current).strip()
|
||||
if part:
|
||||
parts.append(part)
|
||||
current = []
|
||||
index += 1
|
||||
continue
|
||||
|
||||
current.append(char)
|
||||
index += 1
|
||||
|
||||
# Unclosed quote or dangling escape → fail-closed, return whole command
|
||||
if in_single_quote or in_double_quote or escaping:
|
||||
return [command]
|
||||
|
||||
part = "".join(current).strip()
|
||||
if part:
|
||||
parts.append(part)
|
||||
return parts if parts else [command]
|
||||
|
||||
|
||||
def _classify_single_command(command: str) -> str:
|
||||
"""Classify a single (non-compound) command. Return 'block', 'warn', or 'pass'."""
|
||||
normalized = " ".join(command.split())
|
||||
|
||||
for pattern in _HIGH_RISK_PATTERNS:
|
||||
@@ -66,6 +160,35 @@ def _classify_command(command: str) -> str:
|
||||
return "pass"
|
||||
|
||||
|
||||
def _classify_command(command: str) -> str:
|
||||
"""Return 'block', 'warn', or 'pass'.
|
||||
|
||||
Strategy:
|
||||
1. First scan the *whole* raw command against high-risk patterns. This
|
||||
catches structural attacks like ``while true; do bash & done`` or
|
||||
``:(){ :|:& };:`` that span multiple shell statements — splitting them
|
||||
on ``;`` would destroy the pattern context.
|
||||
2. Then split compound commands (e.g. ``cmd1 && cmd2 ; cmd3``) and
|
||||
classify each sub-command independently. The most severe verdict wins.
|
||||
"""
|
||||
# Pass 1: whole-command high-risk scan (catches multi-statement patterns)
|
||||
normalized = " ".join(command.split())
|
||||
for pattern in _HIGH_RISK_PATTERNS:
|
||||
if pattern.search(normalized):
|
||||
return "block"
|
||||
|
||||
# Pass 2: per-sub-command classification
|
||||
sub_commands = _split_compound_command(command)
|
||||
worst = "pass"
|
||||
for sub in sub_commands:
|
||||
verdict = _classify_single_command(sub)
|
||||
if verdict == "block":
|
||||
return "block" # short-circuit: can't get worse
|
||||
if verdict == "warn":
|
||||
worst = "warn"
|
||||
return worst
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user