mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
Merge branch 'main' into release/2.0-rc
This commit is contained in:
@@ -2,8 +2,14 @@ from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointe
|
||||
from .factory import create_deerflow_agent
|
||||
from .features import Next, Prev, RuntimeFeatures
|
||||
from .lead_agent import make_lead_agent
|
||||
from .lead_agent.prompt import prime_enabled_skills_cache
|
||||
from .thread_state import SandboxState, ThreadState
|
||||
|
||||
# LangGraph imports deerflow.agents when registering the graph. Prime the
|
||||
# enabled-skills cache here so the request path can usually read a warm cache
|
||||
# without forcing synchronous filesystem work during prompt module import.
|
||||
prime_enabled_skills_cache()
|
||||
|
||||
__all__ = [
|
||||
"create_deerflow_agent",
|
||||
"RuntimeFeatures",
|
||||
|
||||
@@ -17,6 +17,7 @@ For sync usage see :mod:`deerflow.agents.checkpointer.provider`.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
@@ -54,7 +55,7 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
ensure_sqlite_parent_dir(conn_str)
|
||||
await asyncio.to_thread(ensure_sqlite_parent_dir, conn_str)
|
||||
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
|
||||
@@ -289,14 +289,14 @@ def make_lead_agent(config: RunnableConfig):
|
||||
agent_name = cfg.get("agent_name")
|
||||
|
||||
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
|
||||
# Custom agent model or fallback to global/default model resolution
|
||||
agent_model_name = agent_config.model if agent_config and agent_config.model else _resolve_model_name()
|
||||
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
|
||||
agent_model_name = agent_config.model if agent_config and agent_config.model else None
|
||||
|
||||
# Final model name resolution with request override, then agent config, then global default
|
||||
model_name = requested_model_name or agent_model_name
|
||||
# Final model name resolution: request → agent config → global default, with fallback for unknown names
|
||||
model_name = _resolve_model_name(requested_model_name or agent_model_name)
|
||||
|
||||
app_config = get_app_config()
|
||||
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||
model_config = app_config.get_model_config(model_name)
|
||||
|
||||
if model_config is None:
|
||||
raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.")
|
||||
|
||||
@@ -1,20 +1,167 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
|
||||
from deerflow.config.agents_config import load_agent_soul
|
||||
from deerflow.skills import load_skills
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.subagents import get_available_subagent_names
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
|
||||
_enabled_skills_lock = threading.Lock()
|
||||
_enabled_skills_cache: list[Skill] | None = None
|
||||
_enabled_skills_refresh_active = False
|
||||
_enabled_skills_refresh_version = 0
|
||||
_enabled_skills_refresh_event = threading.Event()
|
||||
|
||||
|
||||
def _load_enabled_skills_sync() -> list[Skill]:
|
||||
return list(load_skills(enabled_only=True))
|
||||
|
||||
|
||||
def _start_enabled_skills_refresh_thread() -> None:
|
||||
threading.Thread(
|
||||
target=_refresh_enabled_skills_cache_worker,
|
||||
name="deerflow-enabled-skills-loader",
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
|
||||
def _refresh_enabled_skills_cache_worker() -> None:
|
||||
global _enabled_skills_cache, _enabled_skills_refresh_active
|
||||
|
||||
while True:
|
||||
with _enabled_skills_lock:
|
||||
target_version = _enabled_skills_refresh_version
|
||||
|
||||
try:
|
||||
skills = _load_enabled_skills_sync()
|
||||
except Exception:
|
||||
logger.exception("Failed to load enabled skills for prompt injection")
|
||||
skills = []
|
||||
|
||||
with _enabled_skills_lock:
|
||||
if _enabled_skills_refresh_version == target_version:
|
||||
_enabled_skills_cache = skills
|
||||
_enabled_skills_refresh_active = False
|
||||
_enabled_skills_refresh_event.set()
|
||||
return
|
||||
|
||||
# A newer invalidation happened while loading. Keep the worker alive
|
||||
# and loop again so the cache always converges on the latest version.
|
||||
_enabled_skills_cache = None
|
||||
|
||||
|
||||
def _ensure_enabled_skills_cache() -> threading.Event:
|
||||
global _enabled_skills_refresh_active
|
||||
|
||||
with _enabled_skills_lock:
|
||||
if _enabled_skills_cache is not None:
|
||||
_enabled_skills_refresh_event.set()
|
||||
return _enabled_skills_refresh_event
|
||||
if _enabled_skills_refresh_active:
|
||||
return _enabled_skills_refresh_event
|
||||
_enabled_skills_refresh_active = True
|
||||
_enabled_skills_refresh_event.clear()
|
||||
|
||||
_start_enabled_skills_refresh_thread()
|
||||
return _enabled_skills_refresh_event
|
||||
|
||||
|
||||
def _invalidate_enabled_skills_cache() -> threading.Event:
|
||||
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
|
||||
|
||||
_get_cached_skills_prompt_section.cache_clear()
|
||||
with _enabled_skills_lock:
|
||||
_enabled_skills_cache = None
|
||||
_enabled_skills_refresh_version += 1
|
||||
_enabled_skills_refresh_event.clear()
|
||||
if _enabled_skills_refresh_active:
|
||||
return _enabled_skills_refresh_event
|
||||
_enabled_skills_refresh_active = True
|
||||
|
||||
_start_enabled_skills_refresh_thread()
|
||||
return _enabled_skills_refresh_event
|
||||
|
||||
|
||||
def prime_enabled_skills_cache() -> None:
|
||||
_ensure_enabled_skills_cache()
|
||||
|
||||
|
||||
def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS) -> bool:
|
||||
if _ensure_enabled_skills_cache().wait(timeout=timeout_seconds):
|
||||
return True
|
||||
|
||||
logger.warning("Timed out waiting %.1fs for enabled skills cache warm-up", timeout_seconds)
|
||||
return False
|
||||
|
||||
|
||||
def _get_enabled_skills():
|
||||
with _enabled_skills_lock:
|
||||
cached = _enabled_skills_cache
|
||||
|
||||
if cached is not None:
|
||||
return list(cached)
|
||||
|
||||
_ensure_enabled_skills_cache()
|
||||
return []
|
||||
|
||||
|
||||
def _skill_mutability_label(category: str) -> str:
|
||||
return "[custom, editable]" if category == "custom" else "[built-in]"
|
||||
|
||||
|
||||
def clear_skills_system_prompt_cache() -> None:
|
||||
_invalidate_enabled_skills_cache()
|
||||
|
||||
|
||||
async def refresh_skills_system_prompt_cache_async() -> None:
|
||||
await asyncio.to_thread(_invalidate_enabled_skills_cache().wait)
|
||||
|
||||
|
||||
def _reset_skills_system_prompt_cache_state() -> None:
|
||||
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
|
||||
|
||||
_get_cached_skills_prompt_section.cache_clear()
|
||||
with _enabled_skills_lock:
|
||||
_enabled_skills_cache = None
|
||||
_enabled_skills_refresh_active = False
|
||||
_enabled_skills_refresh_version = 0
|
||||
_enabled_skills_refresh_event.clear()
|
||||
|
||||
|
||||
def _refresh_enabled_skills_cache() -> None:
|
||||
"""Backward-compatible test helper for direct synchronous reload."""
|
||||
try:
|
||||
return list(load_skills(enabled_only=True))
|
||||
skills = _load_enabled_skills_sync()
|
||||
except Exception:
|
||||
logger.exception("Failed to load enabled skills for prompt injection")
|
||||
return []
|
||||
skills = []
|
||||
|
||||
with _enabled_skills_lock:
|
||||
_enabled_skills_cache = skills
|
||||
_enabled_skills_refresh_active = False
|
||||
_enabled_skills_refresh_event.set()
|
||||
|
||||
|
||||
def _build_skill_evolution_section(skill_evolution_enabled: bool) -> str:
|
||||
if not skill_evolution_enabled:
|
||||
return ""
|
||||
return """
|
||||
## Skill Self-Evolution
|
||||
After completing a task, consider creating or updating a skill when:
|
||||
- The task required 5+ tool calls to resolve
|
||||
- You overcame non-obvious errors or pitfalls
|
||||
- The user corrected your approach and the corrected version worked
|
||||
- You discovered a non-trivial, recurring workflow
|
||||
If you used a skill and encountered issues not covered by it, patch it immediately.
|
||||
Prefer patch over edit. Before creating a new skill, confirm with the user first.
|
||||
Skip simple one-off tasks.
|
||||
"""
|
||||
|
||||
|
||||
def _skill_mutability_label(category: str) -> str:
|
||||
@@ -294,6 +441,9 @@ You: "Deploying to staging..." [proceed]
|
||||
- Use `read_file` tool to read uploaded files using their paths from the list
|
||||
- For PDF, PPT, Excel, and Word files, converted Markdown versions (*.md) are available alongside originals
|
||||
- All temporary work happens in `/mnt/user-data/workspace`
|
||||
- Treat `/mnt/user-data/workspace` as your default current working directory for coding and file-editing tasks
|
||||
- When writing scripts or commands that create/read files from the workspace, prefer relative paths such as `hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`
|
||||
- Avoid hardcoding `/mnt/user-data/...` inside generated scripts when a relative path from the workspace is enough
|
||||
- Final deliverables must be copied to `/mnt/user-data/outputs` and presented using `present_file` tool
|
||||
{acp_section}
|
||||
</working_directory>
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
@@ -18,7 +18,7 @@ class ConversationContext:
|
||||
|
||||
thread_id: str
|
||||
messages: list[Any]
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
agent_name: str | None = None
|
||||
correction_detected: bool = False
|
||||
reinforcement_detected: bool = False
|
||||
|
||||
@@ -4,7 +4,7 @@ import abc
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -15,11 +15,16 @@ from deerflow.config.paths import get_paths
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def utc_now_iso_z() -> str:
|
||||
"""Current UTC time as ISO-8601 with ``Z`` suffix (matches prior naive-UTC output)."""
|
||||
return datetime.now(UTC).isoformat().removesuffix("+00:00") + "Z"
|
||||
|
||||
|
||||
def create_empty_memory() -> dict[str, Any]:
|
||||
"""Create an empty memory structure."""
|
||||
return {
|
||||
"version": "1.0",
|
||||
"lastUpdated": datetime.utcnow().isoformat() + "Z",
|
||||
"lastUpdated": utc_now_iso_z(),
|
||||
"user": {
|
||||
"workContext": {"summary": "", "updatedAt": ""},
|
||||
"personalContext": {"summary": "", "updatedAt": ""},
|
||||
@@ -137,7 +142,7 @@ class FileMemoryStorage(MemoryStorage):
|
||||
|
||||
try:
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
memory_data["lastUpdated"] = datetime.utcnow().isoformat() + "Z"
|
||||
memory_data["lastUpdated"] = utc_now_iso_z()
|
||||
|
||||
temp_path = file_path.with_suffix(".tmp")
|
||||
with open(temp_path, "w", encoding="utf-8") as f:
|
||||
|
||||
@@ -5,14 +5,17 @@ import logging
|
||||
import math
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from deerflow.agents.memory.prompt import (
|
||||
MEMORY_UPDATE_PROMPT,
|
||||
format_conversation_for_update,
|
||||
)
|
||||
from deerflow.agents.memory.storage import create_empty_memory, get_memory_storage
|
||||
from deerflow.agents.memory.storage import (
|
||||
create_empty_memory,
|
||||
get_memory_storage,
|
||||
utc_now_iso_z,
|
||||
)
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
@@ -86,7 +89,7 @@ def create_memory_fact(
|
||||
|
||||
normalized_category = category.strip() or "context"
|
||||
validated_confidence = _validate_confidence(confidence)
|
||||
now = datetime.utcnow().isoformat() + "Z"
|
||||
now = utc_now_iso_z()
|
||||
memory_data = get_memory_data(agent_name)
|
||||
updated_memory = dict(memory_data)
|
||||
facts = list(memory_data.get("facts", []))
|
||||
@@ -376,7 +379,7 @@ class MemoryUpdater:
|
||||
Updated memory data.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
now = datetime.utcnow().isoformat() + "Z"
|
||||
now = utc_now_iso_z()
|
||||
|
||||
# Update user sections
|
||||
user_updates = update_data.get("user", {})
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Middleware for intercepting clarification requests and presenting them to the user."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import override
|
||||
@@ -60,6 +61,20 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
|
||||
context = args.get("context")
|
||||
options = args.get("options", [])
|
||||
|
||||
# Some models (e.g. Qwen3-Max) serialize array parameters as JSON strings
|
||||
# instead of native arrays. Deserialize and normalize so `options`
|
||||
# is always a list for the rendering logic below.
|
||||
if isinstance(options, str):
|
||||
try:
|
||||
options = json.loads(options)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
options = [options]
|
||||
|
||||
if options is None:
|
||||
options = []
|
||||
elif not isinstance(options, list):
|
||||
options = [options]
|
||||
|
||||
# Type-specific icons
|
||||
type_icons = {
|
||||
"missing_info": "❓",
|
||||
|
||||
@@ -33,30 +33,92 @@ _DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
|
||||
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
|
||||
|
||||
|
||||
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
|
||||
"""Normalize tool call args to a dict plus an optional fallback key.
|
||||
|
||||
Some providers serialize ``args`` as a JSON string instead of a dict.
|
||||
We defensively parse those cases so loop detection does not crash while
|
||||
still preserving a stable fallback key for non-dict payloads.
|
||||
"""
|
||||
if isinstance(raw_args, dict):
|
||||
return raw_args, None
|
||||
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
parsed = json.loads(raw_args)
|
||||
except (TypeError, ValueError, json.JSONDecodeError):
|
||||
return {}, raw_args
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
return parsed, None
|
||||
return {}, json.dumps(parsed, sort_keys=True, default=str)
|
||||
|
||||
if raw_args is None:
|
||||
return {}, None
|
||||
|
||||
return {}, json.dumps(raw_args, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def _stable_tool_key(name: str, args: dict, fallback_key: str | None) -> str:
|
||||
"""Derive a stable key from salient args without overfitting to noise."""
|
||||
if name == "read_file" and fallback_key is None:
|
||||
path = args.get("path") or ""
|
||||
start_line = args.get("start_line")
|
||||
end_line = args.get("end_line")
|
||||
|
||||
bucket_size = 200
|
||||
try:
|
||||
start_line = int(start_line) if start_line is not None else 1
|
||||
except (TypeError, ValueError):
|
||||
start_line = 1
|
||||
try:
|
||||
end_line = int(end_line) if end_line is not None else start_line
|
||||
except (TypeError, ValueError):
|
||||
end_line = start_line
|
||||
|
||||
start_line, end_line = sorted((start_line, end_line))
|
||||
bucket_start = max(start_line, 1)
|
||||
bucket_end = max(end_line, 1)
|
||||
bucket_start = (bucket_start - 1) // bucket_size
|
||||
bucket_end = (bucket_end - 1) // bucket_size
|
||||
return f"{path}:{bucket_start}-{bucket_end}"
|
||||
|
||||
# write_file / str_replace are content-sensitive: same path may be updated
|
||||
# with different payloads during iteration. Using only salient fields (path)
|
||||
# can collapse distinct calls, so we hash full args to reduce false positives.
|
||||
if name in {"write_file", "str_replace"}:
|
||||
if fallback_key is not None:
|
||||
return fallback_key
|
||||
return json.dumps(args, sort_keys=True, default=str)
|
||||
|
||||
salient_fields = ("path", "url", "query", "command", "pattern", "glob", "cmd")
|
||||
stable_args = {field: args[field] for field in salient_fields if args.get(field) is not None}
|
||||
if stable_args:
|
||||
return json.dumps(stable_args, sort_keys=True, default=str)
|
||||
|
||||
if fallback_key is not None:
|
||||
return fallback_key
|
||||
|
||||
return json.dumps(args, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def _hash_tool_calls(tool_calls: list[dict]) -> str:
|
||||
"""Deterministic hash of a set of tool calls (name + args).
|
||||
"""Deterministic hash of a set of tool calls (name + stable key).
|
||||
|
||||
This is intended to be order-independent: the same multiset of tool calls
|
||||
should always produce the same hash, regardless of their input order.
|
||||
"""
|
||||
# First normalize each tool call to a minimal (name, args) structure.
|
||||
normalized: list[dict] = []
|
||||
# Normalize each tool call to a stable (name, key) structure.
|
||||
normalized: list[str] = []
|
||||
for tc in tool_calls:
|
||||
normalized.append(
|
||||
{
|
||||
"name": tc.get("name", ""),
|
||||
"args": tc.get("args", {}),
|
||||
}
|
||||
)
|
||||
name = tc.get("name", "")
|
||||
args, fallback_key = _normalize_tool_call_args(tc.get("args", {}))
|
||||
key = _stable_tool_key(name, args, fallback_key)
|
||||
|
||||
# Sort by both name and a deterministic serialization of args so that
|
||||
# permutations of the same multiset of calls yield the same ordering.
|
||||
normalized.sort(
|
||||
key=lambda tc: (
|
||||
tc["name"],
|
||||
json.dumps(tc["args"], sort_keys=True, default=str),
|
||||
)
|
||||
)
|
||||
normalized.append(f"{name}:{key}")
|
||||
|
||||
# Sort so permutations of the same multiset of calls yield the same ordering.
|
||||
normalized.sort()
|
||||
blob = json.dumps(normalized, sort_keys=True, default=str)
|
||||
return hashlib.md5(blob.encode()).hexdigest()[:12]
|
||||
|
||||
|
||||
@@ -23,25 +23,119 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Each pattern is compiled once at import time.
|
||||
_HIGH_RISK_PATTERNS: list[re.Pattern[str]] = [
|
||||
re.compile(r"rm\s+-[^\s]*r[^\s]*\s+(/\*?|~/?\*?|/home\b|/root\b)\s*$"), # rm -rf / /* ~ /home /root
|
||||
re.compile(r"(curl|wget).+\|\s*(ba)?sh"), # curl|sh, wget|sh
|
||||
# --- original rules (retained) ---
|
||||
re.compile(r"rm\s+-[^\s]*r[^\s]*\s+(/\*?|~/?\*?|/home\b|/root\b)\s*$"),
|
||||
re.compile(r"dd\s+if="),
|
||||
re.compile(r"mkfs"),
|
||||
re.compile(r"cat\s+/etc/shadow"),
|
||||
re.compile(r">\s*/etc/"), # overwrite /etc/ files
|
||||
re.compile(r">+\s*/etc/"),
|
||||
# --- pipe to sh/bash (generalised, replaces old curl|sh rule) ---
|
||||
re.compile(r"\|\s*(ba)?sh\b"),
|
||||
# --- command substitution (targeted – only dangerous executables) ---
|
||||
re.compile(r"[`$]\(?\s*(curl|wget|bash|sh|python|ruby|perl|base64)"),
|
||||
# --- base64 decode piped to execution ---
|
||||
re.compile(r"base64\s+.*-d.*\|"),
|
||||
# --- overwrite system binaries ---
|
||||
re.compile(r">+\s*(/usr/bin/|/bin/|/sbin/)"),
|
||||
# --- overwrite shell startup files ---
|
||||
re.compile(r">+\s*~/?\.(bashrc|profile|zshrc|bash_profile)"),
|
||||
# --- process environment leakage ---
|
||||
re.compile(r"/proc/[^/]+/environ"),
|
||||
# --- dynamic linker hijack (one-step escalation) ---
|
||||
re.compile(r"\b(LD_PRELOAD|LD_LIBRARY_PATH)\s*="),
|
||||
# --- bash built-in networking (bypasses tool allowlists) ---
|
||||
re.compile(r"/dev/tcp/"),
|
||||
# --- fork bomb ---
|
||||
re.compile(r"\S+\(\)\s*\{[^}]*\|\s*\S+\s*&"), # :(){ :|:& };:
|
||||
re.compile(r"while\s+true.*&\s*done"), # while true; do bash & done
|
||||
]
|
||||
|
||||
_MEDIUM_RISK_PATTERNS: list[re.Pattern[str]] = [
|
||||
re.compile(r"chmod\s+777"), # overly permissive, but reversible
|
||||
re.compile(r"pip\s+install"),
|
||||
re.compile(r"pip3\s+install"),
|
||||
re.compile(r"chmod\s+777"),
|
||||
re.compile(r"pip3?\s+install"),
|
||||
re.compile(r"apt(-get)?\s+install"),
|
||||
# sudo/su: no-op under Docker root; warn so LLM is aware
|
||||
re.compile(r"\b(sudo|su)\b"),
|
||||
# PATH modification: long attack chain, warn rather than block
|
||||
re.compile(r"\bPATH\s*="),
|
||||
]
|
||||
|
||||
|
||||
def _classify_command(command: str) -> str:
|
||||
"""Return 'block', 'warn', or 'pass'."""
|
||||
# Normalize for matching (collapse whitespace)
|
||||
def _split_compound_command(command: str) -> list[str]:
|
||||
"""Split a compound command into sub-commands (quote-aware).
|
||||
|
||||
Scans the raw command string so unquoted shell control operators are
|
||||
recognised even when they are not surrounded by whitespace
|
||||
(e.g. ``safe;rm -rf /`` or ``rm -rf /&&echo ok``). Operators inside
|
||||
quotes are ignored. If the command ends with an unclosed quote or a
|
||||
dangling escape, return the whole command unchanged (fail-closed —
|
||||
safer to classify the unsplit string than silently drop parts).
|
||||
"""
|
||||
parts: list[str] = []
|
||||
current: list[str] = []
|
||||
in_single_quote = False
|
||||
in_double_quote = False
|
||||
escaping = False
|
||||
index = 0
|
||||
|
||||
while index < len(command):
|
||||
char = command[index]
|
||||
|
||||
if escaping:
|
||||
current.append(char)
|
||||
escaping = False
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if char == "\\" and not in_single_quote:
|
||||
current.append(char)
|
||||
escaping = True
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if char == "'" and not in_double_quote:
|
||||
in_single_quote = not in_single_quote
|
||||
current.append(char)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if char == '"' and not in_single_quote:
|
||||
in_double_quote = not in_double_quote
|
||||
current.append(char)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if not in_single_quote and not in_double_quote:
|
||||
if command.startswith("&&", index) or command.startswith("||", index):
|
||||
part = "".join(current).strip()
|
||||
if part:
|
||||
parts.append(part)
|
||||
current = []
|
||||
index += 2
|
||||
continue
|
||||
if char == ";":
|
||||
part = "".join(current).strip()
|
||||
if part:
|
||||
parts.append(part)
|
||||
current = []
|
||||
index += 1
|
||||
continue
|
||||
|
||||
current.append(char)
|
||||
index += 1
|
||||
|
||||
# Unclosed quote or dangling escape → fail-closed, return whole command
|
||||
if in_single_quote or in_double_quote or escaping:
|
||||
return [command]
|
||||
|
||||
part = "".join(current).strip()
|
||||
if part:
|
||||
parts.append(part)
|
||||
return parts if parts else [command]
|
||||
|
||||
|
||||
def _classify_single_command(command: str) -> str:
|
||||
"""Classify a single (non-compound) command. Return 'block', 'warn', or 'pass'."""
|
||||
normalized = " ".join(command.split())
|
||||
|
||||
for pattern in _HIGH_RISK_PATTERNS:
|
||||
@@ -66,6 +160,35 @@ def _classify_command(command: str) -> str:
|
||||
return "pass"
|
||||
|
||||
|
||||
def _classify_command(command: str) -> str:
|
||||
"""Return 'block', 'warn', or 'pass'.
|
||||
|
||||
Strategy:
|
||||
1. First scan the *whole* raw command against high-risk patterns. This
|
||||
catches structural attacks like ``while true; do bash & done`` or
|
||||
``:(){ :|:& };:`` that span multiple shell statements — splitting them
|
||||
on ``;`` would destroy the pattern context.
|
||||
2. Then split compound commands (e.g. ``cmd1 && cmd2 ; cmd3``) and
|
||||
classify each sub-command independently. The most severe verdict wins.
|
||||
"""
|
||||
# Pass 1: whole-command high-risk scan (catches multi-statement patterns)
|
||||
normalized = " ".join(command.split())
|
||||
for pattern in _HIGH_RISK_PATTERNS:
|
||||
if pattern.search(normalized):
|
||||
return "block"
|
||||
|
||||
# Pass 2: per-sub-command classification
|
||||
sub_commands = _split_compound_command(command)
|
||||
worst = "pass"
|
||||
for sub in sub_commands:
|
||||
verdict = _classify_single_command(sub)
|
||||
if verdict == "block":
|
||||
return "block" # short-circuit: can't get worse
|
||||
if verdict == "warn":
|
||||
worst = "warn"
|
||||
return worst
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -25,7 +25,7 @@ import uuid
|
||||
from collections.abc import Generator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
@@ -55,6 +55,9 @@ from deerflow.uploads.manager import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
StreamEventType = Literal["values", "messages-tuple", "custom", "end"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamEvent:
|
||||
"""A single event from the streaming agent response.
|
||||
@@ -69,7 +72,7 @@ class StreamEvent:
|
||||
data: Event payload. Contents vary by type.
|
||||
"""
|
||||
|
||||
type: str
|
||||
type: StreamEventType
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -254,13 +257,53 @@ class DeerFlowClient:
|
||||
|
||||
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_tool_calls(tool_calls) -> list[dict]:
|
||||
"""Reshape LangChain tool_calls into the wire format used in events."""
|
||||
return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls]
|
||||
|
||||
@staticmethod
|
||||
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None) -> "StreamEvent":
|
||||
"""Build a ``messages-tuple`` AI text event, attaching usage when present."""
|
||||
data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
|
||||
if usage:
|
||||
data["usage_metadata"] = usage
|
||||
return StreamEvent(type="messages-tuple", data=data)
|
||||
|
||||
@staticmethod
|
||||
def _ai_tool_calls_event(msg_id: str | None, tool_calls) -> "StreamEvent":
|
||||
"""Build a ``messages-tuple`` AI tool-calls event."""
|
||||
return StreamEvent(
|
||||
type="messages-tuple",
|
||||
data={
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"id": msg_id,
|
||||
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _tool_message_event(msg: ToolMessage) -> "StreamEvent":
|
||||
"""Build a ``messages-tuple`` tool-result event from a ToolMessage."""
|
||||
return StreamEvent(
|
||||
type="messages-tuple",
|
||||
data={
|
||||
"type": "tool",
|
||||
"content": DeerFlowClient._extract_text(msg.content),
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"id": msg.id,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_message(msg) -> dict:
|
||||
"""Serialize a LangChain message to a plain dict for values events."""
|
||||
if isinstance(msg, AIMessage):
|
||||
d: dict[str, Any] = {"type": "ai", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
if msg.tool_calls:
|
||||
d["tool_calls"] = [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls]
|
||||
d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls)
|
||||
if getattr(msg, "usage_metadata", None):
|
||||
d["usage_metadata"] = msg.usage_metadata
|
||||
return d
|
||||
@@ -315,6 +358,108 @@ class DeerFlowClient:
|
||||
return "\n".join(pieces) if pieces else ""
|
||||
return str(content)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — threads
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_threads(self, limit: int = 10) -> dict:
|
||||
"""List the recent N threads.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of threads to return. Default is 10.
|
||||
|
||||
Returns:
|
||||
Dict with "thread_list" key containing list of thread info dicts,
|
||||
sorted by thread creation time descending.
|
||||
"""
|
||||
checkpointer = self._checkpointer
|
||||
if checkpointer is None:
|
||||
from deerflow.agents.checkpointer.provider import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
|
||||
thread_info_map = {}
|
||||
|
||||
for cp in checkpointer.list(config=None, limit=limit):
|
||||
cfg = cp.config.get("configurable", {})
|
||||
thread_id = cfg.get("thread_id")
|
||||
if not thread_id:
|
||||
continue
|
||||
|
||||
ts = cp.checkpoint.get("ts")
|
||||
checkpoint_id = cfg.get("checkpoint_id")
|
||||
|
||||
if thread_id not in thread_info_map:
|
||||
channel_values = cp.checkpoint.get("channel_values", {})
|
||||
thread_info_map[thread_id] = {
|
||||
"thread_id": thread_id,
|
||||
"created_at": ts,
|
||||
"updated_at": ts,
|
||||
"latest_checkpoint_id": checkpoint_id,
|
||||
"title": channel_values.get("title"),
|
||||
}
|
||||
else:
|
||||
# Explicitly compare timestamps to ensure accuracy when iterating over unordered namespaces.
|
||||
# Treat None as "missing" and only compare when existing values are non-None.
|
||||
if ts is not None:
|
||||
current_created = thread_info_map[thread_id]["created_at"]
|
||||
if current_created is None or ts < current_created:
|
||||
thread_info_map[thread_id]["created_at"] = ts
|
||||
|
||||
current_updated = thread_info_map[thread_id]["updated_at"]
|
||||
if current_updated is None or ts > current_updated:
|
||||
thread_info_map[thread_id]["updated_at"] = ts
|
||||
thread_info_map[thread_id]["latest_checkpoint_id"] = checkpoint_id
|
||||
channel_values = cp.checkpoint.get("channel_values", {})
|
||||
thread_info_map[thread_id]["title"] = channel_values.get("title")
|
||||
|
||||
threads = list(thread_info_map.values())
|
||||
threads.sort(key=lambda x: x.get("created_at") or "", reverse=True)
|
||||
|
||||
return {"thread_list": threads[:limit]}
|
||||
|
||||
def get_thread(self, thread_id: str) -> dict:
|
||||
"""Get the complete thread record, including all node execution records.
|
||||
|
||||
Args:
|
||||
thread_id: Thread ID.
|
||||
|
||||
Returns:
|
||||
Dict containing the thread's full checkpoint history.
|
||||
"""
|
||||
checkpointer = self._checkpointer
|
||||
if checkpointer is None:
|
||||
from deerflow.agents.checkpointer.provider import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
checkpoints = []
|
||||
|
||||
for cp in checkpointer.list(config):
|
||||
channel_values = dict(cp.checkpoint.get("channel_values", {}))
|
||||
if "messages" in channel_values:
|
||||
channel_values["messages"] = [self._serialize_message(m) if hasattr(m, "content") else m for m in channel_values["messages"]]
|
||||
|
||||
cfg = cp.config.get("configurable", {})
|
||||
parent_cfg = cp.parent_config.get("configurable", {}) if cp.parent_config else {}
|
||||
|
||||
checkpoints.append(
|
||||
{
|
||||
"checkpoint_id": cfg.get("checkpoint_id"),
|
||||
"parent_checkpoint_id": parent_cfg.get("checkpoint_id"),
|
||||
"ts": cp.checkpoint.get("ts"),
|
||||
"metadata": cp.metadata,
|
||||
"values": channel_values,
|
||||
"pending_writes": [{"task_id": w[0], "channel": w[1], "value": w[2]} for w in getattr(cp, "pending_writes", [])],
|
||||
}
|
||||
)
|
||||
|
||||
# Sort globally by timestamp to prevent partial ordering issues caused by different namespaces (e.g., subgraphs)
|
||||
checkpoints.sort(key=lambda x: x["ts"] if x["ts"] else "")
|
||||
|
||||
return {"thread_id": thread_id, "checkpoints": checkpoints}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — conversation
|
||||
# ------------------------------------------------------------------
|
||||
@@ -336,6 +481,53 @@ class DeerFlowClient:
|
||||
consumers can switch between HTTP streaming and embedded mode
|
||||
without changing their event-handling logic.
|
||||
|
||||
Token-level streaming
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
This method subscribes to LangGraph's ``messages`` stream mode, so
|
||||
``messages-tuple`` events for AI text are emitted as **deltas** as
|
||||
the model generates tokens, not as one cumulative dump at node
|
||||
completion. Each delta carries a stable ``id`` — consumers that
|
||||
want the full text must accumulate ``content`` per ``id``.
|
||||
``chat()`` already does this for you.
|
||||
|
||||
Tool calls and tool results are still emitted once per logical
|
||||
message. ``values`` events continue to carry full state snapshots
|
||||
after each graph node finishes; AI text already delivered via the
|
||||
``messages`` stream is **not** re-synthesized from the snapshot to
|
||||
avoid duplicate deliveries.
|
||||
|
||||
Why not reuse Gateway's ``run_agent``?
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Gateway (``runtime/runs/worker.py``) has a complete streaming
|
||||
pipeline: ``run_agent`` → ``StreamBridge`` → ``sse_consumer``. It
|
||||
looks like this client duplicates that work, but the two paths
|
||||
serve different audiences and **cannot** share execution:
|
||||
|
||||
* ``run_agent`` is ``async def`` and uses ``agent.astream()``;
|
||||
this method is a sync generator using ``agent.stream()`` so
|
||||
callers can write ``for event in client.stream(...)`` without
|
||||
touching asyncio. Bridging the two would require spinning up
|
||||
an event loop + thread per call.
|
||||
* Gateway events are JSON-serialized by ``serialize()`` for SSE
|
||||
wire transmission. This client yields in-process stream event
|
||||
payloads directly as Python data structures (``StreamEvent``
|
||||
with ``data`` as a plain ``dict``), without the extra
|
||||
JSON/SSE serialization layer used for HTTP delivery.
|
||||
* ``StreamBridge`` is an asyncio-queue decoupling producers from
|
||||
consumers across an HTTP boundary (``Last-Event-ID`` replay,
|
||||
heartbeats, multi-subscriber fan-out). A single in-process
|
||||
caller with a direct iterator needs none of that.
|
||||
|
||||
So ``DeerFlowClient.stream()`` is a parallel, sync, in-process
|
||||
consumer of the same ``create_agent()`` factory — not a wrapper
|
||||
around Gateway. The two paths **should** stay in sync on which
|
||||
LangGraph stream modes they subscribe to; that invariant is
|
||||
enforced by ``tests/test_client.py::test_messages_mode_emits_token_deltas``
|
||||
rather than by a shared constant, because the three layers
|
||||
(Graph, Platform SDK, HTTP) each use their own naming
|
||||
(``messages`` vs ``messages-tuple``) and cannot literally share
|
||||
a string.
|
||||
|
||||
Args:
|
||||
message: User message text.
|
||||
thread_id: Thread ID for conversation context. Auto-generated if None.
|
||||
@@ -346,8 +538,8 @@ class DeerFlowClient:
|
||||
StreamEvent with one of:
|
||||
- type="values" data={"title": str|None, "messages": [...], "artifacts": [...]}
|
||||
- type="custom" data={...}
|
||||
- type="messages-tuple" data={"type": "ai", "content": str, "id": str}
|
||||
- type="messages-tuple" data={"type": "ai", "content": str, "id": str, "usage_metadata": {...}}
|
||||
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str}
|
||||
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str, "usage_metadata": {...}}
|
||||
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
|
||||
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
|
||||
- type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}}
|
||||
@@ -364,13 +556,47 @@ class DeerFlowClient:
|
||||
context["agent_name"] = self._agent_name
|
||||
|
||||
seen_ids: set[str] = set()
|
||||
# Cross-mode handoff: ids already streamed via LangGraph ``messages``
|
||||
# mode so the ``values`` path skips re-synthesis of the same message.
|
||||
streamed_ids: set[str] = set()
|
||||
# The same message id carries identical cumulative ``usage_metadata``
|
||||
# in both the final ``messages`` chunk and the values snapshot —
|
||||
# count it only on whichever arrives first.
|
||||
counted_usage_ids: set[str] = set()
|
||||
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
|
||||
def _account_usage(msg_id: str | None, usage: Any) -> dict | None:
|
||||
"""Add *usage* to cumulative totals if this id has not been counted.
|
||||
|
||||
``usage`` is a ``langchain_core.messages.UsageMetadata`` TypedDict
|
||||
or ``None``; typed as ``Any`` because TypedDicts are not
|
||||
structurally assignable to plain ``dict`` under strict type
|
||||
checking. Returns the normalized usage dict (for attaching
|
||||
to an event) when we accepted it, otherwise ``None``.
|
||||
"""
|
||||
if not usage:
|
||||
return None
|
||||
if msg_id and msg_id in counted_usage_ids:
|
||||
return None
|
||||
if msg_id:
|
||||
counted_usage_ids.add(msg_id)
|
||||
input_tokens = usage.get("input_tokens", 0) or 0
|
||||
output_tokens = usage.get("output_tokens", 0) or 0
|
||||
total_tokens = usage.get("total_tokens", 0) or 0
|
||||
cumulative_usage["input_tokens"] += input_tokens
|
||||
cumulative_usage["output_tokens"] += output_tokens
|
||||
cumulative_usage["total_tokens"] += total_tokens
|
||||
return {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
for item in self._agent.stream(
|
||||
state,
|
||||
config=config,
|
||||
context=context,
|
||||
stream_mode=["values", "custom"],
|
||||
stream_mode=["values", "messages", "custom"],
|
||||
):
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
mode, chunk = item
|
||||
@@ -382,6 +608,36 @@ class DeerFlowClient:
|
||||
yield StreamEvent(type="custom", data=chunk)
|
||||
continue
|
||||
|
||||
if mode == "messages":
|
||||
# LangGraph ``messages`` mode emits ``(message_chunk, metadata)``.
|
||||
if isinstance(chunk, tuple) and len(chunk) == 2:
|
||||
msg_chunk, _metadata = chunk
|
||||
else:
|
||||
msg_chunk = chunk
|
||||
|
||||
msg_id = getattr(msg_chunk, "id", None)
|
||||
|
||||
if isinstance(msg_chunk, AIMessage):
|
||||
text = self._extract_text(msg_chunk.content)
|
||||
counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata)
|
||||
|
||||
if text:
|
||||
if msg_id:
|
||||
streamed_ids.add(msg_id)
|
||||
yield self._ai_text_event(msg_id, text, counted_usage)
|
||||
|
||||
if msg_chunk.tool_calls:
|
||||
if msg_id:
|
||||
streamed_ids.add(msg_id)
|
||||
yield self._ai_tool_calls_event(msg_id, msg_chunk.tool_calls)
|
||||
|
||||
elif isinstance(msg_chunk, ToolMessage):
|
||||
if msg_id:
|
||||
streamed_ids.add(msg_id)
|
||||
yield self._tool_message_event(msg_chunk)
|
||||
continue
|
||||
|
||||
# mode == "values"
|
||||
messages = chunk.get("messages", [])
|
||||
|
||||
for msg in messages:
|
||||
@@ -391,47 +647,25 @@ class DeerFlowClient:
|
||||
if msg_id:
|
||||
seen_ids.add(msg_id)
|
||||
|
||||
# Already streamed via ``messages`` mode; only (defensively)
|
||||
# capture usage here and skip re-synthesizing the event.
|
||||
if msg_id and msg_id in streamed_ids:
|
||||
if isinstance(msg, AIMessage):
|
||||
_account_usage(msg_id, getattr(msg, "usage_metadata", None))
|
||||
continue
|
||||
|
||||
if isinstance(msg, AIMessage):
|
||||
# Track token usage from AI messages
|
||||
usage = getattr(msg, "usage_metadata", None)
|
||||
if usage:
|
||||
cumulative_usage["input_tokens"] += usage.get("input_tokens", 0) or 0
|
||||
cumulative_usage["output_tokens"] += usage.get("output_tokens", 0) or 0
|
||||
cumulative_usage["total_tokens"] += usage.get("total_tokens", 0) or 0
|
||||
counted_usage = _account_usage(msg_id, msg.usage_metadata)
|
||||
|
||||
if msg.tool_calls:
|
||||
yield StreamEvent(
|
||||
type="messages-tuple",
|
||||
data={
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"id": msg_id,
|
||||
"tool_calls": [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls],
|
||||
},
|
||||
)
|
||||
yield self._ai_tool_calls_event(msg_id, msg.tool_calls)
|
||||
|
||||
text = self._extract_text(msg.content)
|
||||
if text:
|
||||
event_data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
|
||||
if usage:
|
||||
event_data["usage_metadata"] = {
|
||||
"input_tokens": usage.get("input_tokens", 0) or 0,
|
||||
"output_tokens": usage.get("output_tokens", 0) or 0,
|
||||
"total_tokens": usage.get("total_tokens", 0) or 0,
|
||||
}
|
||||
yield StreamEvent(type="messages-tuple", data=event_data)
|
||||
yield self._ai_text_event(msg_id, text, counted_usage)
|
||||
|
||||
elif isinstance(msg, ToolMessage):
|
||||
yield StreamEvent(
|
||||
type="messages-tuple",
|
||||
data={
|
||||
"type": "tool",
|
||||
"content": self._extract_text(msg.content),
|
||||
"name": getattr(msg, "name", None),
|
||||
"tool_call_id": getattr(msg, "tool_call_id", None),
|
||||
"id": msg_id,
|
||||
},
|
||||
)
|
||||
yield self._tool_message_event(msg)
|
||||
|
||||
# Emit a values event for each state snapshot
|
||||
yield StreamEvent(
|
||||
@@ -448,10 +682,12 @@ class DeerFlowClient:
|
||||
def chat(self, message: str, *, thread_id: str | None = None, **kwargs) -> str:
|
||||
"""Send a message and return the final text response.
|
||||
|
||||
Convenience wrapper around :meth:`stream` that returns only the
|
||||
**last** AI text from ``messages-tuple`` events. If the agent emits
|
||||
multiple text segments in one turn, intermediate segments are
|
||||
discarded. Use :meth:`stream` directly to capture all events.
|
||||
Convenience wrapper around :meth:`stream` that accumulates delta
|
||||
``messages-tuple`` events per ``id`` and returns the text of the
|
||||
**last** AI message to complete. Intermediate AI messages (e.g.
|
||||
planner drafts) are discarded — only the final id's accumulated
|
||||
text is returned. Use :meth:`stream` directly if you need every
|
||||
delta as it arrives.
|
||||
|
||||
Args:
|
||||
message: User message text.
|
||||
@@ -459,15 +695,21 @@ class DeerFlowClient:
|
||||
**kwargs: Override client defaults (same as stream()).
|
||||
|
||||
Returns:
|
||||
The last AI message text, or empty string if no response.
|
||||
The accumulated text of the last AI message, or empty string
|
||||
if no AI text was produced.
|
||||
"""
|
||||
last_text = ""
|
||||
# Per-id delta lists joined once at the end — avoids the O(n²) cost
|
||||
# of repeated ``str + str`` on a growing buffer for long responses.
|
||||
chunks: dict[str, list[str]] = {}
|
||||
last_id: str = ""
|
||||
for event in self.stream(message, thread_id=thread_id, **kwargs):
|
||||
if event.type == "messages-tuple" and event.data.get("type") == "ai":
|
||||
content = event.data.get("content", "")
|
||||
if content:
|
||||
last_text = content
|
||||
return last_text
|
||||
msg_id = event.data.get("id") or ""
|
||||
delta = event.data.get("content", "")
|
||||
if delta:
|
||||
chunks.setdefault(msg_id, []).append(delta)
|
||||
last_id = msg_id
|
||||
return "".join(chunks.get(last_id, ()))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — configuration queries
|
||||
|
||||
@@ -112,6 +112,9 @@ class AioSandboxProvider(SandboxProvider):
|
||||
atexit.register(self.shutdown)
|
||||
self._register_signal_handlers()
|
||||
|
||||
# Reconcile orphaned containers from previous process lifecycles
|
||||
self._reconcile_orphans()
|
||||
|
||||
# Start idle checker if enabled
|
||||
if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0:
|
||||
self._start_idle_checker()
|
||||
@@ -175,6 +178,51 @@ class AioSandboxProvider(SandboxProvider):
|
||||
resolved[key] = str(value)
|
||||
return resolved
|
||||
|
||||
# ── Startup reconciliation ────────────────────────────────────────────
|
||||
|
||||
def _reconcile_orphans(self) -> None:
|
||||
"""Reconcile orphaned containers left by previous process lifecycles.
|
||||
|
||||
On startup, enumerate all running containers matching our prefix
|
||||
and adopt them all into the warm pool. The idle checker will reclaim
|
||||
containers that nobody re-acquires within ``idle_timeout``.
|
||||
|
||||
All containers are adopted unconditionally because we cannot
|
||||
distinguish "orphaned" from "actively used by another process"
|
||||
based on age alone — ``idle_timeout`` represents inactivity, not
|
||||
uptime. Adopting into the warm pool and letting the idle checker
|
||||
decide avoids destroying containers that a concurrent process may
|
||||
still be using.
|
||||
|
||||
This closes the fundamental gap where in-memory state loss (process
|
||||
restart, crash, SIGKILL) leaves Docker containers running forever.
|
||||
"""
|
||||
try:
|
||||
running = self._backend.list_running()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to enumerate running containers during startup reconciliation: {e}")
|
||||
return
|
||||
|
||||
if not running:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
adopted = 0
|
||||
|
||||
for info in running:
|
||||
age = current_time - info.created_at if info.created_at > 0 else float("inf")
|
||||
# Single lock acquisition per container: atomic check-and-insert.
|
||||
# Avoids a TOCTOU window between the "already tracked?" check and
|
||||
# the warm-pool insert.
|
||||
with self._lock:
|
||||
if info.sandbox_id in self._sandboxes or info.sandbox_id in self._warm_pool:
|
||||
continue
|
||||
self._warm_pool[info.sandbox_id] = (info, current_time)
|
||||
adopted += 1
|
||||
logger.info(f"Adopted container {info.sandbox_id} into warm pool (age: {age:.0f}s)")
|
||||
|
||||
logger.info(f"Startup reconciliation complete: {adopted} adopted into warm pool, {len(running)} total found")
|
||||
|
||||
# ── Deterministic ID ─────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
@@ -316,13 +364,23 @@ class AioSandboxProvider(SandboxProvider):
|
||||
# ── Signal handling ──────────────────────────────────────────────────
|
||||
|
||||
def _register_signal_handlers(self) -> None:
|
||||
"""Register signal handlers for graceful shutdown."""
|
||||
"""Register signal handlers for graceful shutdown.
|
||||
|
||||
Handles SIGTERM, SIGINT, and SIGHUP (terminal close) to ensure
|
||||
sandbox containers are cleaned up even when the user closes the terminal.
|
||||
"""
|
||||
self._original_sigterm = signal.getsignal(signal.SIGTERM)
|
||||
self._original_sigint = signal.getsignal(signal.SIGINT)
|
||||
self._original_sighup = signal.getsignal(signal.SIGHUP) if hasattr(signal, "SIGHUP") else None
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
self.shutdown()
|
||||
original = self._original_sigterm if signum == signal.SIGTERM else self._original_sigint
|
||||
if signum == signal.SIGTERM:
|
||||
original = self._original_sigterm
|
||||
elif hasattr(signal, "SIGHUP") and signum == signal.SIGHUP:
|
||||
original = self._original_sighup
|
||||
else:
|
||||
original = self._original_sigint
|
||||
if callable(original):
|
||||
original(signum, frame)
|
||||
elif original == signal.SIG_DFL:
|
||||
@@ -332,6 +390,8 @@ class AioSandboxProvider(SandboxProvider):
|
||||
try:
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, signal_handler)
|
||||
except ValueError:
|
||||
logger.debug("Could not register signal handlers (not main thread)")
|
||||
|
||||
|
||||
@@ -96,3 +96,19 @@ class SandboxBackend(ABC):
|
||||
SandboxInfo if found and healthy, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
def list_running(self) -> list[SandboxInfo]:
|
||||
"""Enumerate all running sandboxes managed by this backend.
|
||||
|
||||
Used for startup reconciliation: when the process restarts, it needs
|
||||
to discover containers started by previous processes so they can be
|
||||
adopted into the warm pool or destroyed if idle too long.
|
||||
|
||||
The default implementation returns an empty list, which is correct
|
||||
for backends that don't manage local containers (e.g., RemoteSandboxBackend
|
||||
delegates lifecycle to the provisioner which handles its own cleanup).
|
||||
|
||||
Returns:
|
||||
A list of SandboxInfo for all currently running sandboxes.
|
||||
"""
|
||||
return []
|
||||
|
||||
@@ -6,9 +6,11 @@ Handles container lifecycle, port allocation, and cross-process container discov
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
from deerflow.utils.network import get_free_port, release_port
|
||||
|
||||
@@ -18,6 +20,52 @@ from .sandbox_info import SandboxInfo
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_docker_timestamp(raw: str) -> float:
|
||||
"""Parse Docker's ISO 8601 timestamp into a Unix epoch float.
|
||||
|
||||
Docker returns timestamps with nanosecond precision and a trailing ``Z``
|
||||
(e.g. ``2026-04-08T01:22:50.123456789Z``). Python's ``fromisoformat``
|
||||
accepts at most microseconds and (pre-3.11) does not accept ``Z``, so the
|
||||
string is normalized before parsing. Returns ``0.0`` on empty input or
|
||||
parse failure so callers can use ``0.0`` as a sentinel for "unknown age".
|
||||
"""
|
||||
if not raw:
|
||||
return 0.0
|
||||
try:
|
||||
s = raw.strip()
|
||||
if "." in s:
|
||||
dot_pos = s.index(".")
|
||||
tz_start = dot_pos + 1
|
||||
while tz_start < len(s) and s[tz_start].isdigit():
|
||||
tz_start += 1
|
||||
frac = s[dot_pos + 1 : tz_start][:6] # truncate to microseconds
|
||||
tz_suffix = s[tz_start:]
|
||||
s = s[: dot_pos + 1] + frac + tz_suffix
|
||||
if s.endswith("Z"):
|
||||
s = s[:-1] + "+00:00"
|
||||
return datetime.fromisoformat(s).timestamp()
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.debug(f"Could not parse docker timestamp {raw!r}: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def _extract_host_port(inspect_entry: dict, container_port: int) -> int | None:
|
||||
"""Extract the host port mapped to ``container_port/tcp`` from a docker inspect entry.
|
||||
|
||||
Returns None if the container has no port mapping for that port.
|
||||
"""
|
||||
try:
|
||||
ports = (inspect_entry.get("NetworkSettings") or {}).get("Ports") or {}
|
||||
bindings = ports.get(f"{container_port}/tcp") or []
|
||||
if bindings:
|
||||
host_port = bindings[0].get("HostPort")
|
||||
if host_port:
|
||||
return int(host_port)
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _format_container_mount(runtime: str, host_path: str, container_path: str, read_only: bool) -> list[str]:
|
||||
"""Format a bind-mount argument for the selected runtime.
|
||||
|
||||
@@ -172,8 +220,12 @@ class LocalContainerBackend(SandboxBackend):
|
||||
|
||||
def destroy(self, info: SandboxInfo) -> None:
|
||||
"""Stop the container and release its port."""
|
||||
if info.container_id:
|
||||
self._stop_container(info.container_id)
|
||||
# Prefer container_id, fall back to container_name (both accepted by docker stop).
|
||||
# This ensures containers discovered via list_running() (which only has the name)
|
||||
# can also be stopped.
|
||||
stop_target = info.container_id or info.container_name
|
||||
if stop_target:
|
||||
self._stop_container(stop_target)
|
||||
# Extract port from sandbox_url for release
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
@@ -222,6 +274,129 @@ class LocalContainerBackend(SandboxBackend):
|
||||
container_name=container_name,
|
||||
)
|
||||
|
||||
def list_running(self) -> list[SandboxInfo]:
|
||||
"""Enumerate all running containers matching the configured prefix.
|
||||
|
||||
Uses a single ``docker ps`` call to list container names, then a
|
||||
single batched ``docker inspect`` call to retrieve creation timestamp
|
||||
and port mapping for all containers at once. Total subprocess calls:
|
||||
2 (down from 2N+1 in the naive per-container approach).
|
||||
|
||||
Note: Docker's ``--filter name=`` performs *substring* matching,
|
||||
so a secondary ``startswith`` check is applied to ensure only
|
||||
containers with the exact prefix are included.
|
||||
|
||||
Containers without port mappings are still included (with empty
|
||||
sandbox_url) so that startup reconciliation can adopt orphans
|
||||
regardless of their port state.
|
||||
"""
|
||||
# Step 1: enumerate container names via docker ps
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
self._runtime,
|
||||
"ps",
|
||||
"--filter",
|
||||
f"name={self._container_prefix}-",
|
||||
"--format",
|
||||
"{{.Names}}",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
stderr = (result.stderr or "").strip()
|
||||
logger.warning(
|
||||
"Failed to list running containers with %s ps (returncode=%s, stderr=%s)",
|
||||
self._runtime,
|
||||
result.returncode,
|
||||
stderr or "<empty>",
|
||||
)
|
||||
return []
|
||||
if not result.stdout.strip():
|
||||
return []
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError, OSError) as e:
|
||||
logger.warning(f"Failed to list running containers: {e}")
|
||||
return []
|
||||
|
||||
# Filter to names matching our exact prefix (docker filter is substring-based)
|
||||
container_names = [name.strip() for name in result.stdout.strip().splitlines() if name.strip().startswith(self._container_prefix + "-")]
|
||||
if not container_names:
|
||||
return []
|
||||
|
||||
# Step 2: batched docker inspect — single subprocess call for all containers
|
||||
inspections = self._batch_inspect(container_names)
|
||||
|
||||
infos: list[SandboxInfo] = []
|
||||
sandbox_host = os.environ.get("DEER_FLOW_SANDBOX_HOST", "localhost")
|
||||
for container_name in container_names:
|
||||
data = inspections.get(container_name)
|
||||
if data is None:
|
||||
# Container disappeared between ps and inspect, or inspect failed
|
||||
continue
|
||||
created_at, host_port = data
|
||||
sandbox_id = container_name[len(self._container_prefix) + 1 :]
|
||||
sandbox_url = f"http://{sandbox_host}:{host_port}" if host_port else ""
|
||||
|
||||
infos.append(
|
||||
SandboxInfo(
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_url=sandbox_url,
|
||||
container_name=container_name,
|
||||
created_at=created_at,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(infos)} running sandbox container(s)")
|
||||
return infos
|
||||
|
||||
def _batch_inspect(self, container_names: list[str]) -> dict[str, tuple[float, int | None]]:
|
||||
"""Batch-inspect containers in a single subprocess call.
|
||||
|
||||
Returns a mapping of ``container_name -> (created_at, host_port)``.
|
||||
Missing containers or parse failures are silently dropped from the result.
|
||||
"""
|
||||
if not container_names:
|
||||
return {}
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[self._runtime, "inspect", *container_names],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=15,
|
||||
)
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError, OSError) as e:
|
||||
logger.warning(f"Failed to batch-inspect containers: {e}")
|
||||
return {}
|
||||
|
||||
if result.returncode != 0:
|
||||
stderr = (result.stderr or "").strip()
|
||||
logger.warning(
|
||||
"Failed to batch-inspect containers with %s inspect (returncode=%s, stderr=%s)",
|
||||
self._runtime,
|
||||
result.returncode,
|
||||
stderr or "<empty>",
|
||||
)
|
||||
return {}
|
||||
|
||||
try:
|
||||
payload = json.loads(result.stdout or "[]")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse docker inspect output as JSON: {e}")
|
||||
return {}
|
||||
|
||||
out: dict[str, tuple[float, int | None]] = {}
|
||||
for entry in payload:
|
||||
# ``Name`` is prefixed with ``/`` in the docker inspect response
|
||||
name = (entry.get("Name") or "").lstrip("/")
|
||||
if not name:
|
||||
continue
|
||||
created_at = _parse_docker_timestamp(entry.get("Created", ""))
|
||||
host_port = _extract_host_port(entry, 8080)
|
||||
out[name] = (created_at, host_port)
|
||||
return out
|
||||
|
||||
# ── Container operations ─────────────────────────────────────────────
|
||||
|
||||
def _start_container(
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
import json
|
||||
|
||||
from exa_py import Exa
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
|
||||
def _get_exa_client(tool_name: str = "web_search") -> Exa:
|
||||
config = get_app_config().get_tool_config(tool_name)
|
||||
api_key = None
|
||||
if config is not None and "api_key" in config.model_extra:
|
||||
api_key = config.model_extra.get("api_key")
|
||||
return Exa(api_key=api_key)
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str) -> str:
|
||||
"""Search the web.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
"""
|
||||
try:
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
max_results = 5
|
||||
search_type = "auto"
|
||||
contents_max_characters = 1000
|
||||
if config is not None:
|
||||
max_results = config.model_extra.get("max_results", max_results)
|
||||
search_type = config.model_extra.get("search_type", search_type)
|
||||
contents_max_characters = config.model_extra.get("contents_max_characters", contents_max_characters)
|
||||
|
||||
client = _get_exa_client()
|
||||
res = client.search(
|
||||
query,
|
||||
type=search_type,
|
||||
num_results=max_results,
|
||||
contents={"highlights": {"max_characters": contents_max_characters}},
|
||||
)
|
||||
|
||||
normalized_results = [
|
||||
{
|
||||
"title": result.title or "",
|
||||
"url": result.url or "",
|
||||
"snippet": "\n".join(result.highlights) if result.highlights else "",
|
||||
}
|
||||
for result in res.results
|
||||
]
|
||||
json_results = json.dumps(normalized_results, indent=2, ensure_ascii=False)
|
||||
return json_results
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
def web_fetch_tool(url: str) -> str:
|
||||
"""Fetch the contents of a web page at a given URL.
|
||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||
Do NOT add www. to URLs that do NOT have them.
|
||||
URLs must include the schema: https://example.com is a valid URL while example.com is an invalid URL.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch the contents of.
|
||||
"""
|
||||
try:
|
||||
client = _get_exa_client("web_fetch")
|
||||
res = client.get_contents([url], text={"max_characters": 4096})
|
||||
|
||||
if res.results:
|
||||
result = res.results[0]
|
||||
title = result.title or "Untitled"
|
||||
text = result.text or ""
|
||||
return f"# {title}\n\n{text[:4096]}"
|
||||
else:
|
||||
return "Error: No results found"
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
@@ -6,10 +6,10 @@ from langchain.tools import tool
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
|
||||
def _get_firecrawl_client() -> FirecrawlApp:
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
def _get_firecrawl_client(tool_name: str = "web_search") -> FirecrawlApp:
|
||||
config = get_app_config().get_tool_config(tool_name)
|
||||
api_key = None
|
||||
if config is not None:
|
||||
if config is not None and "api_key" in config.model_extra:
|
||||
api_key = config.model_extra.get("api_key")
|
||||
return FirecrawlApp(api_key=api_key) # type: ignore[arg-type]
|
||||
|
||||
@@ -27,7 +27,7 @@ def web_search_tool(query: str) -> str:
|
||||
if config is not None:
|
||||
max_results = config.model_extra.get("max_results", max_results)
|
||||
|
||||
client = _get_firecrawl_client()
|
||||
client = _get_firecrawl_client("web_search")
|
||||
result = client.search(query, limit=max_results)
|
||||
|
||||
# result.web contains list of SearchResultWeb objects
|
||||
@@ -58,7 +58,7 @@ def web_fetch_tool(url: str) -> str:
|
||||
url: The URL to fetch the contents of.
|
||||
"""
|
||||
try:
|
||||
client = _get_firecrawl_client()
|
||||
client = _get_firecrawl_client("web_fetch")
|
||||
result = client.scrape(url, formats=["markdown"])
|
||||
|
||||
markdown_content = result.markdown or ""
|
||||
|
||||
@@ -27,6 +27,10 @@ class ModelConfig(BaseModel):
|
||||
default_factory=lambda: None,
|
||||
description="Extra settings to be passed to the model when thinking is enabled",
|
||||
)
|
||||
when_thinking_disabled: dict | None = Field(
|
||||
default_factory=lambda: None,
|
||||
description="Extra settings to be passed to the model when thinking is disabled",
|
||||
)
|
||||
supports_vision: bool = Field(default_factory=lambda: False, description="Whether the model supports vision/image inputs")
|
||||
thinking: dict | None = Field(
|
||||
default_factory=lambda: None,
|
||||
|
||||
@@ -56,6 +56,7 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
|
||||
"supports_thinking",
|
||||
"supports_reasoning_effort",
|
||||
"when_thinking_enabled",
|
||||
"when_thinking_disabled",
|
||||
"thinking",
|
||||
"supports_vision",
|
||||
},
|
||||
@@ -72,21 +73,24 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
|
||||
raise ValueError(f"Model {name} does not support thinking. Set `supports_thinking` to true in the `config.yaml` to enable thinking.") from None
|
||||
if effective_wte:
|
||||
model_settings_from_config.update(effective_wte)
|
||||
if not thinking_enabled and has_thinking_settings:
|
||||
if effective_wte.get("extra_body", {}).get("thinking", {}).get("type"):
|
||||
if not thinking_enabled:
|
||||
if model_config.when_thinking_disabled is not None:
|
||||
# User-provided disable settings take full precedence
|
||||
model_settings_from_config.update(model_config.when_thinking_disabled)
|
||||
elif has_thinking_settings and effective_wte.get("extra_body", {}).get("thinking", {}).get("type"):
|
||||
# OpenAI-compatible gateway: thinking is nested under extra_body
|
||||
model_settings_from_config["extra_body"] = _deep_merge_dicts(
|
||||
model_settings_from_config.get("extra_body"),
|
||||
{"thinking": {"type": "disabled"}},
|
||||
)
|
||||
model_settings_from_config["reasoning_effort"] = "minimal"
|
||||
elif disable_chat_template_kwargs := _vllm_disable_chat_template_kwargs(effective_wte.get("extra_body", {}).get("chat_template_kwargs") or {}):
|
||||
elif has_thinking_settings and (disable_chat_template_kwargs := _vllm_disable_chat_template_kwargs(effective_wte.get("extra_body", {}).get("chat_template_kwargs") or {})):
|
||||
# vLLM uses chat template kwargs to switch thinking on/off.
|
||||
model_settings_from_config["extra_body"] = _deep_merge_dicts(
|
||||
model_settings_from_config.get("extra_body"),
|
||||
{"chat_template_kwargs": disable_chat_template_kwargs},
|
||||
)
|
||||
elif effective_wte.get("thinking", {}).get("type"):
|
||||
elif has_thinking_settings and effective_wte.get("thinking", {}).get("type"):
|
||||
# Native langchain_anthropic: thinking is a direct constructor parameter
|
||||
model_settings_from_config["thinking"] = {"type": "disabled"}
|
||||
if not model_config.supports_reasoning_effort:
|
||||
|
||||
@@ -48,6 +48,10 @@ class CodexChatModel(BaseChatModel):
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "codex-responses"
|
||||
@@ -216,18 +220,48 @@ class CodexChatModel(BaseChatModel):
|
||||
def _stream_response(self, headers: dict, payload: dict) -> dict:
|
||||
"""Stream SSE from Codex API and collect the final response."""
|
||||
completed_response = None
|
||||
streamed_output_items: dict[int, dict[str, Any]] = {}
|
||||
|
||||
with httpx.Client(timeout=300) as client:
|
||||
with client.stream("POST", f"{CODEX_BASE_URL}/responses", headers=headers, json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
for line in resp.iter_lines():
|
||||
data = self._parse_sse_data_line(line)
|
||||
if data and data.get("type") == "response.completed":
|
||||
if not data:
|
||||
continue
|
||||
|
||||
event_type = data.get("type")
|
||||
if event_type == "response.output_item.done":
|
||||
output_index = data.get("output_index")
|
||||
output_item = data.get("item")
|
||||
if isinstance(output_index, int) and isinstance(output_item, dict):
|
||||
streamed_output_items[output_index] = output_item
|
||||
elif event_type == "response.completed":
|
||||
completed_response = data["response"]
|
||||
|
||||
if not completed_response:
|
||||
raise RuntimeError("Codex API stream ended without response.completed event")
|
||||
|
||||
# ChatGPT Codex can emit the final assistant content only in stream events.
|
||||
# When response.completed arrives, response.output may still be empty.
|
||||
if streamed_output_items:
|
||||
merged_output = []
|
||||
response_output = completed_response.get("output")
|
||||
if isinstance(response_output, list):
|
||||
merged_output = list(response_output)
|
||||
|
||||
max_index = max(max(streamed_output_items), len(merged_output) - 1)
|
||||
if max_index >= 0 and len(merged_output) <= max_index:
|
||||
merged_output.extend([None] * (max_index + 1 - len(merged_output)))
|
||||
|
||||
for output_index, output_item in streamed_output_items.items():
|
||||
existing_item = merged_output[output_index]
|
||||
if not isinstance(existing_item, dict):
|
||||
merged_output[output_index] = output_item
|
||||
|
||||
completed_response = dict(completed_response)
|
||||
completed_response["output"] = [item for item in merged_output if isinstance(item, dict)]
|
||||
|
||||
return completed_response
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -23,6 +23,14 @@ class PatchedChatDeepSeek(ChatDeepSeek):
|
||||
request payload.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"api_key": "DEEPSEEK_API_KEY", "openai_api_key": "DEEPSEEK_API_KEY"}
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
|
||||
@@ -16,6 +16,8 @@ internal checkpoint callbacks that are not exposed in the Python public API.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
@@ -79,6 +81,9 @@ async def run_agent(
|
||||
run_id = record.run_id
|
||||
thread_id = record.thread_id
|
||||
requested_modes: set[str] = set(stream_modes or ["values"])
|
||||
pre_run_checkpoint_id: str | None = None
|
||||
pre_run_snapshot: dict[str, Any] | None = None
|
||||
snapshot_capture_failed = False
|
||||
|
||||
# Initialize RunJournal for event capture
|
||||
journal = None
|
||||
@@ -120,15 +125,23 @@ async def run_agent(
|
||||
# 1. Mark running
|
||||
await run_manager.set_status(run_id, RunStatus.running)
|
||||
|
||||
# Record pre-run checkpoint_id to support rollback (Phase 2).
|
||||
pre_run_checkpoint_id = None
|
||||
try:
|
||||
config_for_check = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await checkpointer.aget_tuple(config_for_check)
|
||||
if ckpt_tuple is not None:
|
||||
pre_run_checkpoint_id = getattr(ckpt_tuple, "config", {}).get("configurable", {}).get("checkpoint_id")
|
||||
except Exception:
|
||||
logger.debug("Could not get pre-run checkpoint_id for run %s", run_id)
|
||||
# Snapshot the latest pre-run checkpoint so rollback can restore it.
|
||||
if checkpointer is not None:
|
||||
try:
|
||||
config_for_check = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await checkpointer.aget_tuple(config_for_check)
|
||||
if ckpt_tuple is not None:
|
||||
ckpt_config = getattr(ckpt_tuple, "config", {}).get("configurable", {})
|
||||
pre_run_checkpoint_id = ckpt_config.get("checkpoint_id")
|
||||
pre_run_snapshot = {
|
||||
"checkpoint_ns": ckpt_config.get("checkpoint_ns", ""),
|
||||
"checkpoint": copy.deepcopy(getattr(ckpt_tuple, "checkpoint", {})),
|
||||
"metadata": copy.deepcopy(getattr(ckpt_tuple, "metadata", {})),
|
||||
"pending_writes": copy.deepcopy(getattr(ckpt_tuple, "pending_writes", []) or []),
|
||||
}
|
||||
except Exception:
|
||||
snapshot_capture_failed = True
|
||||
logger.warning("Could not capture pre-run checkpoint snapshot for run %s", run_id, exc_info=True)
|
||||
|
||||
# 2. Publish metadata — useStream needs both run_id AND thread_id
|
||||
await bridge.publish(
|
||||
@@ -234,17 +247,18 @@ async def run_agent(
|
||||
action = record.abort_action
|
||||
if action == "rollback":
|
||||
await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user")
|
||||
# TODO(Phase 2): Implement full checkpoint rollback.
|
||||
# Use pre_run_checkpoint_id to revert the thread's checkpoint
|
||||
# to the state before this run started. Requires a
|
||||
# checkpointer.adelete() or equivalent API.
|
||||
try:
|
||||
if checkpointer is not None and pre_run_checkpoint_id is not None:
|
||||
# Phase 2: roll back to pre_run_checkpoint_id
|
||||
pass
|
||||
logger.info("Run %s rolled back", run_id)
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
pre_run_checkpoint_id=pre_run_checkpoint_id,
|
||||
pre_run_snapshot=pre_run_snapshot,
|
||||
snapshot_capture_failed=snapshot_capture_failed,
|
||||
)
|
||||
logger.info("Run %s rolled back to pre-run checkpoint %s", run_id, pre_run_checkpoint_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to rollback checkpoint for run %s", run_id)
|
||||
logger.warning("Failed to rollback checkpoint for run %s", run_id, exc_info=True)
|
||||
else:
|
||||
await run_manager.set_status(run_id, RunStatus.interrupted)
|
||||
else:
|
||||
@@ -254,7 +268,18 @@ async def run_agent(
|
||||
action = record.abort_action
|
||||
if action == "rollback":
|
||||
await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user")
|
||||
logger.info("Run %s was cancelled (rollback)", run_id)
|
||||
try:
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
pre_run_checkpoint_id=pre_run_checkpoint_id,
|
||||
pre_run_snapshot=pre_run_snapshot,
|
||||
snapshot_capture_failed=snapshot_capture_failed,
|
||||
)
|
||||
logger.info("Run %s was cancelled and rolled back", run_id)
|
||||
except Exception:
|
||||
logger.warning("Run %s cancellation rollback failed", run_id, exc_info=True)
|
||||
else:
|
||||
await run_manager.set_status(run_id, RunStatus.interrupted)
|
||||
logger.info("Run %s was cancelled", run_id)
|
||||
@@ -313,6 +338,104 @@ async def run_agent(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _call_checkpointer_method(checkpointer: Any, async_name: str, sync_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call a checkpointer method, supporting async and sync variants."""
|
||||
method = getattr(checkpointer, async_name, None) or getattr(checkpointer, sync_name, None)
|
||||
if method is None:
|
||||
raise AttributeError(f"Missing checkpointer method: {async_name}/{sync_name}")
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
|
||||
async def _rollback_to_pre_run_checkpoint(
|
||||
*,
|
||||
checkpointer: Any,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
pre_run_checkpoint_id: str | None,
|
||||
pre_run_snapshot: dict[str, Any] | None,
|
||||
snapshot_capture_failed: bool,
|
||||
) -> None:
|
||||
"""Restore thread state to the checkpoint snapshot captured before run start."""
|
||||
if checkpointer is None:
|
||||
logger.info("Run %s rollback requested but no checkpointer is configured", run_id)
|
||||
return
|
||||
|
||||
if snapshot_capture_failed:
|
||||
logger.warning("Run %s rollback skipped: pre-run checkpoint snapshot capture failed", run_id)
|
||||
return
|
||||
|
||||
if pre_run_snapshot is None:
|
||||
await _call_checkpointer_method(checkpointer, "adelete_thread", "delete_thread", thread_id)
|
||||
logger.info("Run %s rollback reset thread %s to empty state", run_id, thread_id)
|
||||
return
|
||||
|
||||
checkpoint_to_restore = None
|
||||
metadata_to_restore: dict[str, Any] = {}
|
||||
checkpoint_ns = ""
|
||||
checkpoint = pre_run_snapshot.get("checkpoint")
|
||||
if not isinstance(checkpoint, dict):
|
||||
logger.warning("Run %s rollback skipped: invalid pre-run checkpoint snapshot", run_id)
|
||||
return
|
||||
checkpoint_to_restore = checkpoint
|
||||
if checkpoint_to_restore.get("id") is None and pre_run_checkpoint_id is not None:
|
||||
checkpoint_to_restore = {**checkpoint_to_restore, "id": pre_run_checkpoint_id}
|
||||
if checkpoint_to_restore.get("id") is None:
|
||||
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
|
||||
return
|
||||
metadata = pre_run_snapshot.get("metadata", {})
|
||||
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
|
||||
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
|
||||
checkpoint_ns = raw_checkpoint_ns if isinstance(raw_checkpoint_ns, str) else ""
|
||||
|
||||
channel_versions = checkpoint_to_restore.get("channel_versions")
|
||||
new_versions = dict(channel_versions) if isinstance(channel_versions, dict) else {}
|
||||
|
||||
restore_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}}
|
||||
restored_config = await _call_checkpointer_method(
|
||||
checkpointer,
|
||||
"aput",
|
||||
"put",
|
||||
restore_config,
|
||||
checkpoint_to_restore,
|
||||
metadata_to_restore if isinstance(metadata_to_restore, dict) else {},
|
||||
new_versions,
|
||||
)
|
||||
if not isinstance(restored_config, dict):
|
||||
raise RuntimeError(f"Run {run_id} rollback restore returned invalid config: expected dict")
|
||||
restored_configurable = restored_config.get("configurable", {})
|
||||
if not isinstance(restored_configurable, dict):
|
||||
raise RuntimeError(f"Run {run_id} rollback restore returned invalid config payload")
|
||||
restored_checkpoint_id = restored_configurable.get("checkpoint_id")
|
||||
if not restored_checkpoint_id:
|
||||
raise RuntimeError(f"Run {run_id} rollback restore did not return checkpoint_id")
|
||||
|
||||
pending_writes = pre_run_snapshot.get("pending_writes", [])
|
||||
if not pending_writes:
|
||||
return
|
||||
|
||||
writes_by_task: dict[str, list[tuple[str, Any]]] = {}
|
||||
for item in pending_writes:
|
||||
if not isinstance(item, (tuple, list)) or len(item) != 3:
|
||||
raise RuntimeError(f"Run {run_id} rollback failed: pending_write is not a 3-tuple: {item!r}")
|
||||
task_id, channel, value = item
|
||||
if not isinstance(channel, str):
|
||||
raise RuntimeError(f"Run {run_id} rollback failed: pending_write has non-string channel: task_id={task_id!r}, channel={channel!r}")
|
||||
writes_by_task.setdefault(str(task_id), []).append((channel, value))
|
||||
|
||||
for task_id, writes in writes_by_task.items():
|
||||
await _call_checkpointer_method(
|
||||
checkpointer,
|
||||
"aput_writes",
|
||||
"put_writes",
|
||||
restored_config,
|
||||
writes,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
|
||||
def _lg_mode_to_sse_event(mode: str) -> str:
|
||||
"""Map LangGraph internal stream_mode name to SSE event name.
|
||||
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import threading
|
||||
import weakref
|
||||
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
|
||||
_FILE_OPERATION_LOCKS: dict[tuple[str, str], threading.Lock] = {}
|
||||
# Use WeakValueDictionary to prevent memory leak in long-running processes.
|
||||
# Locks are automatically removed when no longer referenced by any thread.
|
||||
_LockKey = tuple[str, str]
|
||||
_FILE_OPERATION_LOCKS: weakref.WeakValueDictionary[_LockKey, threading.Lock] = weakref.WeakValueDictionary()
|
||||
_FILE_OPERATION_LOCKS_GUARD = threading.Lock()
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,8 @@ Do NOT use for simple single commands - use bash tool directly instead.""",
|
||||
- Use parallel execution when commands are independent
|
||||
- Report both stdout and stderr when relevant
|
||||
- Handle errors gracefully and explain what went wrong
|
||||
- Use absolute paths for file operations
|
||||
- Use workspace-relative paths for files under the default workspace, uploads, and outputs directories
|
||||
- Use absolute paths only when the task references deployment-configured custom mounts outside the default workspace layout
|
||||
- Be cautious with destructive operations (rm, overwrite, etc.)
|
||||
</guidelines>
|
||||
|
||||
@@ -38,6 +39,8 @@ You have access to the sandbox environment:
|
||||
- User workspace: `/mnt/user-data/workspace`
|
||||
- Output files: `/mnt/user-data/outputs`
|
||||
- Deployment-configured custom mounts may also be available at other absolute container paths; use them directly when the task references those mounted directories
|
||||
- Treat `/mnt/user-data/workspace` as the default working directory for file IO
|
||||
- Prefer relative paths from the workspace, such as `hello.txt`, `../uploads/input.csv`, and `../outputs/result.md`, when composing commands or helper scripts
|
||||
</working_directory>
|
||||
""",
|
||||
tools=["bash", "ls", "read_file", "write_file", "str_replace"], # Sandbox tools only
|
||||
|
||||
@@ -39,6 +39,8 @@ You have access to the same sandbox environment as the parent agent:
|
||||
- User workspace: `/mnt/user-data/workspace`
|
||||
- Output files: `/mnt/user-data/outputs`
|
||||
- Deployment-configured custom mounts may also be available at other absolute container paths; use them directly when the task references those mounted directories
|
||||
- Treat `/mnt/user-data/workspace` as the default working directory for coding and file IO
|
||||
- Prefer relative paths from the workspace, such as `hello.txt`, `../uploads/input.csv`, and `../outputs/result.md`, when writing scripts or shell commands
|
||||
</working_directory>
|
||||
""",
|
||||
tools=None, # Inherit all tools from parent
|
||||
|
||||
@@ -6,7 +6,7 @@ import threading
|
||||
import uuid
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
@@ -30,6 +30,7 @@ class SubagentStatus(Enum):
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
TIMED_OUT = "timed_out"
|
||||
|
||||
|
||||
@@ -56,6 +57,7 @@ class SubagentResult:
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
ai_messages: list[dict[str, Any]] | None = None
|
||||
cancel_event: threading.Event = field(default_factory=threading.Event, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize mutable defaults."""
|
||||
@@ -74,6 +76,9 @@ _scheduler_pool = ThreadPoolExecutor(max_workers=3, thread_name_prefix="subagent
|
||||
# Larger pool to avoid blocking when scheduler submits execution tasks
|
||||
_execution_pool = ThreadPoolExecutor(max_workers=3, thread_name_prefix="subagent-exec-")
|
||||
|
||||
# Dedicated pool for sync execute() calls made from an already-running event loop.
|
||||
_isolated_loop_pool = ThreadPoolExecutor(max_workers=3, thread_name_prefix="subagent-isolated-")
|
||||
|
||||
|
||||
def _filter_tools(
|
||||
all_tools: list[BaseTool],
|
||||
@@ -241,7 +246,31 @@ class SubagentExecutor:
|
||||
# Use stream instead of invoke to get real-time updates
|
||||
# This allows us to collect AI messages as they are generated
|
||||
final_state = None
|
||||
|
||||
# Pre-check: bail out immediately if already cancelled before streaming starts
|
||||
if result.cancel_event.is_set():
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} cancelled before streaming")
|
||||
with _background_tasks_lock:
|
||||
if result.status == SubagentStatus.RUNNING:
|
||||
result.status = SubagentStatus.CANCELLED
|
||||
result.error = "Cancelled by user"
|
||||
result.completed_at = datetime.now()
|
||||
return result
|
||||
|
||||
async for chunk in agent.astream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type]
|
||||
# Cooperative cancellation: check if parent requested stop.
|
||||
# Note: cancellation is only detected at astream iteration boundaries,
|
||||
# so long-running tool calls within a single iteration will not be
|
||||
# interrupted until the next chunk is yielded.
|
||||
if result.cancel_event.is_set():
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} cancelled by parent")
|
||||
with _background_tasks_lock:
|
||||
if result.status == SubagentStatus.RUNNING:
|
||||
result.status = SubagentStatus.CANCELLED
|
||||
result.error = "Cancelled by user"
|
||||
result.completed_at = datetime.now()
|
||||
return result
|
||||
|
||||
final_state = chunk
|
||||
|
||||
# Extract AI messages from the current state
|
||||
@@ -348,12 +377,55 @@ class SubagentExecutor:
|
||||
|
||||
return result
|
||||
|
||||
def _execute_in_isolated_loop(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
|
||||
"""Execute the subagent in a completely fresh event loop.
|
||||
|
||||
This method is designed to run in a separate thread to ensure complete
|
||||
isolation from any parent event loop, preventing conflicts with asyncio
|
||||
primitives that may be bound to the parent loop (e.g., httpx clients).
|
||||
"""
|
||||
try:
|
||||
previous_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
previous_loop = None
|
||||
|
||||
# Create and set a new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop.run_until_complete(self._aexecute(task, result_holder))
|
||||
finally:
|
||||
try:
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
for task_obj in pending:
|
||||
task_obj.cancel()
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
loop.run_until_complete(loop.shutdown_default_executor())
|
||||
except Exception:
|
||||
logger.debug(
|
||||
f"[trace={self.trace_id}] Failed while cleaning up isolated event loop for subagent {self.config.name}",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
loop.close()
|
||||
finally:
|
||||
asyncio.set_event_loop(previous_loop)
|
||||
|
||||
def execute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
|
||||
"""Execute a task synchronously (wrapper around async execution).
|
||||
|
||||
This method runs the async execution in a new event loop, allowing
|
||||
asynchronous tools (like MCP tools) to be used within the thread pool.
|
||||
|
||||
When called from within an already-running event loop (e.g., when the
|
||||
parent agent is async), this method isolates the subagent execution in
|
||||
a separate thread to avoid event loop conflicts with shared async
|
||||
primitives like httpx clients.
|
||||
|
||||
Args:
|
||||
task: The task description for the subagent.
|
||||
result_holder: Optional pre-created result object to update during execution.
|
||||
@@ -361,16 +433,18 @@ class SubagentExecutor:
|
||||
Returns:
|
||||
SubagentResult with the execution result.
|
||||
"""
|
||||
# Run the async execution in a new event loop
|
||||
# This is necessary because:
|
||||
# 1. We may have async-only tools (like MCP tools)
|
||||
# 2. We're running inside a ThreadPoolExecutor which doesn't have an event loop
|
||||
#
|
||||
# Note: _aexecute() catches all exceptions internally, so this outer
|
||||
# try-except only handles asyncio.run() failures (e.g., if called from
|
||||
# an async context where an event loop already exists). Subagent execution
|
||||
# errors are handled within _aexecute() and returned as FAILED status.
|
||||
try:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop is not None and loop.is_running():
|
||||
logger.debug(f"[trace={self.trace_id}] Subagent {self.config.name} detected running event loop, using isolated thread")
|
||||
future = _isolated_loop_pool.submit(self._execute_in_isolated_loop, task, result_holder)
|
||||
return future.result()
|
||||
|
||||
# Standard path: no running event loop, use asyncio.run
|
||||
return asyncio.run(self._aexecute(task, result_holder))
|
||||
except Exception as e:
|
||||
logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} execution failed")
|
||||
@@ -437,10 +511,12 @@ class SubagentExecutor:
|
||||
except FuturesTimeoutError:
|
||||
logger.error(f"[trace={self.trace_id}] Subagent {self.config.name} execution timed out after {self.config.timeout_seconds}s")
|
||||
with _background_tasks_lock:
|
||||
_background_tasks[task_id].status = SubagentStatus.TIMED_OUT
|
||||
_background_tasks[task_id].error = f"Execution timed out after {self.config.timeout_seconds} seconds"
|
||||
_background_tasks[task_id].completed_at = datetime.now()
|
||||
# Cancel the future (best effort - may not stop the actual execution)
|
||||
if _background_tasks[task_id].status == SubagentStatus.RUNNING:
|
||||
_background_tasks[task_id].status = SubagentStatus.TIMED_OUT
|
||||
_background_tasks[task_id].error = f"Execution timed out after {self.config.timeout_seconds} seconds"
|
||||
_background_tasks[task_id].completed_at = datetime.now()
|
||||
# Signal cooperative cancellation and cancel the future
|
||||
result_holder.cancel_event.set()
|
||||
execution_future.cancel()
|
||||
except Exception as e:
|
||||
logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} async execution failed")
|
||||
@@ -456,6 +532,24 @@ class SubagentExecutor:
|
||||
MAX_CONCURRENT_SUBAGENTS = 3
|
||||
|
||||
|
||||
def request_cancel_background_task(task_id: str) -> None:
|
||||
"""Signal a running background task to stop.
|
||||
|
||||
Sets the cancel_event on the task, which is checked cooperatively
|
||||
by ``_aexecute`` during ``agent.astream()`` iteration. This allows
|
||||
subagent threads — which cannot be force-killed via ``Future.cancel()``
|
||||
— to stop at the next iteration boundary.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to cancel.
|
||||
"""
|
||||
with _background_tasks_lock:
|
||||
result = _background_tasks.get(task_id)
|
||||
if result is not None:
|
||||
result.cancel_event.set()
|
||||
logger.info("Requested cancellation for background task %s", task_id)
|
||||
|
||||
|
||||
def get_background_task_result(task_id: str) -> SubagentResult | None:
|
||||
"""Get the result of a background task.
|
||||
|
||||
@@ -503,6 +597,7 @@ def cleanup_background_task(task_id: str) -> None:
|
||||
is_terminal_status = result.status in {
|
||||
SubagentStatus.COMPLETED,
|
||||
SubagentStatus.FAILED,
|
||||
SubagentStatus.CANCELLED,
|
||||
SubagentStatus.TIMED_OUT,
|
||||
}
|
||||
if is_terminal_status or result.completed_at is not None:
|
||||
|
||||
@@ -14,7 +14,7 @@ from deerflow.agents.lead_agent.prompt import get_skills_prompt_section
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
|
||||
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
|
||||
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result
|
||||
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -182,6 +182,11 @@ async def task_tool(
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task failed. Error: {result.error}"
|
||||
elif result.status == SubagentStatus.CANCELLED:
|
||||
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return "Task cancelled by user."
|
||||
elif result.status == SubagentStatus.TIMED_OUT:
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
|
||||
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
||||
@@ -204,6 +209,11 @@ async def task_tool(
|
||||
writer({"type": "task_timed_out", "task_id": task_id})
|
||||
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
||||
except asyncio.CancelledError:
|
||||
# Signal the background subagent thread to stop cooperatively.
|
||||
# Without this, the thread (running in ThreadPoolExecutor with its
|
||||
# own event loop via asyncio.run) would continue executing even
|
||||
# after the parent task is cancelled.
|
||||
request_cancel_background_task(task_id)
|
||||
|
||||
async def cleanup_when_done() -> None:
|
||||
max_cleanup_polls = max_poll_count
|
||||
@@ -214,7 +224,7 @@ async def task_tool(
|
||||
if result is None:
|
||||
return
|
||||
|
||||
if result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None:
|
||||
if result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.CANCELLED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None:
|
||||
cleanup_background_task(task_id)
|
||||
return
|
||||
|
||||
|
||||
@@ -11,7 +11,11 @@ from weakref import WeakValueDictionary
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
<<<<<<< HEAD
|
||||
from deerflow.agents.lead_agent.prompt import clear_skills_system_prompt_cache
|
||||
=======
|
||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||
>>>>>>> main
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.mcp.tools import _make_sync_tool_wrapper
|
||||
from deerflow.skills.manager import (
|
||||
@@ -115,7 +119,11 @@ async def _skill_manage_impl(
|
||||
name,
|
||||
_history_record(action="create", file_path="SKILL.md", prev_content=None, new_content=content, thread_id=thread_id, scanner=scan),
|
||||
)
|
||||
<<<<<<< HEAD
|
||||
clear_skills_system_prompt_cache()
|
||||
=======
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
>>>>>>> main
|
||||
return f"Created custom skill '{name}'."
|
||||
|
||||
if action == "edit":
|
||||
@@ -132,7 +140,11 @@ async def _skill_manage_impl(
|
||||
name,
|
||||
_history_record(action="edit", file_path="SKILL.md", prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan),
|
||||
)
|
||||
<<<<<<< HEAD
|
||||
clear_skills_system_prompt_cache()
|
||||
=======
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
>>>>>>> main
|
||||
return f"Updated custom skill '{name}'."
|
||||
|
||||
if action == "patch":
|
||||
@@ -156,7 +168,11 @@ async def _skill_manage_impl(
|
||||
name,
|
||||
_history_record(action="patch", file_path="SKILL.md", prev_content=prev_content, new_content=new_content, thread_id=thread_id, scanner=scan),
|
||||
)
|
||||
<<<<<<< HEAD
|
||||
clear_skills_system_prompt_cache()
|
||||
=======
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
>>>>>>> main
|
||||
return f"Patched custom skill '{name}' ({replacement_count} replacement(s) applied, {occurrences} match(es) found)."
|
||||
|
||||
if action == "delete":
|
||||
@@ -169,7 +185,11 @@ async def _skill_manage_impl(
|
||||
_history_record(action="delete", file_path="SKILL.md", prev_content=prev_content, new_content=None, thread_id=thread_id, scanner={"decision": "allow", "reason": "Deletion requested."}),
|
||||
)
|
||||
await _to_thread(shutil.rmtree, skill_dir)
|
||||
<<<<<<< HEAD
|
||||
clear_skills_system_prompt_cache()
|
||||
=======
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
>>>>>>> main
|
||||
return f"Deleted custom skill '{name}'."
|
||||
|
||||
if action == "write_file":
|
||||
|
||||
@@ -7,6 +7,7 @@ dependencies = [
|
||||
"agent-client-protocol>=0.4.0",
|
||||
"agent-sandbox>=0.0.19",
|
||||
"dotenv>=0.9.9",
|
||||
"exa-py>=1.0.0",
|
||||
"httpx>=0.28.0",
|
||||
"kubernetes>=30.0.0",
|
||||
"langchain>=1.2.3",
|
||||
@@ -44,6 +45,7 @@ postgres = [
|
||||
"psycopg[binary]>=3.3.3",
|
||||
"psycopg-pool>=3.3.0",
|
||||
]
|
||||
ollama = ["langchain-ollama>=0.3.0"]
|
||||
pymupdf = ["pymupdf4llm>=0.0.17"]
|
||||
|
||||
[build-system]
|
||||
|
||||
Reference in New Issue
Block a user