Merge branch 'main' into release/2.0-rc

This commit is contained in:
jiangfeng.11
2026-04-11 10:34:31 +08:00
152 changed files with 16060 additions and 499 deletions
@@ -2,8 +2,14 @@ from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointe
from .factory import create_deerflow_agent
from .features import Next, Prev, RuntimeFeatures
from .lead_agent import make_lead_agent
from .lead_agent.prompt import prime_enabled_skills_cache
from .thread_state import SandboxState, ThreadState
# LangGraph imports deerflow.agents when registering the graph. Prime the
# enabled-skills cache here so the request path can usually read a warm cache
# without forcing synchronous filesystem work during prompt module import.
prime_enabled_skills_cache()
__all__ = [
"create_deerflow_agent",
"RuntimeFeatures",
@@ -17,6 +17,7 @@ For sync usage see :mod:`deerflow.agents.checkpointer.provider`.
from __future__ import annotations
import asyncio
import contextlib
import logging
from collections.abc import AsyncIterator
@@ -54,7 +55,7 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
ensure_sqlite_parent_dir(conn_str)
await asyncio.to_thread(ensure_sqlite_parent_dir, conn_str)
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
await saver.setup()
yield saver
@@ -289,14 +289,14 @@ def make_lead_agent(config: RunnableConfig):
agent_name = cfg.get("agent_name")
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
# Custom agent model or fallback to global/default model resolution
agent_model_name = agent_config.model if agent_config and agent_config.model else _resolve_model_name()
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
agent_model_name = agent_config.model if agent_config and agent_config.model else None
# Final model name resolution with request override, then agent config, then global default
model_name = requested_model_name or agent_model_name
# Final model name resolution: request agent config global default, with fallback for unknown names
model_name = _resolve_model_name(requested_model_name or agent_model_name)
app_config = get_app_config()
model_config = app_config.get_model_config(model_name) if model_name else None
model_config = app_config.get_model_config(model_name)
if model_config is None:
raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.")
@@ -1,20 +1,167 @@
import asyncio
import logging
import threading
from datetime import datetime
from functools import lru_cache
from deerflow.config.agents_config import load_agent_soul
from deerflow.skills import load_skills
from deerflow.skills.types import Skill
from deerflow.subagents import get_available_subagent_names
logger = logging.getLogger(__name__)
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
_enabled_skills_lock = threading.Lock()
_enabled_skills_cache: list[Skill] | None = None
_enabled_skills_refresh_active = False
_enabled_skills_refresh_version = 0
_enabled_skills_refresh_event = threading.Event()
def _load_enabled_skills_sync() -> list[Skill]:
return list(load_skills(enabled_only=True))
def _start_enabled_skills_refresh_thread() -> None:
threading.Thread(
target=_refresh_enabled_skills_cache_worker,
name="deerflow-enabled-skills-loader",
daemon=True,
).start()
def _refresh_enabled_skills_cache_worker() -> None:
global _enabled_skills_cache, _enabled_skills_refresh_active
while True:
with _enabled_skills_lock:
target_version = _enabled_skills_refresh_version
try:
skills = _load_enabled_skills_sync()
except Exception:
logger.exception("Failed to load enabled skills for prompt injection")
skills = []
with _enabled_skills_lock:
if _enabled_skills_refresh_version == target_version:
_enabled_skills_cache = skills
_enabled_skills_refresh_active = False
_enabled_skills_refresh_event.set()
return
# A newer invalidation happened while loading. Keep the worker alive
# and loop again so the cache always converges on the latest version.
_enabled_skills_cache = None
def _ensure_enabled_skills_cache() -> threading.Event:
global _enabled_skills_refresh_active
with _enabled_skills_lock:
if _enabled_skills_cache is not None:
_enabled_skills_refresh_event.set()
return _enabled_skills_refresh_event
if _enabled_skills_refresh_active:
return _enabled_skills_refresh_event
_enabled_skills_refresh_active = True
_enabled_skills_refresh_event.clear()
_start_enabled_skills_refresh_thread()
return _enabled_skills_refresh_event
def _invalidate_enabled_skills_cache() -> threading.Event:
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
_get_cached_skills_prompt_section.cache_clear()
with _enabled_skills_lock:
_enabled_skills_cache = None
_enabled_skills_refresh_version += 1
_enabled_skills_refresh_event.clear()
if _enabled_skills_refresh_active:
return _enabled_skills_refresh_event
_enabled_skills_refresh_active = True
_start_enabled_skills_refresh_thread()
return _enabled_skills_refresh_event
def prime_enabled_skills_cache() -> None:
_ensure_enabled_skills_cache()
def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS) -> bool:
if _ensure_enabled_skills_cache().wait(timeout=timeout_seconds):
return True
logger.warning("Timed out waiting %.1fs for enabled skills cache warm-up", timeout_seconds)
return False
def _get_enabled_skills():
with _enabled_skills_lock:
cached = _enabled_skills_cache
if cached is not None:
return list(cached)
_ensure_enabled_skills_cache()
return []
def _skill_mutability_label(category: str) -> str:
return "[custom, editable]" if category == "custom" else "[built-in]"
def clear_skills_system_prompt_cache() -> None:
_invalidate_enabled_skills_cache()
async def refresh_skills_system_prompt_cache_async() -> None:
await asyncio.to_thread(_invalidate_enabled_skills_cache().wait)
def _reset_skills_system_prompt_cache_state() -> None:
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
_get_cached_skills_prompt_section.cache_clear()
with _enabled_skills_lock:
_enabled_skills_cache = None
_enabled_skills_refresh_active = False
_enabled_skills_refresh_version = 0
_enabled_skills_refresh_event.clear()
def _refresh_enabled_skills_cache() -> None:
"""Backward-compatible test helper for direct synchronous reload."""
try:
return list(load_skills(enabled_only=True))
skills = _load_enabled_skills_sync()
except Exception:
logger.exception("Failed to load enabled skills for prompt injection")
return []
skills = []
with _enabled_skills_lock:
_enabled_skills_cache = skills
_enabled_skills_refresh_active = False
_enabled_skills_refresh_event.set()
def _build_skill_evolution_section(skill_evolution_enabled: bool) -> str:
if not skill_evolution_enabled:
return ""
return """
## Skill Self-Evolution
After completing a task, consider creating or updating a skill when:
- The task required 5+ tool calls to resolve
- You overcame non-obvious errors or pitfalls
- The user corrected your approach and the corrected version worked
- You discovered a non-trivial, recurring workflow
If you used a skill and encountered issues not covered by it, patch it immediately.
Prefer patch over edit. Before creating a new skill, confirm with the user first.
Skip simple one-off tasks.
"""
def _skill_mutability_label(category: str) -> str:
@@ -294,6 +441,9 @@ You: "Deploying to staging..." [proceed]
- Use `read_file` tool to read uploaded files using their paths from the list
- For PDF, PPT, Excel, and Word files, converted Markdown versions (*.md) are available alongside originals
- All temporary work happens in `/mnt/user-data/workspace`
- Treat `/mnt/user-data/workspace` as your default current working directory for coding and file-editing tasks
- When writing scripts or commands that create/read files from the workspace, prefer relative paths such as `hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`
- Avoid hardcoding `/mnt/user-data/...` inside generated scripts when a relative path from the workspace is enough
- Final deliverables must be copied to `/mnt/user-data/outputs` and presented using `present_file` tool
{acp_section}
</working_directory>
@@ -4,7 +4,7 @@ import logging
import threading
import time
from dataclasses import dataclass, field
from datetime import datetime
from datetime import UTC, datetime
from typing import Any
from deerflow.config.memory_config import get_memory_config
@@ -18,7 +18,7 @@ class ConversationContext:
thread_id: str
messages: list[Any]
timestamp: datetime = field(default_factory=datetime.utcnow)
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
agent_name: str | None = None
correction_detected: bool = False
reinforcement_detected: bool = False
@@ -4,7 +4,7 @@ import abc
import json
import logging
import threading
from datetime import datetime
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
@@ -15,11 +15,16 @@ from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__)
def utc_now_iso_z() -> str:
"""Current UTC time as ISO-8601 with ``Z`` suffix (matches prior naive-UTC output)."""
return datetime.now(UTC).isoformat().removesuffix("+00:00") + "Z"
def create_empty_memory() -> dict[str, Any]:
"""Create an empty memory structure."""
return {
"version": "1.0",
"lastUpdated": datetime.utcnow().isoformat() + "Z",
"lastUpdated": utc_now_iso_z(),
"user": {
"workContext": {"summary": "", "updatedAt": ""},
"personalContext": {"summary": "", "updatedAt": ""},
@@ -137,7 +142,7 @@ class FileMemoryStorage(MemoryStorage):
try:
file_path.parent.mkdir(parents=True, exist_ok=True)
memory_data["lastUpdated"] = datetime.utcnow().isoformat() + "Z"
memory_data["lastUpdated"] = utc_now_iso_z()
temp_path = file_path.with_suffix(".tmp")
with open(temp_path, "w", encoding="utf-8") as f:
@@ -5,14 +5,17 @@ import logging
import math
import re
import uuid
from datetime import datetime
from typing import Any
from deerflow.agents.memory.prompt import (
MEMORY_UPDATE_PROMPT,
format_conversation_for_update,
)
from deerflow.agents.memory.storage import create_empty_memory, get_memory_storage
from deerflow.agents.memory.storage import (
create_empty_memory,
get_memory_storage,
utc_now_iso_z,
)
from deerflow.config.memory_config import get_memory_config
from deerflow.models import create_chat_model
@@ -86,7 +89,7 @@ def create_memory_fact(
normalized_category = category.strip() or "context"
validated_confidence = _validate_confidence(confidence)
now = datetime.utcnow().isoformat() + "Z"
now = utc_now_iso_z()
memory_data = get_memory_data(agent_name)
updated_memory = dict(memory_data)
facts = list(memory_data.get("facts", []))
@@ -376,7 +379,7 @@ class MemoryUpdater:
Updated memory data.
"""
config = get_memory_config()
now = datetime.utcnow().isoformat() + "Z"
now = utc_now_iso_z()
# Update user sections
user_updates = update_data.get("user", {})
@@ -1,5 +1,6 @@
"""Middleware for intercepting clarification requests and presenting them to the user."""
import json
import logging
from collections.abc import Callable
from typing import override
@@ -60,6 +61,20 @@ class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
context = args.get("context")
options = args.get("options", [])
# Some models (e.g. Qwen3-Max) serialize array parameters as JSON strings
# instead of native arrays. Deserialize and normalize so `options`
# is always a list for the rendering logic below.
if isinstance(options, str):
try:
options = json.loads(options)
except (json.JSONDecodeError, TypeError):
options = [options]
if options is None:
options = []
elif not isinstance(options, list):
options = [options]
# Type-specific icons
type_icons = {
"missing_info": "",
@@ -33,30 +33,92 @@ _DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
"""Normalize tool call args to a dict plus an optional fallback key.
Some providers serialize ``args`` as a JSON string instead of a dict.
We defensively parse those cases so loop detection does not crash while
still preserving a stable fallback key for non-dict payloads.
"""
if isinstance(raw_args, dict):
return raw_args, None
if isinstance(raw_args, str):
try:
parsed = json.loads(raw_args)
except (TypeError, ValueError, json.JSONDecodeError):
return {}, raw_args
if isinstance(parsed, dict):
return parsed, None
return {}, json.dumps(parsed, sort_keys=True, default=str)
if raw_args is None:
return {}, None
return {}, json.dumps(raw_args, sort_keys=True, default=str)
def _stable_tool_key(name: str, args: dict, fallback_key: str | None) -> str:
"""Derive a stable key from salient args without overfitting to noise."""
if name == "read_file" and fallback_key is None:
path = args.get("path") or ""
start_line = args.get("start_line")
end_line = args.get("end_line")
bucket_size = 200
try:
start_line = int(start_line) if start_line is not None else 1
except (TypeError, ValueError):
start_line = 1
try:
end_line = int(end_line) if end_line is not None else start_line
except (TypeError, ValueError):
end_line = start_line
start_line, end_line = sorted((start_line, end_line))
bucket_start = max(start_line, 1)
bucket_end = max(end_line, 1)
bucket_start = (bucket_start - 1) // bucket_size
bucket_end = (bucket_end - 1) // bucket_size
return f"{path}:{bucket_start}-{bucket_end}"
# write_file / str_replace are content-sensitive: same path may be updated
# with different payloads during iteration. Using only salient fields (path)
# can collapse distinct calls, so we hash full args to reduce false positives.
if name in {"write_file", "str_replace"}:
if fallback_key is not None:
return fallback_key
return json.dumps(args, sort_keys=True, default=str)
salient_fields = ("path", "url", "query", "command", "pattern", "glob", "cmd")
stable_args = {field: args[field] for field in salient_fields if args.get(field) is not None}
if stable_args:
return json.dumps(stable_args, sort_keys=True, default=str)
if fallback_key is not None:
return fallback_key
return json.dumps(args, sort_keys=True, default=str)
def _hash_tool_calls(tool_calls: list[dict]) -> str:
"""Deterministic hash of a set of tool calls (name + args).
"""Deterministic hash of a set of tool calls (name + stable key).
This is intended to be order-independent: the same multiset of tool calls
should always produce the same hash, regardless of their input order.
"""
# First normalize each tool call to a minimal (name, args) structure.
normalized: list[dict] = []
# Normalize each tool call to a stable (name, key) structure.
normalized: list[str] = []
for tc in tool_calls:
normalized.append(
{
"name": tc.get("name", ""),
"args": tc.get("args", {}),
}
)
name = tc.get("name", "")
args, fallback_key = _normalize_tool_call_args(tc.get("args", {}))
key = _stable_tool_key(name, args, fallback_key)
# Sort by both name and a deterministic serialization of args so that
# permutations of the same multiset of calls yield the same ordering.
normalized.sort(
key=lambda tc: (
tc["name"],
json.dumps(tc["args"], sort_keys=True, default=str),
)
)
normalized.append(f"{name}:{key}")
# Sort so permutations of the same multiset of calls yield the same ordering.
normalized.sort()
blob = json.dumps(normalized, sort_keys=True, default=str)
return hashlib.md5(blob.encode()).hexdigest()[:12]
@@ -23,25 +23,119 @@ logger = logging.getLogger(__name__)
# Each pattern is compiled once at import time.
_HIGH_RISK_PATTERNS: list[re.Pattern[str]] = [
re.compile(r"rm\s+-[^\s]*r[^\s]*\s+(/\*?|~/?\*?|/home\b|/root\b)\s*$"), # rm -rf / /* ~ /home /root
re.compile(r"(curl|wget).+\|\s*(ba)?sh"), # curl|sh, wget|sh
# --- original rules (retained) ---
re.compile(r"rm\s+-[^\s]*r[^\s]*\s+(/\*?|~/?\*?|/home\b|/root\b)\s*$"),
re.compile(r"dd\s+if="),
re.compile(r"mkfs"),
re.compile(r"cat\s+/etc/shadow"),
re.compile(r">\s*/etc/"), # overwrite /etc/ files
re.compile(r">+\s*/etc/"),
# --- pipe to sh/bash (generalised, replaces old curl|sh rule) ---
re.compile(r"\|\s*(ba)?sh\b"),
# --- command substitution (targeted only dangerous executables) ---
re.compile(r"[`$]\(?\s*(curl|wget|bash|sh|python|ruby|perl|base64)"),
# --- base64 decode piped to execution ---
re.compile(r"base64\s+.*-d.*\|"),
# --- overwrite system binaries ---
re.compile(r">+\s*(/usr/bin/|/bin/|/sbin/)"),
# --- overwrite shell startup files ---
re.compile(r">+\s*~/?\.(bashrc|profile|zshrc|bash_profile)"),
# --- process environment leakage ---
re.compile(r"/proc/[^/]+/environ"),
# --- dynamic linker hijack (one-step escalation) ---
re.compile(r"\b(LD_PRELOAD|LD_LIBRARY_PATH)\s*="),
# --- bash built-in networking (bypasses tool allowlists) ---
re.compile(r"/dev/tcp/"),
# --- fork bomb ---
re.compile(r"\S+\(\)\s*\{[^}]*\|\s*\S+\s*&"), # :(){ :|:& };:
re.compile(r"while\s+true.*&\s*done"), # while true; do bash & done
]
_MEDIUM_RISK_PATTERNS: list[re.Pattern[str]] = [
re.compile(r"chmod\s+777"), # overly permissive, but reversible
re.compile(r"pip\s+install"),
re.compile(r"pip3\s+install"),
re.compile(r"chmod\s+777"),
re.compile(r"pip3?\s+install"),
re.compile(r"apt(-get)?\s+install"),
# sudo/su: no-op under Docker root; warn so LLM is aware
re.compile(r"\b(sudo|su)\b"),
# PATH modification: long attack chain, warn rather than block
re.compile(r"\bPATH\s*="),
]
def _classify_command(command: str) -> str:
"""Return 'block', 'warn', or 'pass'."""
# Normalize for matching (collapse whitespace)
def _split_compound_command(command: str) -> list[str]:
"""Split a compound command into sub-commands (quote-aware).
Scans the raw command string so unquoted shell control operators are
recognised even when they are not surrounded by whitespace
(e.g. ``safe;rm -rf /`` or ``rm -rf /&&echo ok``). Operators inside
quotes are ignored. If the command ends with an unclosed quote or a
dangling escape, return the whole command unchanged (fail-closed —
safer to classify the unsplit string than silently drop parts).
"""
parts: list[str] = []
current: list[str] = []
in_single_quote = False
in_double_quote = False
escaping = False
index = 0
while index < len(command):
char = command[index]
if escaping:
current.append(char)
escaping = False
index += 1
continue
if char == "\\" and not in_single_quote:
current.append(char)
escaping = True
index += 1
continue
if char == "'" and not in_double_quote:
in_single_quote = not in_single_quote
current.append(char)
index += 1
continue
if char == '"' and not in_single_quote:
in_double_quote = not in_double_quote
current.append(char)
index += 1
continue
if not in_single_quote and not in_double_quote:
if command.startswith("&&", index) or command.startswith("||", index):
part = "".join(current).strip()
if part:
parts.append(part)
current = []
index += 2
continue
if char == ";":
part = "".join(current).strip()
if part:
parts.append(part)
current = []
index += 1
continue
current.append(char)
index += 1
# Unclosed quote or dangling escape → fail-closed, return whole command
if in_single_quote or in_double_quote or escaping:
return [command]
part = "".join(current).strip()
if part:
parts.append(part)
return parts if parts else [command]
def _classify_single_command(command: str) -> str:
"""Classify a single (non-compound) command. Return 'block', 'warn', or 'pass'."""
normalized = " ".join(command.split())
for pattern in _HIGH_RISK_PATTERNS:
@@ -66,6 +160,35 @@ def _classify_command(command: str) -> str:
return "pass"
def _classify_command(command: str) -> str:
"""Return 'block', 'warn', or 'pass'.
Strategy:
1. First scan the *whole* raw command against high-risk patterns. This
catches structural attacks like ``while true; do bash & done`` or
``:(){ :|:& };:`` that span multiple shell statements — splitting them
on ``;`` would destroy the pattern context.
2. Then split compound commands (e.g. ``cmd1 && cmd2 ; cmd3``) and
classify each sub-command independently. The most severe verdict wins.
"""
# Pass 1: whole-command high-risk scan (catches multi-statement patterns)
normalized = " ".join(command.split())
for pattern in _HIGH_RISK_PATTERNS:
if pattern.search(normalized):
return "block"
# Pass 2: per-sub-command classification
sub_commands = _split_compound_command(command)
worst = "pass"
for sub in sub_commands:
verdict = _classify_single_command(sub)
if verdict == "block":
return "block" # short-circuit: can't get worse
if verdict == "warn":
worst = "warn"
return worst
# ---------------------------------------------------------------------------
# Middleware
# ---------------------------------------------------------------------------
+291 -49
View File
@@ -25,7 +25,7 @@ import uuid
from collections.abc import Generator, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from typing import Any, Literal
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware
@@ -55,6 +55,9 @@ from deerflow.uploads.manager import (
logger = logging.getLogger(__name__)
StreamEventType = Literal["values", "messages-tuple", "custom", "end"]
@dataclass
class StreamEvent:
"""A single event from the streaming agent response.
@@ -69,7 +72,7 @@ class StreamEvent:
data: Event payload. Contents vary by type.
"""
type: str
type: StreamEventType
data: dict[str, Any] = field(default_factory=dict)
@@ -254,13 +257,53 @@ class DeerFlowClient:
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
@staticmethod
def _serialize_tool_calls(tool_calls) -> list[dict]:
"""Reshape LangChain tool_calls into the wire format used in events."""
return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls]
@staticmethod
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None) -> "StreamEvent":
"""Build a ``messages-tuple`` AI text event, attaching usage when present."""
data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
if usage:
data["usage_metadata"] = usage
return StreamEvent(type="messages-tuple", data=data)
@staticmethod
def _ai_tool_calls_event(msg_id: str | None, tool_calls) -> "StreamEvent":
"""Build a ``messages-tuple`` AI tool-calls event."""
return StreamEvent(
type="messages-tuple",
data={
"type": "ai",
"content": "",
"id": msg_id,
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
},
)
@staticmethod
def _tool_message_event(msg: ToolMessage) -> "StreamEvent":
"""Build a ``messages-tuple`` tool-result event from a ToolMessage."""
return StreamEvent(
type="messages-tuple",
data={
"type": "tool",
"content": DeerFlowClient._extract_text(msg.content),
"name": msg.name,
"tool_call_id": msg.tool_call_id,
"id": msg.id,
},
)
@staticmethod
def _serialize_message(msg) -> dict:
"""Serialize a LangChain message to a plain dict for values events."""
if isinstance(msg, AIMessage):
d: dict[str, Any] = {"type": "ai", "content": msg.content, "id": getattr(msg, "id", None)}
if msg.tool_calls:
d["tool_calls"] = [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls]
d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls)
if getattr(msg, "usage_metadata", None):
d["usage_metadata"] = msg.usage_metadata
return d
@@ -315,6 +358,108 @@ class DeerFlowClient:
return "\n".join(pieces) if pieces else ""
return str(content)
# ------------------------------------------------------------------
# Public API — threads
# ------------------------------------------------------------------
def list_threads(self, limit: int = 10) -> dict:
"""List the recent N threads.
Args:
limit: Maximum number of threads to return. Default is 10.
Returns:
Dict with "thread_list" key containing list of thread info dicts,
sorted by thread creation time descending.
"""
checkpointer = self._checkpointer
if checkpointer is None:
from deerflow.agents.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer()
thread_info_map = {}
for cp in checkpointer.list(config=None, limit=limit):
cfg = cp.config.get("configurable", {})
thread_id = cfg.get("thread_id")
if not thread_id:
continue
ts = cp.checkpoint.get("ts")
checkpoint_id = cfg.get("checkpoint_id")
if thread_id not in thread_info_map:
channel_values = cp.checkpoint.get("channel_values", {})
thread_info_map[thread_id] = {
"thread_id": thread_id,
"created_at": ts,
"updated_at": ts,
"latest_checkpoint_id": checkpoint_id,
"title": channel_values.get("title"),
}
else:
# Explicitly compare timestamps to ensure accuracy when iterating over unordered namespaces.
# Treat None as "missing" and only compare when existing values are non-None.
if ts is not None:
current_created = thread_info_map[thread_id]["created_at"]
if current_created is None or ts < current_created:
thread_info_map[thread_id]["created_at"] = ts
current_updated = thread_info_map[thread_id]["updated_at"]
if current_updated is None or ts > current_updated:
thread_info_map[thread_id]["updated_at"] = ts
thread_info_map[thread_id]["latest_checkpoint_id"] = checkpoint_id
channel_values = cp.checkpoint.get("channel_values", {})
thread_info_map[thread_id]["title"] = channel_values.get("title")
threads = list(thread_info_map.values())
threads.sort(key=lambda x: x.get("created_at") or "", reverse=True)
return {"thread_list": threads[:limit]}
def get_thread(self, thread_id: str) -> dict:
"""Get the complete thread record, including all node execution records.
Args:
thread_id: Thread ID.
Returns:
Dict containing the thread's full checkpoint history.
"""
checkpointer = self._checkpointer
if checkpointer is None:
from deerflow.agents.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer()
config = {"configurable": {"thread_id": thread_id}}
checkpoints = []
for cp in checkpointer.list(config):
channel_values = dict(cp.checkpoint.get("channel_values", {}))
if "messages" in channel_values:
channel_values["messages"] = [self._serialize_message(m) if hasattr(m, "content") else m for m in channel_values["messages"]]
cfg = cp.config.get("configurable", {})
parent_cfg = cp.parent_config.get("configurable", {}) if cp.parent_config else {}
checkpoints.append(
{
"checkpoint_id": cfg.get("checkpoint_id"),
"parent_checkpoint_id": parent_cfg.get("checkpoint_id"),
"ts": cp.checkpoint.get("ts"),
"metadata": cp.metadata,
"values": channel_values,
"pending_writes": [{"task_id": w[0], "channel": w[1], "value": w[2]} for w in getattr(cp, "pending_writes", [])],
}
)
# Sort globally by timestamp to prevent partial ordering issues caused by different namespaces (e.g., subgraphs)
checkpoints.sort(key=lambda x: x["ts"] if x["ts"] else "")
return {"thread_id": thread_id, "checkpoints": checkpoints}
# ------------------------------------------------------------------
# Public API — conversation
# ------------------------------------------------------------------
@@ -336,6 +481,53 @@ class DeerFlowClient:
consumers can switch between HTTP streaming and embedded mode
without changing their event-handling logic.
Token-level streaming
~~~~~~~~~~~~~~~~~~~~~
This method subscribes to LangGraph's ``messages`` stream mode, so
``messages-tuple`` events for AI text are emitted as **deltas** as
the model generates tokens, not as one cumulative dump at node
completion. Each delta carries a stable ``id`` — consumers that
want the full text must accumulate ``content`` per ``id``.
``chat()`` already does this for you.
Tool calls and tool results are still emitted once per logical
message. ``values`` events continue to carry full state snapshots
after each graph node finishes; AI text already delivered via the
``messages`` stream is **not** re-synthesized from the snapshot to
avoid duplicate deliveries.
Why not reuse Gateway's ``run_agent``?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Gateway (``runtime/runs/worker.py``) has a complete streaming
pipeline: ``run_agent`` → ``StreamBridge`` → ``sse_consumer``. It
looks like this client duplicates that work, but the two paths
serve different audiences and **cannot** share execution:
* ``run_agent`` is ``async def`` and uses ``agent.astream()``;
this method is a sync generator using ``agent.stream()`` so
callers can write ``for event in client.stream(...)`` without
touching asyncio. Bridging the two would require spinning up
an event loop + thread per call.
* Gateway events are JSON-serialized by ``serialize()`` for SSE
wire transmission. This client yields in-process stream event
payloads directly as Python data structures (``StreamEvent``
with ``data`` as a plain ``dict``), without the extra
JSON/SSE serialization layer used for HTTP delivery.
* ``StreamBridge`` is an asyncio-queue decoupling producers from
consumers across an HTTP boundary (``Last-Event-ID`` replay,
heartbeats, multi-subscriber fan-out). A single in-process
caller with a direct iterator needs none of that.
So ``DeerFlowClient.stream()`` is a parallel, sync, in-process
consumer of the same ``create_agent()`` factory — not a wrapper
around Gateway. The two paths **should** stay in sync on which
LangGraph stream modes they subscribe to; that invariant is
enforced by ``tests/test_client.py::test_messages_mode_emits_token_deltas``
rather than by a shared constant, because the three layers
(Graph, Platform SDK, HTTP) each use their own naming
(``messages`` vs ``messages-tuple``) and cannot literally share
a string.
Args:
message: User message text.
thread_id: Thread ID for conversation context. Auto-generated if None.
@@ -346,8 +538,8 @@ class DeerFlowClient:
StreamEvent with one of:
- type="values" data={"title": str|None, "messages": [...], "artifacts": [...]}
- type="custom" data={...}
- type="messages-tuple" data={"type": "ai", "content": str, "id": str}
- type="messages-tuple" data={"type": "ai", "content": str, "id": str, "usage_metadata": {...}}
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str}
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str, "usage_metadata": {...}}
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
- type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}}
@@ -364,13 +556,47 @@ class DeerFlowClient:
context["agent_name"] = self._agent_name
seen_ids: set[str] = set()
# Cross-mode handoff: ids already streamed via LangGraph ``messages``
# mode so the ``values`` path skips re-synthesis of the same message.
streamed_ids: set[str] = set()
# The same message id carries identical cumulative ``usage_metadata``
# in both the final ``messages`` chunk and the values snapshot —
# count it only on whichever arrives first.
counted_usage_ids: set[str] = set()
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
def _account_usage(msg_id: str | None, usage: Any) -> dict | None:
"""Add *usage* to cumulative totals if this id has not been counted.
``usage`` is a ``langchain_core.messages.UsageMetadata`` TypedDict
or ``None``; typed as ``Any`` because TypedDicts are not
structurally assignable to plain ``dict`` under strict type
checking. Returns the normalized usage dict (for attaching
to an event) when we accepted it, otherwise ``None``.
"""
if not usage:
return None
if msg_id and msg_id in counted_usage_ids:
return None
if msg_id:
counted_usage_ids.add(msg_id)
input_tokens = usage.get("input_tokens", 0) or 0
output_tokens = usage.get("output_tokens", 0) or 0
total_tokens = usage.get("total_tokens", 0) or 0
cumulative_usage["input_tokens"] += input_tokens
cumulative_usage["output_tokens"] += output_tokens
cumulative_usage["total_tokens"] += total_tokens
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
}
for item in self._agent.stream(
state,
config=config,
context=context,
stream_mode=["values", "custom"],
stream_mode=["values", "messages", "custom"],
):
if isinstance(item, tuple) and len(item) == 2:
mode, chunk = item
@@ -382,6 +608,36 @@ class DeerFlowClient:
yield StreamEvent(type="custom", data=chunk)
continue
if mode == "messages":
# LangGraph ``messages`` mode emits ``(message_chunk, metadata)``.
if isinstance(chunk, tuple) and len(chunk) == 2:
msg_chunk, _metadata = chunk
else:
msg_chunk = chunk
msg_id = getattr(msg_chunk, "id", None)
if isinstance(msg_chunk, AIMessage):
text = self._extract_text(msg_chunk.content)
counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata)
if text:
if msg_id:
streamed_ids.add(msg_id)
yield self._ai_text_event(msg_id, text, counted_usage)
if msg_chunk.tool_calls:
if msg_id:
streamed_ids.add(msg_id)
yield self._ai_tool_calls_event(msg_id, msg_chunk.tool_calls)
elif isinstance(msg_chunk, ToolMessage):
if msg_id:
streamed_ids.add(msg_id)
yield self._tool_message_event(msg_chunk)
continue
# mode == "values"
messages = chunk.get("messages", [])
for msg in messages:
@@ -391,47 +647,25 @@ class DeerFlowClient:
if msg_id:
seen_ids.add(msg_id)
# Already streamed via ``messages`` mode; only (defensively)
# capture usage here and skip re-synthesizing the event.
if msg_id and msg_id in streamed_ids:
if isinstance(msg, AIMessage):
_account_usage(msg_id, getattr(msg, "usage_metadata", None))
continue
if isinstance(msg, AIMessage):
# Track token usage from AI messages
usage = getattr(msg, "usage_metadata", None)
if usage:
cumulative_usage["input_tokens"] += usage.get("input_tokens", 0) or 0
cumulative_usage["output_tokens"] += usage.get("output_tokens", 0) or 0
cumulative_usage["total_tokens"] += usage.get("total_tokens", 0) or 0
counted_usage = _account_usage(msg_id, msg.usage_metadata)
if msg.tool_calls:
yield StreamEvent(
type="messages-tuple",
data={
"type": "ai",
"content": "",
"id": msg_id,
"tool_calls": [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls],
},
)
yield self._ai_tool_calls_event(msg_id, msg.tool_calls)
text = self._extract_text(msg.content)
if text:
event_data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
if usage:
event_data["usage_metadata"] = {
"input_tokens": usage.get("input_tokens", 0) or 0,
"output_tokens": usage.get("output_tokens", 0) or 0,
"total_tokens": usage.get("total_tokens", 0) or 0,
}
yield StreamEvent(type="messages-tuple", data=event_data)
yield self._ai_text_event(msg_id, text, counted_usage)
elif isinstance(msg, ToolMessage):
yield StreamEvent(
type="messages-tuple",
data={
"type": "tool",
"content": self._extract_text(msg.content),
"name": getattr(msg, "name", None),
"tool_call_id": getattr(msg, "tool_call_id", None),
"id": msg_id,
},
)
yield self._tool_message_event(msg)
# Emit a values event for each state snapshot
yield StreamEvent(
@@ -448,10 +682,12 @@ class DeerFlowClient:
def chat(self, message: str, *, thread_id: str | None = None, **kwargs) -> str:
"""Send a message and return the final text response.
Convenience wrapper around :meth:`stream` that returns only the
**last** AI text from ``messages-tuple`` events. If the agent emits
multiple text segments in one turn, intermediate segments are
discarded. Use :meth:`stream` directly to capture all events.
Convenience wrapper around :meth:`stream` that accumulates delta
``messages-tuple`` events per ``id`` and returns the text of the
**last** AI message to complete. Intermediate AI messages (e.g.
planner drafts) are discarded — only the final id's accumulated
text is returned. Use :meth:`stream` directly if you need every
delta as it arrives.
Args:
message: User message text.
@@ -459,15 +695,21 @@ class DeerFlowClient:
**kwargs: Override client defaults (same as stream()).
Returns:
The last AI message text, or empty string if no response.
The accumulated text of the last AI message, or empty string
if no AI text was produced.
"""
last_text = ""
# Per-id delta lists joined once at the end — avoids the O(n²) cost
# of repeated ``str + str`` on a growing buffer for long responses.
chunks: dict[str, list[str]] = {}
last_id: str = ""
for event in self.stream(message, thread_id=thread_id, **kwargs):
if event.type == "messages-tuple" and event.data.get("type") == "ai":
content = event.data.get("content", "")
if content:
last_text = content
return last_text
msg_id = event.data.get("id") or ""
delta = event.data.get("content", "")
if delta:
chunks.setdefault(msg_id, []).append(delta)
last_id = msg_id
return "".join(chunks.get(last_id, ()))
# ------------------------------------------------------------------
# Public API — configuration queries
@@ -112,6 +112,9 @@ class AioSandboxProvider(SandboxProvider):
atexit.register(self.shutdown)
self._register_signal_handlers()
# Reconcile orphaned containers from previous process lifecycles
self._reconcile_orphans()
# Start idle checker if enabled
if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0:
self._start_idle_checker()
@@ -175,6 +178,51 @@ class AioSandboxProvider(SandboxProvider):
resolved[key] = str(value)
return resolved
# ── Startup reconciliation ────────────────────────────────────────────
def _reconcile_orphans(self) -> None:
"""Reconcile orphaned containers left by previous process lifecycles.
On startup, enumerate all running containers matching our prefix
and adopt them all into the warm pool. The idle checker will reclaim
containers that nobody re-acquires within ``idle_timeout``.
All containers are adopted unconditionally because we cannot
distinguish "orphaned" from "actively used by another process"
based on age alone — ``idle_timeout`` represents inactivity, not
uptime. Adopting into the warm pool and letting the idle checker
decide avoids destroying containers that a concurrent process may
still be using.
This closes the fundamental gap where in-memory state loss (process
restart, crash, SIGKILL) leaves Docker containers running forever.
"""
try:
running = self._backend.list_running()
except Exception as e:
logger.warning(f"Failed to enumerate running containers during startup reconciliation: {e}")
return
if not running:
return
current_time = time.time()
adopted = 0
for info in running:
age = current_time - info.created_at if info.created_at > 0 else float("inf")
# Single lock acquisition per container: atomic check-and-insert.
# Avoids a TOCTOU window between the "already tracked?" check and
# the warm-pool insert.
with self._lock:
if info.sandbox_id in self._sandboxes or info.sandbox_id in self._warm_pool:
continue
self._warm_pool[info.sandbox_id] = (info, current_time)
adopted += 1
logger.info(f"Adopted container {info.sandbox_id} into warm pool (age: {age:.0f}s)")
logger.info(f"Startup reconciliation complete: {adopted} adopted into warm pool, {len(running)} total found")
# ── Deterministic ID ─────────────────────────────────────────────────
@staticmethod
@@ -316,13 +364,23 @@ class AioSandboxProvider(SandboxProvider):
# ── Signal handling ──────────────────────────────────────────────────
def _register_signal_handlers(self) -> None:
"""Register signal handlers for graceful shutdown."""
"""Register signal handlers for graceful shutdown.
Handles SIGTERM, SIGINT, and SIGHUP (terminal close) to ensure
sandbox containers are cleaned up even when the user closes the terminal.
"""
self._original_sigterm = signal.getsignal(signal.SIGTERM)
self._original_sigint = signal.getsignal(signal.SIGINT)
self._original_sighup = signal.getsignal(signal.SIGHUP) if hasattr(signal, "SIGHUP") else None
def signal_handler(signum, frame):
self.shutdown()
original = self._original_sigterm if signum == signal.SIGTERM else self._original_sigint
if signum == signal.SIGTERM:
original = self._original_sigterm
elif hasattr(signal, "SIGHUP") and signum == signal.SIGHUP:
original = self._original_sighup
else:
original = self._original_sigint
if callable(original):
original(signum, frame)
elif original == signal.SIG_DFL:
@@ -332,6 +390,8 @@ class AioSandboxProvider(SandboxProvider):
try:
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, signal_handler)
except ValueError:
logger.debug("Could not register signal handlers (not main thread)")
@@ -96,3 +96,19 @@ class SandboxBackend(ABC):
SandboxInfo if found and healthy, None otherwise.
"""
...
def list_running(self) -> list[SandboxInfo]:
"""Enumerate all running sandboxes managed by this backend.
Used for startup reconciliation: when the process restarts, it needs
to discover containers started by previous processes so they can be
adopted into the warm pool or destroyed if idle too long.
The default implementation returns an empty list, which is correct
for backends that don't manage local containers (e.g., RemoteSandboxBackend
delegates lifecycle to the provisioner which handles its own cleanup).
Returns:
A list of SandboxInfo for all currently running sandboxes.
"""
return []
@@ -6,9 +6,11 @@ Handles container lifecycle, port allocation, and cross-process container discov
from __future__ import annotations
import json
import logging
import os
import subprocess
from datetime import datetime
from deerflow.utils.network import get_free_port, release_port
@@ -18,6 +20,52 @@ from .sandbox_info import SandboxInfo
logger = logging.getLogger(__name__)
def _parse_docker_timestamp(raw: str) -> float:
"""Parse Docker's ISO 8601 timestamp into a Unix epoch float.
Docker returns timestamps with nanosecond precision and a trailing ``Z``
(e.g. ``2026-04-08T01:22:50.123456789Z``). Python's ``fromisoformat``
accepts at most microseconds and (pre-3.11) does not accept ``Z``, so the
string is normalized before parsing. Returns ``0.0`` on empty input or
parse failure so callers can use ``0.0`` as a sentinel for "unknown age".
"""
if not raw:
return 0.0
try:
s = raw.strip()
if "." in s:
dot_pos = s.index(".")
tz_start = dot_pos + 1
while tz_start < len(s) and s[tz_start].isdigit():
tz_start += 1
frac = s[dot_pos + 1 : tz_start][:6] # truncate to microseconds
tz_suffix = s[tz_start:]
s = s[: dot_pos + 1] + frac + tz_suffix
if s.endswith("Z"):
s = s[:-1] + "+00:00"
return datetime.fromisoformat(s).timestamp()
except (ValueError, TypeError) as e:
logger.debug(f"Could not parse docker timestamp {raw!r}: {e}")
return 0.0
def _extract_host_port(inspect_entry: dict, container_port: int) -> int | None:
"""Extract the host port mapped to ``container_port/tcp`` from a docker inspect entry.
Returns None if the container has no port mapping for that port.
"""
try:
ports = (inspect_entry.get("NetworkSettings") or {}).get("Ports") or {}
bindings = ports.get(f"{container_port}/tcp") or []
if bindings:
host_port = bindings[0].get("HostPort")
if host_port:
return int(host_port)
except (ValueError, TypeError, AttributeError):
pass
return None
def _format_container_mount(runtime: str, host_path: str, container_path: str, read_only: bool) -> list[str]:
"""Format a bind-mount argument for the selected runtime.
@@ -172,8 +220,12 @@ class LocalContainerBackend(SandboxBackend):
def destroy(self, info: SandboxInfo) -> None:
"""Stop the container and release its port."""
if info.container_id:
self._stop_container(info.container_id)
# Prefer container_id, fall back to container_name (both accepted by docker stop).
# This ensures containers discovered via list_running() (which only has the name)
# can also be stopped.
stop_target = info.container_id or info.container_name
if stop_target:
self._stop_container(stop_target)
# Extract port from sandbox_url for release
try:
from urllib.parse import urlparse
@@ -222,6 +274,129 @@ class LocalContainerBackend(SandboxBackend):
container_name=container_name,
)
def list_running(self) -> list[SandboxInfo]:
"""Enumerate all running containers matching the configured prefix.
Uses a single ``docker ps`` call to list container names, then a
single batched ``docker inspect`` call to retrieve creation timestamp
and port mapping for all containers at once. Total subprocess calls:
2 (down from 2N+1 in the naive per-container approach).
Note: Docker's ``--filter name=`` performs *substring* matching,
so a secondary ``startswith`` check is applied to ensure only
containers with the exact prefix are included.
Containers without port mappings are still included (with empty
sandbox_url) so that startup reconciliation can adopt orphans
regardless of their port state.
"""
# Step 1: enumerate container names via docker ps
try:
result = subprocess.run(
[
self._runtime,
"ps",
"--filter",
f"name={self._container_prefix}-",
"--format",
"{{.Names}}",
],
capture_output=True,
text=True,
timeout=10,
)
if result.returncode != 0:
stderr = (result.stderr or "").strip()
logger.warning(
"Failed to list running containers with %s ps (returncode=%s, stderr=%s)",
self._runtime,
result.returncode,
stderr or "<empty>",
)
return []
if not result.stdout.strip():
return []
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError, OSError) as e:
logger.warning(f"Failed to list running containers: {e}")
return []
# Filter to names matching our exact prefix (docker filter is substring-based)
container_names = [name.strip() for name in result.stdout.strip().splitlines() if name.strip().startswith(self._container_prefix + "-")]
if not container_names:
return []
# Step 2: batched docker inspect — single subprocess call for all containers
inspections = self._batch_inspect(container_names)
infos: list[SandboxInfo] = []
sandbox_host = os.environ.get("DEER_FLOW_SANDBOX_HOST", "localhost")
for container_name in container_names:
data = inspections.get(container_name)
if data is None:
# Container disappeared between ps and inspect, or inspect failed
continue
created_at, host_port = data
sandbox_id = container_name[len(self._container_prefix) + 1 :]
sandbox_url = f"http://{sandbox_host}:{host_port}" if host_port else ""
infos.append(
SandboxInfo(
sandbox_id=sandbox_id,
sandbox_url=sandbox_url,
container_name=container_name,
created_at=created_at,
)
)
logger.info(f"Found {len(infos)} running sandbox container(s)")
return infos
def _batch_inspect(self, container_names: list[str]) -> dict[str, tuple[float, int | None]]:
"""Batch-inspect containers in a single subprocess call.
Returns a mapping of ``container_name -> (created_at, host_port)``.
Missing containers or parse failures are silently dropped from the result.
"""
if not container_names:
return {}
try:
result = subprocess.run(
[self._runtime, "inspect", *container_names],
capture_output=True,
text=True,
timeout=15,
)
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError, OSError) as e:
logger.warning(f"Failed to batch-inspect containers: {e}")
return {}
if result.returncode != 0:
stderr = (result.stderr or "").strip()
logger.warning(
"Failed to batch-inspect containers with %s inspect (returncode=%s, stderr=%s)",
self._runtime,
result.returncode,
stderr or "<empty>",
)
return {}
try:
payload = json.loads(result.stdout or "[]")
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse docker inspect output as JSON: {e}")
return {}
out: dict[str, tuple[float, int | None]] = {}
for entry in payload:
# ``Name`` is prefixed with ``/`` in the docker inspect response
name = (entry.get("Name") or "").lstrip("/")
if not name:
continue
created_at = _parse_docker_timestamp(entry.get("Created", ""))
host_port = _extract_host_port(entry, 8080)
out[name] = (created_at, host_port)
return out
# ── Container operations ─────────────────────────────────────────────
def _start_container(
@@ -0,0 +1,79 @@
import json
from exa_py import Exa
from langchain.tools import tool
from deerflow.config import get_app_config
def _get_exa_client(tool_name: str = "web_search") -> Exa:
config = get_app_config().get_tool_config(tool_name)
api_key = None
if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key")
return Exa(api_key=api_key)
@tool("web_search", parse_docstring=True)
def web_search_tool(query: str) -> str:
"""Search the web.
Args:
query: The query to search for.
"""
try:
config = get_app_config().get_tool_config("web_search")
max_results = 5
search_type = "auto"
contents_max_characters = 1000
if config is not None:
max_results = config.model_extra.get("max_results", max_results)
search_type = config.model_extra.get("search_type", search_type)
contents_max_characters = config.model_extra.get("contents_max_characters", contents_max_characters)
client = _get_exa_client()
res = client.search(
query,
type=search_type,
num_results=max_results,
contents={"highlights": {"max_characters": contents_max_characters}},
)
normalized_results = [
{
"title": result.title or "",
"url": result.url or "",
"snippet": "\n".join(result.highlights) if result.highlights else "",
}
for result in res.results
]
json_results = json.dumps(normalized_results, indent=2, ensure_ascii=False)
return json_results
except Exception as e:
return f"Error: {str(e)}"
@tool("web_fetch", parse_docstring=True)
def web_fetch_tool(url: str) -> str:
"""Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
Do NOT add www. to URLs that do NOT have them.
URLs must include the schema: https://example.com is a valid URL while example.com is an invalid URL.
Args:
url: The URL to fetch the contents of.
"""
try:
client = _get_exa_client("web_fetch")
res = client.get_contents([url], text={"max_characters": 4096})
if res.results:
result = res.results[0]
title = result.title or "Untitled"
text = result.text or ""
return f"# {title}\n\n{text[:4096]}"
else:
return "Error: No results found"
except Exception as e:
return f"Error: {str(e)}"
@@ -6,10 +6,10 @@ from langchain.tools import tool
from deerflow.config import get_app_config
def _get_firecrawl_client() -> FirecrawlApp:
config = get_app_config().get_tool_config("web_search")
def _get_firecrawl_client(tool_name: str = "web_search") -> FirecrawlApp:
config = get_app_config().get_tool_config(tool_name)
api_key = None
if config is not None:
if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key")
return FirecrawlApp(api_key=api_key) # type: ignore[arg-type]
@@ -27,7 +27,7 @@ def web_search_tool(query: str) -> str:
if config is not None:
max_results = config.model_extra.get("max_results", max_results)
client = _get_firecrawl_client()
client = _get_firecrawl_client("web_search")
result = client.search(query, limit=max_results)
# result.web contains list of SearchResultWeb objects
@@ -58,7 +58,7 @@ def web_fetch_tool(url: str) -> str:
url: The URL to fetch the contents of.
"""
try:
client = _get_firecrawl_client()
client = _get_firecrawl_client("web_fetch")
result = client.scrape(url, formats=["markdown"])
markdown_content = result.markdown or ""
@@ -27,6 +27,10 @@ class ModelConfig(BaseModel):
default_factory=lambda: None,
description="Extra settings to be passed to the model when thinking is enabled",
)
when_thinking_disabled: dict | None = Field(
default_factory=lambda: None,
description="Extra settings to be passed to the model when thinking is disabled",
)
supports_vision: bool = Field(default_factory=lambda: False, description="Whether the model supports vision/image inputs")
thinking: dict | None = Field(
default_factory=lambda: None,
@@ -56,6 +56,7 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
"supports_thinking",
"supports_reasoning_effort",
"when_thinking_enabled",
"when_thinking_disabled",
"thinking",
"supports_vision",
},
@@ -72,21 +73,24 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
raise ValueError(f"Model {name} does not support thinking. Set `supports_thinking` to true in the `config.yaml` to enable thinking.") from None
if effective_wte:
model_settings_from_config.update(effective_wte)
if not thinking_enabled and has_thinking_settings:
if effective_wte.get("extra_body", {}).get("thinking", {}).get("type"):
if not thinking_enabled:
if model_config.when_thinking_disabled is not None:
# User-provided disable settings take full precedence
model_settings_from_config.update(model_config.when_thinking_disabled)
elif has_thinking_settings and effective_wte.get("extra_body", {}).get("thinking", {}).get("type"):
# OpenAI-compatible gateway: thinking is nested under extra_body
model_settings_from_config["extra_body"] = _deep_merge_dicts(
model_settings_from_config.get("extra_body"),
{"thinking": {"type": "disabled"}},
)
model_settings_from_config["reasoning_effort"] = "minimal"
elif disable_chat_template_kwargs := _vllm_disable_chat_template_kwargs(effective_wte.get("extra_body", {}).get("chat_template_kwargs") or {}):
elif has_thinking_settings and (disable_chat_template_kwargs := _vllm_disable_chat_template_kwargs(effective_wte.get("extra_body", {}).get("chat_template_kwargs") or {})):
# vLLM uses chat template kwargs to switch thinking on/off.
model_settings_from_config["extra_body"] = _deep_merge_dicts(
model_settings_from_config.get("extra_body"),
{"chat_template_kwargs": disable_chat_template_kwargs},
)
elif effective_wte.get("thinking", {}).get("type"):
elif has_thinking_settings and effective_wte.get("thinking", {}).get("type"):
# Native langchain_anthropic: thinking is a direct constructor parameter
model_settings_from_config["thinking"] = {"type": "disabled"}
if not model_config.supports_reasoning_effort:
@@ -48,6 +48,10 @@ class CodexChatModel(BaseChatModel):
model_config = {"arbitrary_types_allowed": True}
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def _llm_type(self) -> str:
return "codex-responses"
@@ -216,18 +220,48 @@ class CodexChatModel(BaseChatModel):
def _stream_response(self, headers: dict, payload: dict) -> dict:
"""Stream SSE from Codex API and collect the final response."""
completed_response = None
streamed_output_items: dict[int, dict[str, Any]] = {}
with httpx.Client(timeout=300) as client:
with client.stream("POST", f"{CODEX_BASE_URL}/responses", headers=headers, json=payload) as resp:
resp.raise_for_status()
for line in resp.iter_lines():
data = self._parse_sse_data_line(line)
if data and data.get("type") == "response.completed":
if not data:
continue
event_type = data.get("type")
if event_type == "response.output_item.done":
output_index = data.get("output_index")
output_item = data.get("item")
if isinstance(output_index, int) and isinstance(output_item, dict):
streamed_output_items[output_index] = output_item
elif event_type == "response.completed":
completed_response = data["response"]
if not completed_response:
raise RuntimeError("Codex API stream ended without response.completed event")
# ChatGPT Codex can emit the final assistant content only in stream events.
# When response.completed arrives, response.output may still be empty.
if streamed_output_items:
merged_output = []
response_output = completed_response.get("output")
if isinstance(response_output, list):
merged_output = list(response_output)
max_index = max(max(streamed_output_items), len(merged_output) - 1)
if max_index >= 0 and len(merged_output) <= max_index:
merged_output.extend([None] * (max_index + 1 - len(merged_output)))
for output_index, output_item in streamed_output_items.items():
existing_item = merged_output[output_index]
if not isinstance(existing_item, dict):
merged_output[output_index] = output_item
completed_response = dict(completed_response)
completed_response["output"] = [item for item in merged_output if isinstance(item, dict)]
return completed_response
@staticmethod
@@ -23,6 +23,14 @@ class PatchedChatDeepSeek(ChatDeepSeek):
request payload.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def lc_secrets(self) -> dict[str, str]:
return {"api_key": "DEEPSEEK_API_KEY", "openai_api_key": "DEEPSEEK_API_KEY"}
def _get_request_payload(
self,
input_: LanguageModelInput,
@@ -16,6 +16,8 @@ internal checkpoint callbacks that are not exposed in the Python public API.
from __future__ import annotations
import asyncio
import copy
import inspect
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal
@@ -79,6 +81,9 @@ async def run_agent(
run_id = record.run_id
thread_id = record.thread_id
requested_modes: set[str] = set(stream_modes or ["values"])
pre_run_checkpoint_id: str | None = None
pre_run_snapshot: dict[str, Any] | None = None
snapshot_capture_failed = False
# Initialize RunJournal for event capture
journal = None
@@ -120,15 +125,23 @@ async def run_agent(
# 1. Mark running
await run_manager.set_status(run_id, RunStatus.running)
# Record pre-run checkpoint_id to support rollback (Phase 2).
pre_run_checkpoint_id = None
try:
config_for_check = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(config_for_check)
if ckpt_tuple is not None:
pre_run_checkpoint_id = getattr(ckpt_tuple, "config", {}).get("configurable", {}).get("checkpoint_id")
except Exception:
logger.debug("Could not get pre-run checkpoint_id for run %s", run_id)
# Snapshot the latest pre-run checkpoint so rollback can restore it.
if checkpointer is not None:
try:
config_for_check = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(config_for_check)
if ckpt_tuple is not None:
ckpt_config = getattr(ckpt_tuple, "config", {}).get("configurable", {})
pre_run_checkpoint_id = ckpt_config.get("checkpoint_id")
pre_run_snapshot = {
"checkpoint_ns": ckpt_config.get("checkpoint_ns", ""),
"checkpoint": copy.deepcopy(getattr(ckpt_tuple, "checkpoint", {})),
"metadata": copy.deepcopy(getattr(ckpt_tuple, "metadata", {})),
"pending_writes": copy.deepcopy(getattr(ckpt_tuple, "pending_writes", []) or []),
}
except Exception:
snapshot_capture_failed = True
logger.warning("Could not capture pre-run checkpoint snapshot for run %s", run_id, exc_info=True)
# 2. Publish metadata — useStream needs both run_id AND thread_id
await bridge.publish(
@@ -234,17 +247,18 @@ async def run_agent(
action = record.abort_action
if action == "rollback":
await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user")
# TODO(Phase 2): Implement full checkpoint rollback.
# Use pre_run_checkpoint_id to revert the thread's checkpoint
# to the state before this run started. Requires a
# checkpointer.adelete() or equivalent API.
try:
if checkpointer is not None and pre_run_checkpoint_id is not None:
# Phase 2: roll back to pre_run_checkpoint_id
pass
logger.info("Run %s rolled back", run_id)
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id=thread_id,
run_id=run_id,
pre_run_checkpoint_id=pre_run_checkpoint_id,
pre_run_snapshot=pre_run_snapshot,
snapshot_capture_failed=snapshot_capture_failed,
)
logger.info("Run %s rolled back to pre-run checkpoint %s", run_id, pre_run_checkpoint_id)
except Exception:
logger.warning("Failed to rollback checkpoint for run %s", run_id)
logger.warning("Failed to rollback checkpoint for run %s", run_id, exc_info=True)
else:
await run_manager.set_status(run_id, RunStatus.interrupted)
else:
@@ -254,7 +268,18 @@ async def run_agent(
action = record.abort_action
if action == "rollback":
await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user")
logger.info("Run %s was cancelled (rollback)", run_id)
try:
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id=thread_id,
run_id=run_id,
pre_run_checkpoint_id=pre_run_checkpoint_id,
pre_run_snapshot=pre_run_snapshot,
snapshot_capture_failed=snapshot_capture_failed,
)
logger.info("Run %s was cancelled and rolled back", run_id)
except Exception:
logger.warning("Run %s cancellation rollback failed", run_id, exc_info=True)
else:
await run_manager.set_status(run_id, RunStatus.interrupted)
logger.info("Run %s was cancelled", run_id)
@@ -313,6 +338,104 @@ async def run_agent(
# ---------------------------------------------------------------------------
async def _call_checkpointer_method(checkpointer: Any, async_name: str, sync_name: str, *args: Any, **kwargs: Any) -> Any:
"""Call a checkpointer method, supporting async and sync variants."""
method = getattr(checkpointer, async_name, None) or getattr(checkpointer, sync_name, None)
if method is None:
raise AttributeError(f"Missing checkpointer method: {async_name}/{sync_name}")
result = method(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result
async def _rollback_to_pre_run_checkpoint(
*,
checkpointer: Any,
thread_id: str,
run_id: str,
pre_run_checkpoint_id: str | None,
pre_run_snapshot: dict[str, Any] | None,
snapshot_capture_failed: bool,
) -> None:
"""Restore thread state to the checkpoint snapshot captured before run start."""
if checkpointer is None:
logger.info("Run %s rollback requested but no checkpointer is configured", run_id)
return
if snapshot_capture_failed:
logger.warning("Run %s rollback skipped: pre-run checkpoint snapshot capture failed", run_id)
return
if pre_run_snapshot is None:
await _call_checkpointer_method(checkpointer, "adelete_thread", "delete_thread", thread_id)
logger.info("Run %s rollback reset thread %s to empty state", run_id, thread_id)
return
checkpoint_to_restore = None
metadata_to_restore: dict[str, Any] = {}
checkpoint_ns = ""
checkpoint = pre_run_snapshot.get("checkpoint")
if not isinstance(checkpoint, dict):
logger.warning("Run %s rollback skipped: invalid pre-run checkpoint snapshot", run_id)
return
checkpoint_to_restore = checkpoint
if checkpoint_to_restore.get("id") is None and pre_run_checkpoint_id is not None:
checkpoint_to_restore = {**checkpoint_to_restore, "id": pre_run_checkpoint_id}
if checkpoint_to_restore.get("id") is None:
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
return
metadata = pre_run_snapshot.get("metadata", {})
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
checkpoint_ns = raw_checkpoint_ns if isinstance(raw_checkpoint_ns, str) else ""
channel_versions = checkpoint_to_restore.get("channel_versions")
new_versions = dict(channel_versions) if isinstance(channel_versions, dict) else {}
restore_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}}
restored_config = await _call_checkpointer_method(
checkpointer,
"aput",
"put",
restore_config,
checkpoint_to_restore,
metadata_to_restore if isinstance(metadata_to_restore, dict) else {},
new_versions,
)
if not isinstance(restored_config, dict):
raise RuntimeError(f"Run {run_id} rollback restore returned invalid config: expected dict")
restored_configurable = restored_config.get("configurable", {})
if not isinstance(restored_configurable, dict):
raise RuntimeError(f"Run {run_id} rollback restore returned invalid config payload")
restored_checkpoint_id = restored_configurable.get("checkpoint_id")
if not restored_checkpoint_id:
raise RuntimeError(f"Run {run_id} rollback restore did not return checkpoint_id")
pending_writes = pre_run_snapshot.get("pending_writes", [])
if not pending_writes:
return
writes_by_task: dict[str, list[tuple[str, Any]]] = {}
for item in pending_writes:
if not isinstance(item, (tuple, list)) or len(item) != 3:
raise RuntimeError(f"Run {run_id} rollback failed: pending_write is not a 3-tuple: {item!r}")
task_id, channel, value = item
if not isinstance(channel, str):
raise RuntimeError(f"Run {run_id} rollback failed: pending_write has non-string channel: task_id={task_id!r}, channel={channel!r}")
writes_by_task.setdefault(str(task_id), []).append((channel, value))
for task_id, writes in writes_by_task.items():
await _call_checkpointer_method(
checkpointer,
"aput_writes",
"put_writes",
restored_config,
writes,
task_id=task_id,
)
def _lg_mode_to_sse_event(mode: str) -> str:
"""Map LangGraph internal stream_mode name to SSE event name.
@@ -1,8 +1,12 @@
import threading
import weakref
from deerflow.sandbox.sandbox import Sandbox
_FILE_OPERATION_LOCKS: dict[tuple[str, str], threading.Lock] = {}
# Use WeakValueDictionary to prevent memory leak in long-running processes.
# Locks are automatically removed when no longer referenced by any thread.
_LockKey = tuple[str, str]
_FILE_OPERATION_LOCKS: weakref.WeakValueDictionary[_LockKey, threading.Lock] = weakref.WeakValueDictionary()
_FILE_OPERATION_LOCKS_GUARD = threading.Lock()
@@ -20,7 +20,8 @@ Do NOT use for simple single commands - use bash tool directly instead.""",
- Use parallel execution when commands are independent
- Report both stdout and stderr when relevant
- Handle errors gracefully and explain what went wrong
- Use absolute paths for file operations
- Use workspace-relative paths for files under the default workspace, uploads, and outputs directories
- Use absolute paths only when the task references deployment-configured custom mounts outside the default workspace layout
- Be cautious with destructive operations (rm, overwrite, etc.)
</guidelines>
@@ -38,6 +39,8 @@ You have access to the sandbox environment:
- User workspace: `/mnt/user-data/workspace`
- Output files: `/mnt/user-data/outputs`
- Deployment-configured custom mounts may also be available at other absolute container paths; use them directly when the task references those mounted directories
- Treat `/mnt/user-data/workspace` as the default working directory for file IO
- Prefer relative paths from the workspace, such as `hello.txt`, `../uploads/input.csv`, and `../outputs/result.md`, when composing commands or helper scripts
</working_directory>
""",
tools=["bash", "ls", "read_file", "write_file", "str_replace"], # Sandbox tools only
@@ -39,6 +39,8 @@ You have access to the same sandbox environment as the parent agent:
- User workspace: `/mnt/user-data/workspace`
- Output files: `/mnt/user-data/outputs`
- Deployment-configured custom mounts may also be available at other absolute container paths; use them directly when the task references those mounted directories
- Treat `/mnt/user-data/workspace` as the default working directory for coding and file IO
- Prefer relative paths from the workspace, such as `hello.txt`, `../uploads/input.csv`, and `../outputs/result.md`, when writing scripts or shell commands
</working_directory>
""",
tools=None, # Inherit all tools from parent
@@ -6,7 +6,7 @@ import threading
import uuid
from concurrent.futures import Future, ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
@@ -30,6 +30,7 @@ class SubagentStatus(Enum):
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
TIMED_OUT = "timed_out"
@@ -56,6 +57,7 @@ class SubagentResult:
started_at: datetime | None = None
completed_at: datetime | None = None
ai_messages: list[dict[str, Any]] | None = None
cancel_event: threading.Event = field(default_factory=threading.Event, repr=False)
def __post_init__(self):
"""Initialize mutable defaults."""
@@ -74,6 +76,9 @@ _scheduler_pool = ThreadPoolExecutor(max_workers=3, thread_name_prefix="subagent
# Larger pool to avoid blocking when scheduler submits execution tasks
_execution_pool = ThreadPoolExecutor(max_workers=3, thread_name_prefix="subagent-exec-")
# Dedicated pool for sync execute() calls made from an already-running event loop.
_isolated_loop_pool = ThreadPoolExecutor(max_workers=3, thread_name_prefix="subagent-isolated-")
def _filter_tools(
all_tools: list[BaseTool],
@@ -241,7 +246,31 @@ class SubagentExecutor:
# Use stream instead of invoke to get real-time updates
# This allows us to collect AI messages as they are generated
final_state = None
# Pre-check: bail out immediately if already cancelled before streaming starts
if result.cancel_event.is_set():
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} cancelled before streaming")
with _background_tasks_lock:
if result.status == SubagentStatus.RUNNING:
result.status = SubagentStatus.CANCELLED
result.error = "Cancelled by user"
result.completed_at = datetime.now()
return result
async for chunk in agent.astream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type]
# Cooperative cancellation: check if parent requested stop.
# Note: cancellation is only detected at astream iteration boundaries,
# so long-running tool calls within a single iteration will not be
# interrupted until the next chunk is yielded.
if result.cancel_event.is_set():
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} cancelled by parent")
with _background_tasks_lock:
if result.status == SubagentStatus.RUNNING:
result.status = SubagentStatus.CANCELLED
result.error = "Cancelled by user"
result.completed_at = datetime.now()
return result
final_state = chunk
# Extract AI messages from the current state
@@ -348,12 +377,55 @@ class SubagentExecutor:
return result
def _execute_in_isolated_loop(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
"""Execute the subagent in a completely fresh event loop.
This method is designed to run in a separate thread to ensure complete
isolation from any parent event loop, preventing conflicts with asyncio
primitives that may be bound to the parent loop (e.g., httpx clients).
"""
try:
previous_loop = asyncio.get_event_loop()
except RuntimeError:
previous_loop = None
# Create and set a new event loop for this thread
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
return loop.run_until_complete(self._aexecute(task, result_holder))
finally:
try:
pending = asyncio.all_tasks(loop)
if pending:
for task_obj in pending:
task_obj.cancel()
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(loop.shutdown_default_executor())
except Exception:
logger.debug(
f"[trace={self.trace_id}] Failed while cleaning up isolated event loop for subagent {self.config.name}",
exc_info=True,
)
finally:
try:
loop.close()
finally:
asyncio.set_event_loop(previous_loop)
def execute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
"""Execute a task synchronously (wrapper around async execution).
This method runs the async execution in a new event loop, allowing
asynchronous tools (like MCP tools) to be used within the thread pool.
When called from within an already-running event loop (e.g., when the
parent agent is async), this method isolates the subagent execution in
a separate thread to avoid event loop conflicts with shared async
primitives like httpx clients.
Args:
task: The task description for the subagent.
result_holder: Optional pre-created result object to update during execution.
@@ -361,16 +433,18 @@ class SubagentExecutor:
Returns:
SubagentResult with the execution result.
"""
# Run the async execution in a new event loop
# This is necessary because:
# 1. We may have async-only tools (like MCP tools)
# 2. We're running inside a ThreadPoolExecutor which doesn't have an event loop
#
# Note: _aexecute() catches all exceptions internally, so this outer
# try-except only handles asyncio.run() failures (e.g., if called from
# an async context where an event loop already exists). Subagent execution
# errors are handled within _aexecute() and returned as FAILED status.
try:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
logger.debug(f"[trace={self.trace_id}] Subagent {self.config.name} detected running event loop, using isolated thread")
future = _isolated_loop_pool.submit(self._execute_in_isolated_loop, task, result_holder)
return future.result()
# Standard path: no running event loop, use asyncio.run
return asyncio.run(self._aexecute(task, result_holder))
except Exception as e:
logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} execution failed")
@@ -437,10 +511,12 @@ class SubagentExecutor:
except FuturesTimeoutError:
logger.error(f"[trace={self.trace_id}] Subagent {self.config.name} execution timed out after {self.config.timeout_seconds}s")
with _background_tasks_lock:
_background_tasks[task_id].status = SubagentStatus.TIMED_OUT
_background_tasks[task_id].error = f"Execution timed out after {self.config.timeout_seconds} seconds"
_background_tasks[task_id].completed_at = datetime.now()
# Cancel the future (best effort - may not stop the actual execution)
if _background_tasks[task_id].status == SubagentStatus.RUNNING:
_background_tasks[task_id].status = SubagentStatus.TIMED_OUT
_background_tasks[task_id].error = f"Execution timed out after {self.config.timeout_seconds} seconds"
_background_tasks[task_id].completed_at = datetime.now()
# Signal cooperative cancellation and cancel the future
result_holder.cancel_event.set()
execution_future.cancel()
except Exception as e:
logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} async execution failed")
@@ -456,6 +532,24 @@ class SubagentExecutor:
MAX_CONCURRENT_SUBAGENTS = 3
def request_cancel_background_task(task_id: str) -> None:
"""Signal a running background task to stop.
Sets the cancel_event on the task, which is checked cooperatively
by ``_aexecute`` during ``agent.astream()`` iteration. This allows
subagent threads — which cannot be force-killed via ``Future.cancel()``
— to stop at the next iteration boundary.
Args:
task_id: The task ID to cancel.
"""
with _background_tasks_lock:
result = _background_tasks.get(task_id)
if result is not None:
result.cancel_event.set()
logger.info("Requested cancellation for background task %s", task_id)
def get_background_task_result(task_id: str) -> SubagentResult | None:
"""Get the result of a background task.
@@ -503,6 +597,7 @@ def cleanup_background_task(task_id: str) -> None:
is_terminal_status = result.status in {
SubagentStatus.COMPLETED,
SubagentStatus.FAILED,
SubagentStatus.CANCELLED,
SubagentStatus.TIMED_OUT,
}
if is_terminal_status or result.completed_at is not None:
@@ -14,7 +14,7 @@ from deerflow.agents.lead_agent.prompt import get_skills_prompt_section
from deerflow.agents.thread_state import ThreadState
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task
logger = logging.getLogger(__name__)
@@ -182,6 +182,11 @@ async def task_tool(
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
cleanup_background_task(task_id)
return f"Task failed. Error: {result.error}"
elif result.status == SubagentStatus.CANCELLED:
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
cleanup_background_task(task_id)
return "Task cancelled by user."
elif result.status == SubagentStatus.TIMED_OUT:
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
@@ -204,6 +209,11 @@ async def task_tool(
writer({"type": "task_timed_out", "task_id": task_id})
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
except asyncio.CancelledError:
# Signal the background subagent thread to stop cooperatively.
# Without this, the thread (running in ThreadPoolExecutor with its
# own event loop via asyncio.run) would continue executing even
# after the parent task is cancelled.
request_cancel_background_task(task_id)
async def cleanup_when_done() -> None:
max_cleanup_polls = max_poll_count
@@ -214,7 +224,7 @@ async def task_tool(
if result is None:
return
if result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None:
if result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.CANCELLED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None:
cleanup_background_task(task_id)
return
@@ -11,7 +11,11 @@ from weakref import WeakValueDictionary
from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT
<<<<<<< HEAD
from deerflow.agents.lead_agent.prompt import clear_skills_system_prompt_cache
=======
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
>>>>>>> main
from deerflow.agents.thread_state import ThreadState
from deerflow.mcp.tools import _make_sync_tool_wrapper
from deerflow.skills.manager import (
@@ -115,7 +119,11 @@ async def _skill_manage_impl(
name,
_history_record(action="create", file_path="SKILL.md", prev_content=None, new_content=content, thread_id=thread_id, scanner=scan),
)
<<<<<<< HEAD
clear_skills_system_prompt_cache()
=======
await refresh_skills_system_prompt_cache_async()
>>>>>>> main
return f"Created custom skill '{name}'."
if action == "edit":
@@ -132,7 +140,11 @@ async def _skill_manage_impl(
name,
_history_record(action="edit", file_path="SKILL.md", prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan),
)
<<<<<<< HEAD
clear_skills_system_prompt_cache()
=======
await refresh_skills_system_prompt_cache_async()
>>>>>>> main
return f"Updated custom skill '{name}'."
if action == "patch":
@@ -156,7 +168,11 @@ async def _skill_manage_impl(
name,
_history_record(action="patch", file_path="SKILL.md", prev_content=prev_content, new_content=new_content, thread_id=thread_id, scanner=scan),
)
<<<<<<< HEAD
clear_skills_system_prompt_cache()
=======
await refresh_skills_system_prompt_cache_async()
>>>>>>> main
return f"Patched custom skill '{name}' ({replacement_count} replacement(s) applied, {occurrences} match(es) found)."
if action == "delete":
@@ -169,7 +185,11 @@ async def _skill_manage_impl(
_history_record(action="delete", file_path="SKILL.md", prev_content=prev_content, new_content=None, thread_id=thread_id, scanner={"decision": "allow", "reason": "Deletion requested."}),
)
await _to_thread(shutil.rmtree, skill_dir)
<<<<<<< HEAD
clear_skills_system_prompt_cache()
=======
await refresh_skills_system_prompt_cache_async()
>>>>>>> main
return f"Deleted custom skill '{name}'."
if action == "write_file":
+2
View File
@@ -7,6 +7,7 @@ dependencies = [
"agent-client-protocol>=0.4.0",
"agent-sandbox>=0.0.19",
"dotenv>=0.9.9",
"exa-py>=1.0.0",
"httpx>=0.28.0",
"kubernetes>=30.0.0",
"langchain>=1.2.3",
@@ -44,6 +45,7 @@ postgres = [
"psycopg[binary]>=3.3.3",
"psycopg-pool>=3.3.0",
]
ollama = ["langchain-ollama>=0.3.0"]
pymupdf = ["pymupdf4llm>=0.0.17"]
[build-system]