Merge branch 'main' into rayhpeng/persistence-scaffold

# Conflicts:
#	.env.example
#	backend/packages/harness/deerflow/agents/middlewares/title_middleware.py
This commit is contained in:
rayhpeng
2026-04-04 21:28:07 +08:00
180 changed files with 10945 additions and 787 deletions
@@ -345,6 +345,8 @@ def make_lead_agent(config: RunnableConfig):
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
),
state_schema=ThreadState,
)
@@ -8,6 +8,14 @@ from deerflow.subagents import get_available_subagent_names
logger = logging.getLogger(__name__)
def _get_enabled_skills():
try:
return list(load_skills(enabled_only=True))
except Exception:
logger.exception("Failed to load enabled skills for prompt injection")
return []
def _build_subagent_section(max_concurrent: int) -> str:
"""Build the subagent system prompt section with dynamic concurrency limit.
@@ -386,7 +394,7 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
Returns the <skill_system>...</skill_system> block listing all enabled skills,
suitable for injection into any agent's system prompt.
"""
skills = load_skills(enabled_only=True)
skills = _get_enabled_skills()
try:
from deerflow.config import get_app_config
@@ -402,6 +410,10 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
if available_skills is not None:
skills = [skill for skill in skills if skill.name in available_skills]
# Check again after filtering
if not skills:
return ""
skill_items = "\n".join(
f" <skill>\n <name>{skill.name}</name>\n <description>{skill.description}</description>\n <location>{skill.get_container_file_path(container_base_path)}</location>\n </skill>" for skill in skills
)
@@ -446,7 +458,7 @@ def get_deferred_tools_prompt_section() -> str:
if not get_app_config().tool_search.enabled:
return ""
except FileNotFoundError:
except Exception:
return ""
registry = get_deferred_registry()
@@ -29,6 +29,17 @@ Instructions:
2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies)
3. Update the memory sections as needed following the detailed length guidelines below
Before extracting facts, perform a structured reflection on the conversation:
1. Error/Retry Detection: Did the agent encounter errors, require retries, or produce incorrect results?
If yes, record the root cause and correct approach as a high-confidence fact with category "correction".
2. User Correction Detection: Did the user correct the agent's direction, understanding, or output?
If yes, record the correct interpretation or approach as a high-confidence fact with category "correction".
Include what went wrong in "sourceError" only when category is "correction" and the mistake is explicit in the conversation.
3. Project Constraint Discovery: Were any project-specific constraints discovered during the conversation?
If yes, record them as facts with the most appropriate category and confidence.
{correction_hint}
Memory Section Guidelines:
**User Context** (Current state - concise summaries):
@@ -62,6 +73,7 @@ Memory Section Guidelines:
* context: Background facts (job title, projects, locations, languages)
* behavior: Working patterns, communication habits, problem-solving approaches
* goal: Stated objectives, learning targets, project ambitions
* correction: Explicit agent mistakes or user corrections, including the correct approach
- Confidence levels:
* 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y")
* 0.7-0.8: Strongly implied from actions/discussions
@@ -94,7 +106,7 @@ Output Format (JSON):
"longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }}
}},
"newFacts": [
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
],
"factsToRemove": ["fact_id_1", "fact_id_2"]
}}
@@ -104,6 +116,8 @@ Important Rules:
- Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs)
- Include specific metrics, version numbers, and proper nouns in facts
- Only add facts that are clearly stated (0.9+) or strongly implied (0.7+)
- Use category "correction" for explicit agent mistakes or user corrections; assign confidence >= 0.95 when the correction is explicit
- Include "sourceError" only for explicit correction facts when the prior mistake or wrong approach is clearly stated; omit it otherwise
- Remove facts that are contradicted by new information
- When updating topOfMind, integrate new focus areas while removing completed/abandoned ones
Keep 3-5 concurrent focus themes that are still active and relevant
@@ -126,7 +140,7 @@ Message:
Extract facts in this JSON format:
{{
"facts": [
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
]
}}
@@ -136,6 +150,7 @@ Categories:
- context: Background context (location, job, projects)
- behavior: Behavioral patterns
- goal: User's goals or objectives
- correction: Explicit corrections or mistakes to avoid repeating
Rules:
- Only extract clear, specific facts
@@ -231,6 +246,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
if earlier.get("summary"):
history_sections.append(f"Earlier: {earlier['summary']}")
background = history_data.get("longTermBackground", {})
if background.get("summary"):
history_sections.append(f"Background: {background['summary']}")
if history_sections:
sections.append("History:\n" + "\n".join(f"- {s}" for s in history_sections))
@@ -262,7 +281,11 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
continue
category = str(fact.get("category", "context")).strip() or "context"
confidence = _coerce_confidence(fact.get("confidence"), default=0.0)
line = f"- [{category} | {confidence:.2f}] {content}"
source_error = fact.get("sourceError")
if category == "correction" and isinstance(source_error, str) and source_error.strip():
line = f"- [{category} | {confidence:.2f}] {content} (avoid: {source_error.strip()})"
else:
line = f"- [{category} | {confidence:.2f}] {content}"
# Each additional line is preceded by a newline (except the first).
line_text = ("\n" + line) if fact_lines else line
@@ -20,6 +20,7 @@ class ConversationContext:
messages: list[Any]
timestamp: datetime = field(default_factory=datetime.utcnow)
agent_name: str | None = None
correction_detected: bool = False
class MemoryUpdateQueue:
@@ -37,25 +38,38 @@ class MemoryUpdateQueue:
self._timer: threading.Timer | None = None
self._processing = False
def add(self, thread_id: str, messages: list[Any], agent_name: str | None = None) -> None:
def add(
self,
thread_id: str,
messages: list[Any],
agent_name: str | None = None,
correction_detected: bool = False,
) -> None:
"""Add a conversation to the update queue.
Args:
thread_id: The thread ID.
messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
correction_detected: Whether recent turns include an explicit correction signal.
"""
config = get_memory_config()
if not config.enabled:
return
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
)
with self._lock:
existing_context = next(
(context for context in self._queue if context.thread_id == thread_id),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=merged_correction_detected,
)
# Check if this thread already has a pending update
# If so, replace it with the newer one
self._queue = [c for c in self._queue if c.thread_id != thread_id]
@@ -115,6 +129,7 @@ class MemoryUpdateQueue:
messages=context.messages,
thread_id=context.thread_id,
agent_name=context.agent_name,
correction_detected=context.correction_detected,
)
if success:
logger.info("Memory updated successfully for thread %s", context.thread_id)
@@ -266,13 +266,20 @@ class MemoryUpdater:
model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False)
def update_memory(self, messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
def update_memory(
self,
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
) -> bool:
"""Update memory based on conversation messages.
Args:
messages: List of conversation messages.
thread_id: Optional thread ID for tracking source.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
Returns:
True if update was successful, False otherwise.
@@ -295,9 +302,19 @@ class MemoryUpdater:
return False
# Build prompt
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
)
# Call LLM
@@ -383,6 +400,8 @@ class MemoryUpdater:
confidence = fact.get("confidence", 0.5)
if confidence >= config.fact_confidence_threshold:
raw_content = fact.get("content", "")
if not isinstance(raw_content, str):
continue
normalized_content = raw_content.strip()
fact_key = _fact_content_key(normalized_content)
if fact_key is not None and fact_key in existing_fact_keys:
@@ -396,6 +415,11 @@ class MemoryUpdater:
"createdAt": now,
"source": thread_id or "unknown",
}
source_error = fact.get("sourceError")
if isinstance(source_error, str):
normalized_source_error = source_error.strip()
if normalized_source_error:
fact_entry["sourceError"] = normalized_source_error
current_memory["facts"].append(fact_entry)
if fact_key is not None:
existing_fact_keys.add(fact_key)
@@ -412,16 +436,22 @@ class MemoryUpdater:
return current_memory
def update_memory_from_conversation(messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
def update_memory_from_conversation(
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
) -> bool:
"""Convenience function to update memory from a conversation.
Args:
messages: List of conversation messages.
thread_id: Optional thread ID.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
Returns:
True if successful, False otherwise.
"""
updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name)
return updater.update_memory(messages, thread_id, agent_name, correction_detected)
@@ -0,0 +1,275 @@
"""LLM error handling middleware with retry/backoff and user-facing fallbacks."""
from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Awaitable, Callable
from email.utils import parsedate_to_datetime
from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import (
ModelCallResult,
ModelRequest,
ModelResponse,
)
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
logger = logging.getLogger(__name__)
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
_BUSY_PATTERNS = (
"server busy",
"temporarily unavailable",
"try again later",
"please retry",
"please try again",
"overloaded",
"high demand",
"rate limit",
"负载较高",
"服务繁忙",
"稍后重试",
"请稍后重试",
)
_QUOTA_PATTERNS = (
"insufficient_quota",
"quota",
"billing",
"credit",
"payment",
"余额不足",
"超出限额",
"额度不足",
"欠费",
)
_AUTH_PATTERNS = (
"authentication",
"unauthorized",
"invalid api key",
"invalid_api_key",
"permission",
"forbidden",
"access denied",
"无权",
"未授权",
)
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"""Retry transient LLM errors and surface graceful assistant messages."""
retry_max_attempts: int = 3
retry_base_delay_ms: int = 1000
retry_cap_delay_ms: int = 8000
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
detail = _extract_error_detail(exc)
lowered = detail.lower()
error_code = _extract_error_code(exc)
status_code = _extract_status_code(exc)
if _matches_any(lowered, _QUOTA_PATTERNS) or _matches_any(str(error_code).lower(), _QUOTA_PATTERNS):
return False, "quota"
if _matches_any(lowered, _AUTH_PATTERNS):
return False, "auth"
exc_name = exc.__class__.__name__
if exc_name in {
"APITimeoutError",
"APIConnectionError",
"InternalServerError",
}:
return True, "transient"
if status_code in _RETRIABLE_STATUS_CODES:
return True, "transient"
if _matches_any(lowered, _BUSY_PATTERNS):
return True, "busy"
return False, "generic"
def _build_retry_delay_ms(self, attempt: int, exc: BaseException) -> int:
retry_after = _extract_retry_after_ms(exc)
if retry_after is not None:
return retry_after
backoff = self.retry_base_delay_ms * (2 ** max(0, attempt - 1))
return min(backoff, self.retry_cap_delay_ms)
def _build_retry_message(self, attempt: int, wait_ms: int, reason: str) -> str:
seconds = max(1, round(wait_ms / 1000))
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
def _build_user_message(self, exc: BaseException, reason: str) -> str:
detail = _extract_error_detail(exc)
if reason == "quota":
return "The configured LLM provider rejected the request because the account is out of quota, billing is unavailable, or usage is restricted. Please fix the provider account and try again."
if reason == "auth":
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
if reason in {"busy", "transient"}:
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
return f"LLM request failed: {detail}"
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
try:
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer(
{
"type": "llm_retry",
"attempt": attempt,
"max_attempts": self.retry_max_attempts,
"wait_ms": wait_ms,
"reason": reason,
"message": self._build_retry_message(attempt, wait_ms, reason),
}
)
except Exception:
logger.debug("Failed to emit llm_retry event", exc_info=True)
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
attempt = 1
while True:
try:
return handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
if retriable and attempt < self.retry_max_attempts:
wait_ms = self._build_retry_delay_ms(attempt, exc)
logger.warning(
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
attempt,
self.retry_max_attempts,
wait_ms,
_extract_error_detail(exc),
)
self._emit_retry_event(attempt, wait_ms, reason)
time.sleep(wait_ms / 1000)
attempt += 1
continue
logger.warning(
"LLM call failed after %d attempt(s): %s",
attempt,
_extract_error_detail(exc),
exc_info=exc,
)
return AIMessage(content=self._build_user_message(exc, reason))
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
attempt = 1
while True:
try:
return await handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
if retriable and attempt < self.retry_max_attempts:
wait_ms = self._build_retry_delay_ms(attempt, exc)
logger.warning(
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
attempt,
self.retry_max_attempts,
wait_ms,
_extract_error_detail(exc),
)
self._emit_retry_event(attempt, wait_ms, reason)
await asyncio.sleep(wait_ms / 1000)
attempt += 1
continue
logger.warning(
"LLM call failed after %d attempt(s): %s",
attempt,
_extract_error_detail(exc),
exc_info=exc,
)
return AIMessage(content=self._build_user_message(exc, reason))
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
return any(pattern in detail for pattern in patterns)
def _extract_error_code(exc: BaseException) -> Any:
for attr in ("code", "error_code"):
value = getattr(exc, attr, None)
if value not in (None, ""):
return value
body = getattr(exc, "body", None)
if isinstance(body, dict):
error = body.get("error")
if isinstance(error, dict):
for key in ("code", "type"):
value = error.get(key)
if value not in (None, ""):
return value
return None
def _extract_status_code(exc: BaseException) -> int | None:
for attr in ("status_code", "status"):
value = getattr(exc, attr, None)
if isinstance(value, int):
return value
response = getattr(exc, "response", None)
status = getattr(response, "status_code", None)
return status if isinstance(status, int) else None
def _extract_retry_after_ms(exc: BaseException) -> int | None:
response = getattr(exc, "response", None)
headers = getattr(response, "headers", None)
if headers is None:
return None
raw = None
header_name = ""
for key in ("retry-after-ms", "Retry-After-Ms", "retry-after", "Retry-After"):
header_name = key
if hasattr(headers, "get"):
raw = headers.get(key)
if raw:
break
if not raw:
return None
try:
multiplier = 1 if "ms" in header_name.lower() else 1000
return max(0, int(float(raw) * multiplier))
except (TypeError, ValueError):
try:
target = parsedate_to_datetime(str(raw))
delta = target.timestamp() - time.time()
return max(0, int(delta * 1000))
except (TypeError, ValueError, OverflowError):
return None
def _extract_error_detail(exc: BaseException) -> str:
detail = str(exc).strip()
if detail:
return detail
message = getattr(exc, "message", None)
if isinstance(message, str) and message.strip():
return message.strip()
return exc.__class__.__name__
@@ -182,6 +182,23 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
return None, False
@staticmethod
def _append_text(content: str | list | None, text: str) -> str | list:
"""Append *text* to AIMessage content, handling str, list, and None.
When content is a list of content blocks (e.g. Anthropic thinking mode),
we append a new ``{"type": "text", ...}`` block instead of concatenating
a string to a list, which would raise ``TypeError``.
"""
if content is None:
return text
if isinstance(content, list):
return [*content, {"type": "text", "text": f"\n\n{text}"}]
if isinstance(content, str):
return content + f"\n\n{text}"
# Fallback: coerce unexpected types to str to avoid TypeError
return str(content) + f"\n\n{text}"
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
warning, hard_stop = self._track_and_check(state, runtime)
@@ -192,7 +209,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
stripped_msg = last_msg.model_copy(
update={
"tool_calls": [],
"content": (last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}",
"content": self._append_text(last_msg.content, _HARD_STOP_MSG),
}
)
return {"messages": [stripped_msg]}
@@ -14,6 +14,21 @@ from deerflow.config.memory_config import get_memory_config
logger = logging.getLogger(__name__)
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
_CORRECTION_PATTERNS = (
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
re.compile(r"\btry again\b", re.IGNORECASE),
re.compile(r"\bredo\b", re.IGNORECASE),
re.compile(r"不对"),
re.compile(r"你理解错了"),
re.compile(r"你理解有误"),
re.compile(r"重试"),
re.compile(r"重新来"),
re.compile(r"换一种"),
re.compile(r"改用"),
)
class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
@@ -21,6 +36,22 @@ class MemoryMiddlewareState(AgentState):
pass
def _extract_message_text(message: Any) -> str:
"""Extract plain text from message content for filtering and signal detection."""
content = getattr(message, "content", "")
if isinstance(content, list):
text_parts: list[str] = []
for part in content:
if isinstance(part, str):
text_parts.append(part)
elif isinstance(part, dict):
text_val = part.get("text")
if isinstance(text_val, str):
text_parts.append(text_val)
return " ".join(text_parts)
return str(content)
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
"""Filter messages to keep only user inputs and final assistant responses.
@@ -44,18 +75,13 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
Returns:
Filtered list containing only user inputs and final assistant responses.
"""
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
filtered = []
skip_next_ai = False
for msg in messages:
msg_type = getattr(msg, "type", None)
if msg_type == "human":
content = getattr(msg, "content", "")
if isinstance(content, list):
content = " ".join(p.get("text", "") for p in content if isinstance(p, dict))
content_str = str(content)
content_str = _extract_message_text(msg)
if "<uploaded_files>" in content_str:
# Strip the ephemeral upload block; keep the user's real question.
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
@@ -87,6 +113,25 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
return filtered
def detect_correction(messages: list[Any]) -> bool:
"""Detect explicit user corrections in recent conversation turns.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale corrections from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
return True
return False
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution.
@@ -150,7 +195,13 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
return None
# Queue the filtered conversation for memory update
correction_detected = detect_correction(filtered_messages)
queue = get_memory_queue()
queue.add(thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name)
queue.add(
thread_id=thread_id,
messages=filtered_messages,
agent_name=self._agent_name,
correction_detected=correction_detected,
)
return None
@@ -116,44 +116,33 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return config
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Synchronously generate a title. Returns state update or None."""
"""Generate a local fallback title without blocking on an LLM call."""
if not self._should_generate_title(state):
return None
prompt, user_msg = self._build_title_prompt(state)
config = get_title_config()
model = create_chat_model(name=config.model_name, thinking_enabled=False)
try:
response = model.invoke(prompt, config=self._get_runnable_config())
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)
except Exception:
logger.exception("Failed to generate title (sync)")
title = self._fallback_title(user_msg)
return {"title": title}
_, user_msg = self._build_title_prompt(state)
return {"title": self._fallback_title(user_msg)}
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Asynchronously generate a title. Returns state update or None."""
"""Generate a title asynchronously and fall back locally on failure."""
if not self._should_generate_title(state):
return None
prompt, user_msg = self._build_title_prompt(state)
config = get_title_config()
model = create_chat_model(name=config.model_name, thinking_enabled=False)
prompt, user_msg = self._build_title_prompt(state)
try:
response = await model.ainvoke(prompt, config=self._get_runnable_config())
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
else:
model = create_chat_model(thinking_enabled=False)
response = await model.ainvoke(prompt)
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)
if title:
return {"title": title}
except Exception:
logger.exception("Failed to generate title (async)")
title = self._fallback_title(user_msg)
return {"title": title}
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
return {"title": self._fallback_title(user_msg)}
@override
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
@@ -72,6 +72,7 @@ def _build_runtime_middlewares(
lazy_init: bool = True,
) -> list[AgentMiddleware]:
"""Build shared base middlewares for agent execution."""
from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from deerflow.sandbox.middleware import SandboxMiddleware
@@ -90,6 +91,8 @@ def _build_runtime_middlewares(
middlewares.append(DanglingToolCallMiddleware())
middlewares.append(LLMErrorHandlingMiddleware())
# Guardrail middleware (if configured)
from deerflow.config.guardrails_config import get_guardrails_config
@@ -135,6 +138,6 @@ def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentM
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
return _build_runtime_middlewares(
include_uploads=False,
include_dangling_tool_call_patch=False,
include_dangling_tool_call_patch=True,
lazy_init=lazy_init,
)
@@ -10,10 +10,52 @@ from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.config.paths import Paths, get_paths
from deerflow.utils.file_conversion import extract_outline
logger = logging.getLogger(__name__)
_OUTLINE_PREVIEW_LINES = 5
def _extract_outline_for_file(file_path: Path) -> tuple[list[dict], list[str]]:
"""Return the document outline and fallback preview for *file_path*.
Looks for a sibling ``<stem>.md`` file produced by the upload conversion
pipeline.
Returns:
(outline, preview) where:
- outline: list of ``{title, line}`` dicts (plus optional sentinel).
Empty when no headings are found or no .md exists.
- preview: first few non-empty lines of the .md, used as a content
anchor when outline is empty so the agent has some context.
Empty when outline is non-empty (no fallback needed).
"""
md_path = file_path.with_suffix(".md")
if not md_path.is_file():
return [], []
outline = extract_outline(md_path)
if outline:
logger.debug("Extracted %d outline entries from %s", len(outline), file_path.name)
return outline, []
# outline is empty — read the first few non-empty lines as a content preview
preview: list[str] = []
try:
with md_path.open(encoding="utf-8") as f:
for line in f:
stripped = line.strip()
if stripped:
preview.append(stripped)
if len(preview) >= _OUTLINE_PREVIEW_LINES:
break
except Exception:
logger.debug("Failed to read preview lines from %s", md_path, exc_info=True)
return [], preview
class UploadsMiddlewareState(AgentState):
"""State schema for uploads middleware."""
@@ -39,12 +81,38 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
super().__init__()
self._paths = Paths(base_dir) if base_dir else get_paths()
def _format_file_entry(self, file: dict, lines: list[str]) -> None:
"""Append a single file entry (name, size, path, optional outline) to lines."""
size_kb = file["size"] / 1024
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
lines.append(f"- {file['filename']} ({size_str})")
lines.append(f" Path: {file['path']}")
outline = file.get("outline") or []
if outline:
truncated = outline[-1].get("truncated", False)
visible = [e for e in outline if not e.get("truncated")]
lines.append(" Document outline (use `read_file` with line ranges to read sections):")
for entry in visible:
lines.append(f" L{entry['line']}: {entry['title']}")
if truncated:
lines.append(f" ... (showing first {len(visible)} headings; use `read_file` to explore further)")
else:
preview = file.get("outline_preview") or []
if preview:
lines.append(" No structural headings detected. Document begins with:")
for text in preview:
lines.append(f" > {text}")
lines.append(" Use `grep` to search for keywords (e.g. `grep(pattern='keyword', path='/mnt/user-data/uploads/')`).")
lines.append("")
def _create_files_message(self, new_files: list[dict], historical_files: list[dict]) -> str:
"""Create a formatted message listing uploaded files.
Args:
new_files: Files uploaded in the current message.
historical_files: Files uploaded in previous messages.
Each file dict may contain an optional ``outline`` key — a list of
``{title, line}`` dicts extracted from the converted Markdown file.
Returns:
Formatted string inside <uploaded_files> tags.
@@ -55,25 +123,24 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
lines.append("")
if new_files:
for file in new_files:
size_kb = file["size"] / 1024
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
lines.append(f"- {file['filename']} ({size_str})")
lines.append(f" Path: {file['path']}")
lines.append("")
self._format_file_entry(file, lines)
else:
lines.append("(empty)")
lines.append("")
if historical_files:
lines.append("The following files were uploaded in previous messages and are still available:")
lines.append("")
for file in historical_files:
size_kb = file["size"] / 1024
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
lines.append(f"- {file['filename']} ({size_str})")
lines.append(f" Path: {file['path']}")
lines.append("")
self._format_file_entry(file, lines)
lines.append("You can read these files using the `read_file` tool with the paths shown above.")
lines.append("To work with these files:")
lines.append("- Read from the file first — use the outline line numbers and `read_file` to locate relevant sections.")
lines.append("- Use `grep` to search for keywords when you are not sure which section to look at")
lines.append(" (e.g. `grep(pattern='revenue', path='/mnt/user-data/uploads/')`).")
lines.append("- Use `glob` to find files by name pattern")
lines.append(" (e.g. `glob(pattern='**/*.md', path='/mnt/user-data/uploads/')`).")
lines.append("- Only fall back to web search if the file content is clearly insufficient to answer the question.")
lines.append("</uploaded_files>")
return "\n".join(lines)
@@ -147,6 +214,13 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
# Resolve uploads directory for existence checks
thread_id = (runtime.context or {}).get("thread_id")
if thread_id is None:
try:
from langgraph.config import get_config
thread_id = get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
pass # get_config() raises outside a runnable context (e.g. unit tests)
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files
@@ -159,15 +233,26 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
for file_path in sorted(uploads_dir.iterdir()):
if file_path.is_file() and file_path.name not in new_filenames:
stat = file_path.stat()
outline, preview = _extract_outline_for_file(file_path)
historical_files.append(
{
"filename": file_path.name,
"size": stat.st_size,
"path": f"/mnt/user-data/uploads/{file_path.name}",
"extension": file_path.suffix,
"outline": outline,
"outline_preview": preview,
}
)
# Attach outlines to new files as well
if uploads_dir:
for file in new_files:
phys_path = uploads_dir / file["filename"]
outline, preview = _extract_outline_for_file(phys_path)
file["outline"] = outline
file["outline_preview"] = preview
if not new_files and not historical_files:
return None
@@ -117,6 +117,7 @@ class DeerFlowClient:
subagent_enabled: bool = False,
plan_mode: bool = False,
agent_name: str | None = None,
available_skills: set[str] | None = None,
middlewares: Sequence[AgentMiddleware] | None = None,
):
"""Initialize the client.
@@ -133,6 +134,7 @@ class DeerFlowClient:
subagent_enabled: Enable subagent delegation.
plan_mode: Enable TodoList middleware for plan mode.
agent_name: Name of the agent to use.
available_skills: Optional set of skill names to make available. If None (default), all scanned skills are available.
middlewares: Optional list of custom middlewares to inject into the agent.
"""
if config_path is not None:
@@ -148,6 +150,7 @@ class DeerFlowClient:
self._subagent_enabled = subagent_enabled
self._plan_mode = plan_mode
self._agent_name = agent_name
self._available_skills = set(available_skills) if available_skills is not None else None
self._middlewares = list(middlewares) if middlewares else []
# Lazy agent — created on first call, recreated when config changes.
@@ -208,6 +211,8 @@ class DeerFlowClient:
cfg.get("thinking_enabled"),
cfg.get("is_plan_mode"),
cfg.get("subagent_enabled"),
self._agent_name,
frozenset(self._available_skills) if self._available_skills is not None else None,
)
if self._agent is not None and self._agent_config_key == key:
@@ -226,6 +231,7 @@ class DeerFlowClient:
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
agent_name=self._agent_name,
available_skills=self._available_skills,
),
"state_schema": ThreadState,
}
@@ -7,6 +7,7 @@ import uuid
from agent_sandbox import Sandbox as AioSandboxClient
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.search import GrepMatch, path_matches, should_ignore_path, truncate_line
logger = logging.getLogger(__name__)
@@ -124,16 +125,96 @@ class AioSandbox(Sandbox):
content: The text content to write to the file.
append: Whether to append the content to the file.
"""
try:
if append:
# Read existing content first and append
existing = self.read_file(path)
if not existing.startswith("Error:"):
content = existing + content
self._client.file.write_file(file=path, content=content)
except Exception as e:
logger.error(f"Failed to write file in sandbox: {e}")
raise
with self._lock:
try:
if append:
existing = self.read_file(path)
if not existing.startswith("Error:"):
content = existing + content
self._client.file.write_file(file=path, content=content)
except Exception as e:
logger.error(f"Failed to write file in sandbox: {e}")
raise
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
if not include_dirs:
result = self._client.file.find_files(path=path, glob=pattern)
files = result.data.files if result.data and result.data.files else []
filtered = [file_path for file_path in files if not should_ignore_path(file_path)]
truncated = len(filtered) > max_results
return filtered[:max_results], truncated
result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
entries = result.data.files if result.data and result.data.files else []
matches: list[str] = []
root_path = path.rstrip("/") or "/"
root_prefix = root_path if root_path == "/" else f"{root_path}/"
for entry in entries:
if entry.path != root_path and not entry.path.startswith(root_prefix):
continue
if should_ignore_path(entry.path):
continue
rel_path = entry.path[len(root_path) :].lstrip("/")
if path_matches(pattern, rel_path):
matches.append(entry.path)
if len(matches) >= max_results:
return matches, True
return matches, False
def grep(
self,
path: str,
pattern: str,
*,
glob: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = 100,
) -> tuple[list[GrepMatch], bool]:
import re as _re
regex_source = _re.escape(pattern) if literal else pattern
# Validate the pattern locally so an invalid regex raises re.error
# (caught by grep_tool's except re.error handler) rather than a
# generic remote API error.
_re.compile(regex_source, 0 if case_sensitive else _re.IGNORECASE)
regex = regex_source if case_sensitive else f"(?i){regex_source}"
if glob is not None:
find_result = self._client.file.find_files(path=path, glob=glob)
candidate_paths = find_result.data.files if find_result.data and find_result.data.files else []
else:
list_result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
entries = list_result.data.files if list_result.data and list_result.data.files else []
candidate_paths = [entry.path for entry in entries if not entry.is_directory]
matches: list[GrepMatch] = []
truncated = False
for file_path in candidate_paths:
if should_ignore_path(file_path):
continue
search_result = self._client.file.search_in_file(file=file_path, regex=regex)
data = search_result.data
if data is None:
continue
line_numbers = data.line_numbers or []
matched_lines = data.matches or []
for line_number, line in zip(line_numbers, matched_lines):
matches.append(
GrepMatch(
path=file_path,
line_number=line_number if isinstance(line_number, int) else 0,
line=truncate_line(line),
)
)
if len(matches) >= max_results:
truncated = True
return matches, truncated
return matches, truncated
def update_file(self, path: str, content: bytes) -> None:
"""Update a file with binary content in the sandbox.
@@ -142,9 +223,10 @@ class AioSandbox(Sandbox):
path: The absolute path of the file to update.
content: The binary content to write to the file.
"""
try:
base64_content = base64.b64encode(content).decode("utf-8")
self._client.file.write_file(file=path, content=base64_content, encoding="base64")
except Exception as e:
logger.error(f"Failed to update file in sandbox: {e}")
raise
with self._lock:
try:
base64_content = base64.b64encode(content).decode("utf-8")
self._client.file.write_file(file=path, content=base64_content, encoding="base64")
except Exception as e:
logger.error(f"Failed to update file in sandbox: {e}")
raise
@@ -1,13 +1,16 @@
import logging
import os
import requests
import httpx
logger = logging.getLogger(__name__)
_api_key_warned = False
class JinaClient:
def crawl(self, url: str, return_format: str = "html", timeout: int = 10) -> str:
async def crawl(self, url: str, return_format: str = "html", timeout: int = 10) -> str:
global _api_key_warned
headers = {
"Content-Type": "application/json",
"X-Return-Format": return_format,
@@ -15,11 +18,13 @@ class JinaClient:
}
if os.getenv("JINA_API_KEY"):
headers["Authorization"] = f"Bearer {os.getenv('JINA_API_KEY')}"
else:
elif not _api_key_warned:
_api_key_warned = True
logger.warning("Jina API key is not set. Provide your own key to access a higher rate limit. See https://jina.ai/reader for more information.")
data = {"url": url}
try:
response = requests.post("https://r.jina.ai/", headers=headers, json=data)
async with httpx.AsyncClient() as client:
response = await client.post("https://r.jina.ai/", headers=headers, json=data, timeout=timeout)
if response.status_code != 200:
error_message = f"Jina API returned status {response.status_code}: {response.text}"
@@ -34,5 +39,5 @@ class JinaClient:
return response.text
except Exception as e:
error_message = f"Request to Jina API failed: {str(e)}"
logger.error(error_message)
logger.exception(error_message)
return f"Error: {error_message}"
@@ -8,7 +8,7 @@ readability_extractor = ReadabilityExtractor()
@tool("web_fetch", parse_docstring=True)
def web_fetch_tool(url: str) -> str:
async def web_fetch_tool(url: str) -> str:
"""Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@@ -23,6 +23,8 @@ def web_fetch_tool(url: str) -> str:
config = get_app_config().get_tool_config("web_fetch")
if config is not None and "timeout" in config.model_extra:
timeout = config.model_extra.get("timeout")
html_content = jina_client.crawl(url, return_format="html", timeout=timeout)
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
if isinstance(html_content, str) and html_content.startswith("Error:"):
return html_content
article = readability_extractor.extract_article(html_content)
return article.to_markdown()[:4096]
@@ -3,7 +3,13 @@ from .extensions_config import ExtensionsConfig, get_extensions_config
from .memory_config import MemoryConfig, get_memory_config
from .paths import Paths, get_paths
from .skills_config import SkillsConfig
from .tracing_config import get_tracing_config, is_tracing_enabled
from .tracing_config import (
get_enabled_tracing_providers,
get_explicitly_enabled_tracing_providers,
get_tracing_config,
is_tracing_enabled,
validate_enabled_tracing_providers,
)
__all__ = [
"get_app_config",
@@ -15,5 +21,8 @@ __all__ = [
"MemoryConfig",
"get_memory_config",
"get_tracing_config",
"get_explicitly_enabled_tracing_providers",
"get_enabled_tracing_providers",
"is_tracing_enabled",
"validate_enabled_tracing_providers",
]
@@ -22,6 +22,11 @@ class AgentConfig(BaseModel):
description: str = ""
model: str | None = None
tool_groups: list[str] | None = None
# skills controls which skills are loaded into the agent's prompt:
# - None (or omitted): load all enabled skills (default fallback behavior)
# - [] (explicit empty list): disable all skills
# - ["skill1", "skill2"]: load only the specified skills
skills: list[str] | None = None
def load_agent_config(name: str | None) -> AgentConfig | None:
@@ -1,5 +1,6 @@
import logging
import os
from contextvars import ContextVar
from pathlib import Path
from typing import Any, Self
@@ -11,16 +12,16 @@ from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.database_config import DatabaseConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import load_guardrails_config_from_dict
from deerflow.config.memory_config import load_memory_config_from_dict
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skills_config import SkillsConfig
from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict
from deerflow.config.subagents_config import load_subagents_config_from_dict
from deerflow.config.summarization_config import load_summarization_config_from_dict
from deerflow.config.title_config import load_title_config_from_dict
from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict
from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict
from deerflow.config.token_usage_config import TokenUsageConfig
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
@@ -30,6 +31,13 @@ load_dotenv()
logger = logging.getLogger(__name__)
def _default_config_candidates() -> tuple[Path, ...]:
"""Return deterministic config.yaml locations without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
return (backend_dir / "config.yaml", repo_root / "config.yaml")
class AppConfig(BaseModel):
"""Config for the DeerFlow application"""
@@ -42,6 +50,11 @@ class AppConfig(BaseModel):
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
extensions: ExtensionsConfig = Field(default_factory=ExtensionsConfig, description="Extensions configuration (MCP servers and skills state)")
tool_search: ToolSearchConfig = Field(default_factory=ToolSearchConfig, description="Tool search / deferred loading configuration")
title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration")
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
model_config = ConfigDict(extra="allow", frozen=False)
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
@@ -55,7 +68,7 @@ class AppConfig(BaseModel):
Priority:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
3. Otherwise, first check the `config.yaml` in the current directory, then fallback to `config.yaml` in the parent directory.
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
"""
if config_path:
path = Path(config_path)
@@ -68,14 +81,10 @@ class AppConfig(BaseModel):
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
return path
else:
# Check if the config.yaml is in the current directory
path = Path(os.getcwd()) / "config.yaml"
if not path.exists():
# Check if the config.yaml is in the parent directory of CWD
path = Path(os.getcwd()).parent / "config.yaml"
if not path.exists():
raise FileNotFoundError("`config.yaml` file not found at the current directory nor its parent directory")
return path
for path in _default_config_candidates():
if path.exists():
return path
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
@classmethod
def from_file(cls, config_path: str | None = None) -> Self:
@@ -248,6 +257,8 @@ _app_config: AppConfig | None = None
_app_config_path: Path | None = None
_app_config_mtime: float | None = None
_app_config_is_custom = False
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
def _get_config_mtime(config_path: Path) -> float | None:
@@ -280,6 +291,10 @@ def get_app_config() -> AppConfig:
"""
global _app_config, _app_config_path, _app_config_mtime
runtime_override = _current_app_config.get()
if runtime_override is not None:
return runtime_override
if _app_config is not None and _app_config_is_custom:
return _app_config
@@ -341,3 +356,26 @@ def set_app_config(config: AppConfig) -> None:
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = True
def peek_current_app_config() -> AppConfig | None:
"""Return the runtime-scoped AppConfig override, if one is active."""
return _current_app_config.get()
def push_current_app_config(config: AppConfig) -> None:
"""Push a runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
_current_app_config_stack.set(stack + (_current_app_config.get(),))
_current_app_config.set(config)
def pop_current_app_config() -> None:
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
if not stack:
_current_app_config.set(None)
return
previous = stack[-1]
_current_app_config_stack.set(stack[:-1])
_current_app_config.set(previous)
@@ -80,6 +80,12 @@ class ExtensionsConfig(BaseModel):
Args:
config_path: Optional path to extensions config file.
Resolution order:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
3. Otherwise, search backend/repository-root defaults for
`extensions_config.json`, then legacy `mcp_config.json`.
Returns:
Path to the extensions config file if found, otherwise None.
"""
@@ -94,24 +100,16 @@ class ExtensionsConfig(BaseModel):
raise FileNotFoundError(f"Extensions config file specified by environment variable `DEER_FLOW_EXTENSIONS_CONFIG_PATH` not found at {path}")
return path
else:
# Check if the extensions_config.json is in the current directory
path = Path(os.getcwd()) / "extensions_config.json"
if path.exists():
return path
# Check if the extensions_config.json is in the parent directory of CWD
path = Path(os.getcwd()).parent / "extensions_config.json"
if path.exists():
return path
# Backward compatibility: check for mcp_config.json
path = Path(os.getcwd()) / "mcp_config.json"
if path.exists():
return path
path = Path(os.getcwd()).parent / "mcp_config.json"
if path.exists():
return path
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
for path in (
backend_dir / "extensions_config.json",
repo_root / "extensions_config.json",
backend_dir / "mcp_config.json",
repo_root / "mcp_config.json",
):
if path.exists():
return path
# Extensions are optional, so return None if not found
return None
@@ -9,6 +9,12 @@ VIRTUAL_PATH_PREFIX = "/mnt/user-data"
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
def _default_local_base_dir() -> Path:
"""Return the repo-local DeerFlow state directory without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4]
return backend_dir / ".deer-flow"
def _validate_thread_id(thread_id: str) -> str:
"""Validate a thread ID before using it in filesystem paths."""
if not _SAFE_THREAD_ID_RE.match(thread_id):
@@ -67,8 +73,7 @@ class Paths:
BaseDir resolution (in priority order):
1. Constructor argument `base_dir`
2. DEER_FLOW_HOME environment variable
3. Local dev fallback: cwd/.deer-flow (when cwd is the backend/ dir)
4. Default: $HOME/.deer-flow
3. Repo-local fallback derived from this module path: `{backend_dir}/.deer-flow`
"""
def __init__(self, base_dir: str | Path | None = None) -> None:
@@ -104,11 +109,7 @@ class Paths:
if env_home := os.getenv("DEER_FLOW_HOME"):
return Path(env_home).resolve()
cwd = Path.cwd()
if cwd.name == "backend" or (cwd / "pyproject.toml").exists():
return cwd / ".deer-flow"
return Path.home() / ".deer-flow"
return _default_local_base_dir()
@property
def memory_file(self) -> Path:
@@ -64,4 +64,15 @@ class SandboxConfig(BaseModel):
description="Environment variables to inject into the sandbox container. Values starting with $ will be resolved from host environment variables.",
)
bash_output_max_chars: int = Field(
default=20000,
ge=0,
description="Maximum characters to keep from bash tool output. Output exceeding this limit is middle-truncated (head + tail), preserving the first and last half. Set to 0 to disable truncation.",
)
read_file_output_max_chars: int = Field(
default=50000,
ge=0,
description="Maximum characters to keep from read_file tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
)
model_config = ConfigDict(extra="allow")
@@ -3,6 +3,11 @@ from pathlib import Path
from pydantic import BaseModel, Field
def _default_repo_root() -> Path:
"""Resolve the repo root without relying on the current working directory."""
return Path(__file__).resolve().parents[5]
class SkillsConfig(BaseModel):
"""Configuration for skills system"""
@@ -26,8 +31,8 @@ class SkillsConfig(BaseModel):
# Use configured path (can be absolute or relative)
path = Path(self.path)
if not path.is_absolute():
# If relative, resolve from current working directory
path = Path.cwd() / path
# If relative, resolve from the repo root for deterministic behavior.
path = _default_repo_root() / path
return path.resolve()
else:
# Default: ../skills relative to backend directory
@@ -1,14 +1,12 @@
import logging
import os
import threading
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
_config_lock = threading.Lock()
class TracingConfig(BaseModel):
class LangSmithTracingConfig(BaseModel):
"""Configuration for LangSmith tracing."""
enabled: bool = Field(...)
@@ -18,9 +16,69 @@ class TracingConfig(BaseModel):
@property
def is_configured(self) -> bool:
"""Check if tracing is fully configured (enabled and has API key)."""
return self.enabled and bool(self.api_key)
def validate(self) -> None:
if self.enabled and not self.api_key:
raise ValueError("LangSmith tracing is enabled but LANGSMITH_API_KEY (or LANGCHAIN_API_KEY) is not set.")
class LangfuseTracingConfig(BaseModel):
"""Configuration for Langfuse tracing."""
enabled: bool = Field(...)
public_key: str | None = Field(...)
secret_key: str | None = Field(...)
host: str = Field(...)
@property
def is_configured(self) -> bool:
return self.enabled and bool(self.public_key) and bool(self.secret_key)
def validate(self) -> None:
if not self.enabled:
return
missing: list[str] = []
if not self.public_key:
missing.append("LANGFUSE_PUBLIC_KEY")
if not self.secret_key:
missing.append("LANGFUSE_SECRET_KEY")
if missing:
raise ValueError(f"Langfuse tracing is enabled but required settings are missing: {', '.join(missing)}")
class TracingConfig(BaseModel):
"""Tracing configuration for supported providers."""
langsmith: LangSmithTracingConfig = Field(...)
langfuse: LangfuseTracingConfig = Field(...)
@property
def is_configured(self) -> bool:
return bool(self.enabled_providers)
@property
def explicitly_enabled_providers(self) -> list[str]:
enabled: list[str] = []
if self.langsmith.enabled:
enabled.append("langsmith")
if self.langfuse.enabled:
enabled.append("langfuse")
return enabled
@property
def enabled_providers(self) -> list[str]:
enabled: list[str] = []
if self.langsmith.is_configured:
enabled.append("langsmith")
if self.langfuse.is_configured:
enabled.append("langfuse")
return enabled
def validate_enabled(self) -> None:
self.langsmith.validate()
self.langfuse.validate()
_tracing_config: TracingConfig | None = None
@@ -29,12 +87,7 @@ _TRUTHY_VALUES = {"1", "true", "yes", "on"}
def _env_flag_preferred(*names: str) -> bool:
"""Return the boolean value of the first env var that is present and non-empty.
Accepted truthy values (case-insensitive): ``1``, ``true``, ``yes``, ``on``.
Any other non-empty value is treated as falsy. If none of the named
variables is set, returns ``False``.
"""
"""Return the boolean value of the first env var that is present and non-empty."""
for name in names:
value = os.environ.get(name)
if value is not None and value.strip():
@@ -52,43 +105,45 @@ def _first_env_value(*names: str) -> str | None:
def get_tracing_config() -> TracingConfig:
"""Get the current tracing configuration from environment variables.
``LANGSMITH_*`` variables take precedence over their legacy ``LANGCHAIN_*``
counterparts. For boolean flags (``enabled``), the *first* variable that is
present and non-empty in the priority list is the sole authority its value
is parsed and returned without consulting the remaining candidates. Accepted
truthy values are ``1``, ``true``, ``yes``, and ``on`` (case-insensitive);
any other non-empty value is treated as falsy.
Priority order:
enabled : LANGSMITH_TRACING > LANGCHAIN_TRACING_V2 > LANGCHAIN_TRACING
api_key : LANGSMITH_API_KEY > LANGCHAIN_API_KEY
project : LANGSMITH_PROJECT > LANGCHAIN_PROJECT (default: "deer-flow")
endpoint : LANGSMITH_ENDPOINT > LANGCHAIN_ENDPOINT (default: https://api.smith.langchain.com)
Returns:
TracingConfig with current settings.
"""
"""Get the current tracing configuration from environment variables."""
global _tracing_config
if _tracing_config is not None:
return _tracing_config
with _config_lock:
if _tracing_config is not None: # Double-check after acquiring lock
if _tracing_config is not None:
return _tracing_config
_tracing_config = TracingConfig(
# Keep compatibility with both legacy LANGCHAIN_* and newer LANGSMITH_* variables.
enabled=_env_flag_preferred("LANGSMITH_TRACING", "LANGCHAIN_TRACING_V2", "LANGCHAIN_TRACING"),
api_key=_first_env_value("LANGSMITH_API_KEY", "LANGCHAIN_API_KEY"),
project=_first_env_value("LANGSMITH_PROJECT", "LANGCHAIN_PROJECT") or "deer-flow",
endpoint=_first_env_value("LANGSMITH_ENDPOINT", "LANGCHAIN_ENDPOINT") or "https://api.smith.langchain.com",
langsmith=LangSmithTracingConfig(
enabled=_env_flag_preferred("LANGSMITH_TRACING", "LANGCHAIN_TRACING_V2", "LANGCHAIN_TRACING"),
api_key=_first_env_value("LANGSMITH_API_KEY", "LANGCHAIN_API_KEY"),
project=_first_env_value("LANGSMITH_PROJECT", "LANGCHAIN_PROJECT") or "deer-flow",
endpoint=_first_env_value("LANGSMITH_ENDPOINT", "LANGCHAIN_ENDPOINT") or "https://api.smith.langchain.com",
),
langfuse=LangfuseTracingConfig(
enabled=_env_flag_preferred("LANGFUSE_TRACING"),
public_key=_first_env_value("LANGFUSE_PUBLIC_KEY"),
secret_key=_first_env_value("LANGFUSE_SECRET_KEY"),
host=_first_env_value("LANGFUSE_BASE_URL") or "https://cloud.langfuse.com",
),
)
return _tracing_config
def get_enabled_tracing_providers() -> list[str]:
"""Return the configured tracing providers that are enabled and complete."""
return get_tracing_config().enabled_providers
def get_explicitly_enabled_tracing_providers() -> list[str]:
"""Return tracing providers explicitly enabled by config, even if incomplete."""
return get_tracing_config().explicitly_enabled_providers
def validate_enabled_tracing_providers() -> None:
"""Validate that any explicitly enabled providers are fully configured."""
get_tracing_config().validate_enabled()
def is_tracing_enabled() -> bool:
"""Check if LangSmith tracing is enabled and configured.
Returns:
True if tracing is enabled and has an API key.
"""
"""Check if any tracing provider is enabled and fully configured."""
return get_tracing_config().is_configured
@@ -2,8 +2,9 @@ import logging
from langchain.chat_models import BaseChatModel
from deerflow.config import get_app_config, get_tracing_config, is_tracing_enabled
from deerflow.config import get_app_config
from deerflow.reflection import resolve_class
from deerflow.tracing import build_tracing_callbacks
logger = logging.getLogger(__name__)
@@ -88,17 +89,9 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
model_instance = model_class(**kwargs, **model_settings_from_config)
if is_tracing_enabled():
try:
from langchain_core.tracers.langchain import LangChainTracer
tracing_config = get_tracing_config()
tracer = LangChainTracer(
project_name=tracing_config.project,
)
existing_callbacks = model_instance.callbacks or []
model_instance.callbacks = [*existing_callbacks, tracer]
logger.debug(f"LangSmith tracing attached to model '{name}' (project='{tracing_config.project}')")
except Exception as e:
logger.warning(f"Failed to attach LangSmith tracing to model '{name}': {e}")
callbacks = build_tracing_callbacks()
if callbacks:
existing_callbacks = model_instance.callbacks or []
model_instance.callbacks = [*existing_callbacks, *callbacks]
logger.debug(f"Tracing attached to model '{name}' with providers={len(callbacks)}")
return model_instance
@@ -123,6 +123,11 @@ async def run_agent(
# Inject runtime context so middlewares can access thread_id
# (langgraph-cli does this automatically; we must do it manually)
runtime = Runtime(context={"thread_id": thread_id}, store=store)
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
# prefers it over ``configurable`` for thread-level data), make
# sure ``thread_id`` is available there too.
if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("thread_id", thread_id)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Inject RunJournal as a LangChain callback handler.
@@ -25,6 +25,7 @@ class MemoryStreamBridge(StreamBridge):
self._maxsize = queue_maxsize
self._queues: dict[str, asyncio.Queue[StreamEvent]] = {}
self._counters: dict[str, int] = {}
self._dropped_counts: dict[str, int] = {}
# -- helpers ---------------------------------------------------------------
@@ -32,6 +33,7 @@ class MemoryStreamBridge(StreamBridge):
if run_id not in self._queues:
self._queues[run_id] = asyncio.Queue(maxsize=self._maxsize)
self._counters[run_id] = 0
self._dropped_counts[run_id] = 0
return self._queues[run_id]
def _next_id(self, run_id: str) -> str:
@@ -48,14 +50,41 @@ class MemoryStreamBridge(StreamBridge):
try:
await asyncio.wait_for(queue.put(entry), timeout=_PUBLISH_TIMEOUT)
except TimeoutError:
logger.warning("Stream bridge queue full for run %s — dropping event %s", run_id, event)
self._dropped_counts[run_id] = self._dropped_counts.get(run_id, 0) + 1
logger.warning(
"Stream bridge queue full for run %s — dropping event %s (total dropped: %d)",
run_id,
event,
self._dropped_counts[run_id],
)
async def publish_end(self, run_id: str) -> None:
queue = self._get_or_create_queue(run_id)
try:
await asyncio.wait_for(queue.put(END_SENTINEL), timeout=_PUBLISH_TIMEOUT)
except TimeoutError:
logger.warning("Stream bridge queue full for run %s — dropping END sentinel", run_id)
# END sentinel is critical — it is the only signal that allows
# subscribers to terminate. If the queue is full we evict the
# oldest *regular* events to make room rather than dropping END,
# which would cause the SSE connection to hang forever and leak
# the queue/counter resources for this run_id.
if queue.full():
evicted = 0
while queue.full():
try:
queue.get_nowait()
evicted += 1
except asyncio.QueueEmpty:
break # pragma: no cover defensive
if evicted:
logger.warning(
"Stream bridge queue full for run %s — evicted %d event(s) to guarantee END sentinel delivery",
run_id,
evicted,
)
# After eviction the queue is guaranteed to have space, so a
# simple non-blocking put is safe. We still use put() (which
# blocks until space is available) as a defensive measure.
await queue.put(END_SENTINEL)
async def subscribe(
self,
@@ -84,7 +113,18 @@ class MemoryStreamBridge(StreamBridge):
await asyncio.sleep(delay)
self._queues.pop(run_id, None)
self._counters.pop(run_id, None)
self._dropped_counts.pop(run_id, None)
async def close(self) -> None:
self._queues.clear()
self._counters.clear()
self._dropped_counts.clear()
def dropped_count(self, run_id: str) -> int:
"""Return the number of events dropped for *run_id*."""
return self._dropped_counts.get(run_id, 0)
@property
def dropped_total(self) -> int:
"""Return the total number of events dropped across all runs."""
return sum(self._dropped_counts.values())
@@ -0,0 +1,23 @@
import threading
from deerflow.sandbox.sandbox import Sandbox
_FILE_OPERATION_LOCKS: dict[tuple[str, str], threading.Lock] = {}
_FILE_OPERATION_LOCKS_GUARD = threading.Lock()
def get_file_operation_lock_key(sandbox: Sandbox, path: str) -> tuple[str, str]:
sandbox_id = getattr(sandbox, "id", None)
if not sandbox_id:
sandbox_id = f"instance:{id(sandbox)}"
return sandbox_id, path
def get_file_operation_lock(sandbox: Sandbox, path: str) -> threading.Lock:
lock_key = get_file_operation_lock_key(sandbox, path)
with _FILE_OPERATION_LOCKS_GUARD:
lock = _FILE_OPERATION_LOCKS.get(lock_key)
if lock is None:
lock = threading.Lock()
_FILE_OPERATION_LOCKS[lock_key] = lock
return lock
@@ -1,72 +1,6 @@
import fnmatch
from pathlib import Path
IGNORE_PATTERNS = [
# Version Control
".git",
".svn",
".hg",
".bzr",
# Dependencies
"node_modules",
"__pycache__",
".venv",
"venv",
".env",
"env",
".tox",
".nox",
".eggs",
"*.egg-info",
"site-packages",
# Build outputs
"dist",
"build",
".next",
".nuxt",
".output",
".turbo",
"target",
"out",
# IDE & Editor
".idea",
".vscode",
"*.swp",
"*.swo",
"*~",
".project",
".classpath",
".settings",
# OS generated
".DS_Store",
"Thumbs.db",
"desktop.ini",
"*.lnk",
# Logs & temp files
"*.log",
"*.tmp",
"*.temp",
"*.bak",
"*.cache",
".cache",
"logs",
# Coverage & test artifacts
".coverage",
"coverage",
".nyc_output",
"htmlcov",
".pytest_cache",
".mypy_cache",
".ruff_cache",
]
def _should_ignore(name: str) -> bool:
"""Check if a file/directory name matches any ignore pattern."""
for pattern in IGNORE_PATTERNS:
if fnmatch.fnmatch(name, pattern):
return True
return False
from deerflow.sandbox.search import should_ignore_name
def list_dir(path: str, max_depth: int = 2) -> list[str]:
@@ -95,7 +29,7 @@ def list_dir(path: str, max_depth: int = 2) -> list[str]:
try:
for item in current_path.iterdir():
if _should_ignore(item.name):
if should_ignore_name(item.name):
continue
post_fix = "/" if item.is_dir() else ""
@@ -1,11 +1,23 @@
import errno
import ntpath
import os
import shutil
import subprocess
from dataclasses import dataclass
from pathlib import Path
from deerflow.sandbox.local.list_dir import list_dir
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.search import GrepMatch, find_glob_matches, find_grep_matches
@dataclass(frozen=True)
class PathMapping:
"""A path mapping from a container path to a local path with optional read-only flag."""
container_path: str
local_path: str
read_only: bool = False
class LocalSandbox(Sandbox):
@@ -39,17 +51,42 @@ class LocalSandbox(Sandbox):
return None
def __init__(self, id: str, path_mappings: dict[str, str] | None = None):
def __init__(self, id: str, path_mappings: list[PathMapping] | None = None):
"""
Initialize local sandbox with optional path mappings.
Args:
id: Sandbox identifier
path_mappings: Dictionary mapping container paths to local paths
Example: {"/mnt/skills": "/absolute/path/to/skills"}
path_mappings: List of path mappings with optional read-only flag.
Skills directory is read-only by default.
"""
super().__init__(id)
self.path_mappings = path_mappings or {}
self.path_mappings = path_mappings or []
def _is_read_only_path(self, resolved_path: str) -> bool:
"""Check if a resolved path is under a read-only mount.
When multiple mappings match (nested mounts), prefer the most specific
mapping (i.e. the one whose local_path is the longest prefix of the
resolved path), similar to how ``_resolve_path`` handles container paths.
"""
resolved = str(Path(resolved_path).resolve())
best_mapping: PathMapping | None = None
best_prefix_len = -1
for mapping in self.path_mappings:
local_resolved = str(Path(mapping.local_path).resolve())
if resolved == local_resolved or resolved.startswith(local_resolved + os.sep):
prefix_len = len(local_resolved)
if prefix_len > best_prefix_len:
best_prefix_len = prefix_len
best_mapping = mapping
if best_mapping is None:
return False
return best_mapping.read_only
def _resolve_path(self, path: str) -> str:
"""
@@ -64,7 +101,9 @@ class LocalSandbox(Sandbox):
path_str = str(path)
# Try each mapping (longest prefix first for more specific matches)
for container_path, local_path in sorted(self.path_mappings.items(), key=lambda x: len(x[0]), reverse=True):
for mapping in sorted(self.path_mappings, key=lambda m: len(m.container_path), reverse=True):
container_path = mapping.container_path
local_path = mapping.local_path
if path_str == container_path or path_str.startswith(container_path + "/"):
# Replace the container path prefix with local path
relative = path_str[len(container_path) :].lstrip("/")
@@ -84,15 +123,16 @@ class LocalSandbox(Sandbox):
Returns:
Container path if mapping exists, otherwise original path
"""
path_str = str(Path(path).resolve())
normalized_path = path.replace("\\", "/")
path_str = str(Path(normalized_path).resolve())
# Try each mapping (longest local path first for more specific matches)
for container_path, local_path in sorted(self.path_mappings.items(), key=lambda x: len(x[1]), reverse=True):
local_path_resolved = str(Path(local_path).resolve())
if path_str.startswith(local_path_resolved):
for mapping in sorted(self.path_mappings, key=lambda m: len(m.local_path), reverse=True):
local_path_resolved = str(Path(mapping.local_path).resolve())
if path_str == local_path_resolved or path_str.startswith(local_path_resolved + "/"):
# Replace the local path prefix with container path
relative = path_str[len(local_path_resolved) :].lstrip("/")
resolved = f"{container_path}/{relative}" if relative else container_path
resolved = f"{mapping.container_path}/{relative}" if relative else mapping.container_path
return resolved
# No mapping found, return original path
@@ -111,7 +151,7 @@ class LocalSandbox(Sandbox):
import re
# Sort mappings by local path length (longest first) for correct prefix matching
sorted_mappings = sorted(self.path_mappings.items(), key=lambda x: len(x[1]), reverse=True)
sorted_mappings = sorted(self.path_mappings, key=lambda m: len(m.local_path), reverse=True)
if not sorted_mappings:
return output
@@ -119,12 +159,11 @@ class LocalSandbox(Sandbox):
# Create pattern that matches absolute paths
# Match paths like /Users/... or other absolute paths
result = output
for container_path, local_path in sorted_mappings:
local_path_resolved = str(Path(local_path).resolve())
for mapping in sorted_mappings:
# Escape the local path for use in regex
escaped_local = re.escape(local_path_resolved)
# Match the local path followed by optional path components
pattern = re.compile(escaped_local + r"(?:/[^\s\"';&|<>()]*)?")
escaped_local = re.escape(str(Path(mapping.local_path).resolve()))
# Match the local path followed by optional path components with either separator
pattern = re.compile(escaped_local + r"(?:[/\\][^\s\"';&|<>()]*)?")
def replace_match(match: re.Match) -> str:
matched_path = match.group(0)
@@ -147,7 +186,7 @@ class LocalSandbox(Sandbox):
import re
# Sort mappings by length (longest first) for correct prefix matching
sorted_mappings = sorted(self.path_mappings.items(), key=lambda x: len(x[0]), reverse=True)
sorted_mappings = sorted(self.path_mappings, key=lambda m: len(m.container_path), reverse=True)
# Build regex pattern to match all container paths
# Match container path followed by optional path components
@@ -157,7 +196,7 @@ class LocalSandbox(Sandbox):
# Create pattern that matches any of the container paths.
# The lookahead (?=/|$|...) ensures we only match at a path-segment boundary,
# preventing /mnt/skills from matching inside /mnt/skills-extra.
patterns = [re.escape(container_path) + r"(?=/|$|[\s\"';&|<>()])(?:/[^\s\"';&|<>()]*)?" for container_path, _ in sorted_mappings]
patterns = [re.escape(m.container_path) + r"(?=/|$|[\s\"';&|<>()])(?:/[^\s\"';&|<>()]*)?" for m in sorted_mappings]
pattern = re.compile("|".join(f"({p})" for p in patterns))
def replace_match(match: re.Match) -> str:
@@ -248,6 +287,8 @@ class LocalSandbox(Sandbox):
def write_file(self, path: str, content: str, append: bool = False) -> None:
resolved_path = self._resolve_path(path)
if self._is_read_only_path(resolved_path):
raise OSError(errno.EROFS, "Read-only file system", path)
try:
dir_path = os.path.dirname(resolved_path)
if dir_path:
@@ -259,8 +300,43 @@ class LocalSandbox(Sandbox):
# Re-raise with the original path for clearer error messages, hiding internal resolved paths
raise type(e)(e.errno, e.strerror, path) from None
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
resolved_path = Path(self._resolve_path(path))
matches, truncated = find_glob_matches(resolved_path, pattern, include_dirs=include_dirs, max_results=max_results)
return [self._reverse_resolve_path(match) for match in matches], truncated
def grep(
self,
path: str,
pattern: str,
*,
glob: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = 100,
) -> tuple[list[GrepMatch], bool]:
resolved_path = Path(self._resolve_path(path))
matches, truncated = find_grep_matches(
resolved_path,
pattern,
glob_pattern=glob,
literal=literal,
case_sensitive=case_sensitive,
max_results=max_results,
)
return [
GrepMatch(
path=self._reverse_resolve_path(match.path),
line_number=match.line_number,
line=match.line,
)
for match in matches
], truncated
def update_file(self, path: str, content: bytes) -> None:
resolved_path = self._resolve_path(path)
if self._is_read_only_path(resolved_path):
raise OSError(errno.EROFS, "Read-only file system", path)
try:
dir_path = os.path.dirname(resolved_path)
if dir_path:
@@ -1,6 +1,7 @@
import logging
from pathlib import Path
from deerflow.sandbox.local.local_sandbox import LocalSandbox
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import SandboxProvider
@@ -14,16 +15,17 @@ class LocalSandboxProvider(SandboxProvider):
"""Initialize the local sandbox provider with path mappings."""
self._path_mappings = self._setup_path_mappings()
def _setup_path_mappings(self) -> dict[str, str]:
def _setup_path_mappings(self) -> list[PathMapping]:
"""
Setup path mappings for local sandbox.
Maps container paths to actual local paths, including skills directory.
Maps container paths to actual local paths, including skills directory
and any custom mounts configured in config.yaml.
Returns:
Dictionary of path mappings
List of path mappings
"""
mappings = {}
mappings: list[PathMapping] = []
# Map skills container path to local skills directory
try:
@@ -35,10 +37,63 @@ class LocalSandboxProvider(SandboxProvider):
# Only add mapping if skills directory exists
if skills_path.exists():
mappings[container_path] = str(skills_path)
mappings.append(
PathMapping(
container_path=container_path,
local_path=str(skills_path),
read_only=True, # Skills directory is always read-only
)
)
# Map custom mounts from sandbox config
_RESERVED_CONTAINER_PREFIXES = [container_path, "/mnt/acp-workspace", "/mnt/user-data"]
sandbox_config = config.sandbox
if sandbox_config and sandbox_config.mounts:
for mount in sandbox_config.mounts:
host_path = Path(mount.host_path)
container_path = mount.container_path.rstrip("/") or "/"
if not host_path.is_absolute():
logger.warning(
"Mount host_path must be absolute, skipping: %s -> %s",
mount.host_path,
mount.container_path,
)
continue
if not container_path.startswith("/"):
logger.warning(
"Mount container_path must be absolute, skipping: %s -> %s",
mount.host_path,
mount.container_path,
)
continue
# Reject mounts that conflict with reserved container paths
if any(container_path == p or container_path.startswith(p + "/") for p in _RESERVED_CONTAINER_PREFIXES):
logger.warning(
"Mount container_path conflicts with reserved prefix, skipping: %s",
mount.container_path,
)
continue
# Ensure the host path exists before adding mapping
if host_path.exists():
mappings.append(
PathMapping(
container_path=container_path,
local_path=str(host_path.resolve()),
read_only=mount.read_only,
)
)
else:
logger.warning(
"Mount host_path does not exist, skipping: %s -> %s",
mount.host_path,
mount.container_path,
)
except Exception as e:
# Log but don't fail if config loading fails
logger.warning("Could not setup skills path mapping: %s", e, exc_info=True)
logger.warning("Could not setup path mappings: %s", e, exc_info=True)
return mappings
@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from deerflow.sandbox.search import GrepMatch
class Sandbox(ABC):
"""Abstract base class for sandbox environments"""
@@ -61,6 +63,25 @@ class Sandbox(ABC):
"""
pass
@abstractmethod
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
"""Find paths that match a glob pattern under a root directory."""
pass
@abstractmethod
def grep(
self,
path: str,
pattern: str,
*,
glob: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = 100,
) -> tuple[list[GrepMatch], bool]:
"""Search for matches inside text files under a directory."""
pass
@abstractmethod
def update_file(self, path: str, content: bytes) -> None:
"""Update a file with binary content.
@@ -0,0 +1,210 @@
import fnmatch
import os
import re
from dataclasses import dataclass
from pathlib import Path, PurePosixPath
IGNORE_PATTERNS = [
".git",
".svn",
".hg",
".bzr",
"node_modules",
"__pycache__",
".venv",
"venv",
".env",
"env",
".tox",
".nox",
".eggs",
"*.egg-info",
"site-packages",
"dist",
"build",
".next",
".nuxt",
".output",
".turbo",
"target",
"out",
".idea",
".vscode",
"*.swp",
"*.swo",
"*~",
".project",
".classpath",
".settings",
".DS_Store",
"Thumbs.db",
"desktop.ini",
"*.lnk",
"*.log",
"*.tmp",
"*.temp",
"*.bak",
"*.cache",
".cache",
"logs",
".coverage",
"coverage",
".nyc_output",
"htmlcov",
".pytest_cache",
".mypy_cache",
".ruff_cache",
]
DEFAULT_MAX_FILE_SIZE_BYTES = 1_000_000
DEFAULT_LINE_SUMMARY_LENGTH = 200
@dataclass(frozen=True)
class GrepMatch:
path: str
line_number: int
line: str
def should_ignore_name(name: str) -> bool:
for pattern in IGNORE_PATTERNS:
if fnmatch.fnmatch(name, pattern):
return True
return False
def should_ignore_path(path: str) -> bool:
return any(should_ignore_name(segment) for segment in path.replace("\\", "/").split("/") if segment)
def path_matches(pattern: str, rel_path: str) -> bool:
path = PurePosixPath(rel_path)
if path.match(pattern):
return True
if pattern.startswith("**/"):
return path.match(pattern[3:])
return False
def truncate_line(line: str, max_chars: int = DEFAULT_LINE_SUMMARY_LENGTH) -> str:
line = line.rstrip("\n\r")
if len(line) <= max_chars:
return line
return line[: max_chars - 3] + "..."
def is_binary_file(path: Path, sample_size: int = 8192) -> bool:
try:
with path.open("rb") as handle:
return b"\0" in handle.read(sample_size)
except OSError:
return True
def find_glob_matches(root: Path, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
matches: list[str] = []
truncated = False
root = root.resolve()
if not root.exists():
raise FileNotFoundError(root)
if not root.is_dir():
raise NotADirectoryError(root)
for current_root, dirs, files in os.walk(root):
dirs[:] = [name for name in dirs if not should_ignore_name(name)]
# root is already resolved; os.walk builds current_root by joining under root,
# so relative_to() works without an extra stat()/resolve() per directory.
rel_dir = Path(current_root).relative_to(root)
if include_dirs:
for name in dirs:
rel_path = (rel_dir / name).as_posix()
if path_matches(pattern, rel_path):
matches.append(str(Path(current_root) / name))
if len(matches) >= max_results:
truncated = True
return matches, truncated
for name in files:
if should_ignore_name(name):
continue
rel_path = (rel_dir / name).as_posix()
if path_matches(pattern, rel_path):
matches.append(str(Path(current_root) / name))
if len(matches) >= max_results:
truncated = True
return matches, truncated
return matches, truncated
def find_grep_matches(
root: Path,
pattern: str,
*,
glob_pattern: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = 100,
max_file_size: int = DEFAULT_MAX_FILE_SIZE_BYTES,
line_summary_length: int = DEFAULT_LINE_SUMMARY_LENGTH,
) -> tuple[list[GrepMatch], bool]:
matches: list[GrepMatch] = []
truncated = False
root = root.resolve()
if not root.exists():
raise FileNotFoundError(root)
if not root.is_dir():
raise NotADirectoryError(root)
regex_source = re.escape(pattern) if literal else pattern
flags = 0 if case_sensitive else re.IGNORECASE
regex = re.compile(regex_source, flags)
# Skip lines longer than this to prevent ReDoS on minified / no-newline files.
_max_line_chars = line_summary_length * 10
for current_root, dirs, files in os.walk(root):
dirs[:] = [name for name in dirs if not should_ignore_name(name)]
rel_dir = Path(current_root).relative_to(root)
for name in files:
if should_ignore_name(name):
continue
candidate_path = Path(current_root) / name
rel_path = (rel_dir / name).as_posix()
if glob_pattern is not None and not path_matches(glob_pattern, rel_path):
continue
try:
if candidate_path.is_symlink():
continue
file_path = candidate_path.resolve()
if not file_path.is_relative_to(root):
continue
if file_path.stat().st_size > max_file_size or is_binary_file(file_path):
continue
with file_path.open(encoding="utf-8", errors="replace") as handle:
for line_number, line in enumerate(handle, start=1):
if len(line) > _max_line_chars:
continue
if regex.search(line):
matches.append(
GrepMatch(
path=str(file_path),
line_number=line_number,
line=truncate_line(line, line_summary_length),
)
)
if len(matches) >= max_results:
truncated = True
return matches, truncated
except OSError:
continue
return matches, truncated
@@ -7,17 +7,21 @@ from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.config import get_app_config
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
from deerflow.sandbox.exceptions import (
SandboxError,
SandboxNotFoundError,
SandboxRuntimeError,
)
from deerflow.sandbox.file_operation_lock import get_file_operation_lock
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.sandbox.search import GrepMatch
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])/(?:[^\s\"'`;&|<>()]+)")
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
_FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE)
_LOCAL_BASH_SYSTEM_PATH_PREFIXES = (
"/bin/",
"/usr/bin/",
@@ -29,6 +33,10 @@ _LOCAL_BASH_SYSTEM_PATH_PREFIXES = (
_DEFAULT_SKILLS_CONTAINER_PATH = "/mnt/skills"
_ACP_WORKSPACE_VIRTUAL_PATH = "/mnt/acp-workspace"
_DEFAULT_GLOB_MAX_RESULTS = 200
_MAX_GLOB_MAX_RESULTS = 1000
_DEFAULT_GREP_MAX_RESULTS = 100
_MAX_GREP_MAX_RESULTS = 500
def _get_skills_container_path() -> str:
@@ -111,6 +119,54 @@ def _is_acp_workspace_path(path: str) -> bool:
return path == _ACP_WORKSPACE_VIRTUAL_PATH or path.startswith(f"{_ACP_WORKSPACE_VIRTUAL_PATH}/")
def _get_custom_mounts():
"""Get custom volume mounts from sandbox config.
Result is cached after the first successful config load. If config loading
fails an empty list is returned *without* caching so that a later call can
pick up the real value once the config is available.
"""
cached = getattr(_get_custom_mounts, "_cached", None)
if cached is not None:
return cached
try:
from pathlib import Path
from deerflow.config import get_app_config
config = get_app_config()
mounts = []
if config.sandbox and config.sandbox.mounts:
# Only include mounts whose host_path exists, consistent with
# LocalSandboxProvider._setup_path_mappings() which also filters
# by host_path.exists().
mounts = [m for m in config.sandbox.mounts if Path(m.host_path).exists()]
_get_custom_mounts._cached = mounts # type: ignore[attr-defined]
return mounts
except Exception:
# If config loading fails, return an empty list without caching so that
# a later call can retry once the config is available.
return []
def _is_custom_mount_path(path: str) -> bool:
"""Check if path is under a custom mount container_path."""
for mount in _get_custom_mounts():
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
return True
return False
def _get_custom_mount_for_path(path: str):
"""Get the mount config matching this path (longest prefix first)."""
best = None
for mount in _get_custom_mounts():
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
if best is None or len(mount.container_path) > len(best.container_path):
best = mount
return best
def _extract_thread_id_from_thread_data(thread_data: "ThreadDataState | None") -> str | None:
"""Extract thread_id from thread_data by inspecting workspace_path.
@@ -243,6 +299,69 @@ def _get_mcp_allowed_paths() -> list[str]:
return allowed_paths
def _get_tool_config_int(name: str, key: str, default: int) -> int:
try:
tool_config = get_app_config().get_tool_config(name)
if tool_config is not None and key in tool_config.model_extra:
value = tool_config.model_extra.get(key)
if isinstance(value, int):
return value
except Exception:
pass
return default
def _clamp_max_results(value: int, *, default: int, upper_bound: int) -> int:
if value <= 0:
return default
return min(value, upper_bound)
def _resolve_max_results(name: str, requested: int, *, default: int, upper_bound: int) -> int:
requested_max_results = _clamp_max_results(requested, default=default, upper_bound=upper_bound)
configured_max_results = _clamp_max_results(
_get_tool_config_int(name, "max_results", default),
default=default,
upper_bound=upper_bound,
)
return min(requested_max_results, configured_max_results)
def _resolve_local_read_path(path: str, thread_data: ThreadDataState) -> str:
validate_local_tool_path(path, thread_data, read_only=True)
if _is_skills_path(path):
return _resolve_skills_path(path)
if _is_acp_workspace_path(path):
return _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
return _resolve_and_validate_user_data_path(path, thread_data)
def _format_glob_results(root_path: str, matches: list[str], truncated: bool) -> str:
if not matches:
return f"No files matched under {root_path}"
lines = [f"Found {len(matches)} paths under {root_path}"]
if truncated:
lines[0] += f" (showing first {len(matches)})"
lines.extend(f"{index}. {path}" for index, path in enumerate(matches, start=1))
if truncated:
lines.append("Results truncated. Narrow the path or pattern to see fewer matches.")
return "\n".join(lines)
def _format_grep_results(root_path: str, matches: list[GrepMatch], truncated: bool) -> str:
if not matches:
return f"No matches found under {root_path}"
lines = [f"Found {len(matches)} matches under {root_path}"]
if truncated:
lines[0] += f" (showing first {len(matches)})"
lines.extend(f"{match.path}:{match.line_number}: {match.line}" for match in matches)
if truncated:
lines.append("Results truncated. Narrow the path or add a glob filter.")
return "\n".join(lines)
def _path_variants(path: str) -> set[str]:
return {path, path.replace("\\", "/"), path.replace("/", "\\")}
@@ -377,6 +496,8 @@ def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None)
result = pattern.sub(replace_acp, result)
# Custom mount host paths are masked by LocalSandbox._reverse_resolve_paths_in_output()
# Mask user-data host paths
if thread_data is None:
return result
@@ -425,6 +546,7 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
- ``/mnt/user-data/*`` — always allowed (read + write)
- ``/mnt/skills/*`` — allowed only when *read_only* is True
- ``/mnt/acp-workspace/*`` — allowed only when *read_only* is True
- Custom mount paths (from config.yaml) — respects per-mount ``read_only`` flag
Args:
path: The virtual path to validate.
@@ -456,7 +578,14 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
if path.startswith(f"{VIRTUAL_PATH_PREFIX}/"):
return
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path()}/, or {_ACP_WORKSPACE_VIRTUAL_PATH}/ are allowed")
# Custom mount paths — respect read_only config
if _is_custom_mount_path(path):
mount = _get_custom_mount_for_path(path)
if mount and mount.read_only and not read_only:
raise PermissionError(f"Write access to read-only mount is not allowed: {path}")
return
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path()}/, {_ACP_WORKSPACE_VIRTUAL_PATH}/, or configured mount paths are allowed")
def _validate_resolved_user_data_path(resolved: Path, thread_data: ThreadDataState) -> None:
@@ -506,15 +635,21 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
boundary and must not be treated as isolation from the host filesystem.
In local mode, commands must use virtual paths under /mnt/user-data for
user data access. Skills paths under /mnt/skills and ACP workspace paths
under /mnt/acp-workspace are allowed (path-traversal checks only; write
prevention for bash commands is not enforced here).
user data access. Skills paths under /mnt/skills, ACP workspace paths
under /mnt/acp-workspace, and custom mount container paths (configured in
config.yaml) are allowed (path-traversal checks only; write prevention
for bash commands is not enforced here).
A small allowlist of common system path prefixes is kept for executable
and device references (e.g. /bin/sh, /dev/null).
"""
if thread_data is None:
raise SandboxRuntimeError("Thread data not available for local sandbox")
# Block file:// URLs which bypass the absolute-path regex but allow local file exfiltration
file_url_match = _FILE_URL_PATTERN.search(command)
if file_url_match:
raise PermissionError(f"Unsafe file:// URL in command: {file_url_match.group()}. Use paths under {VIRTUAL_PATH_PREFIX}")
unsafe_paths: list[str] = []
allowed_paths = _get_mcp_allowed_paths()
@@ -538,6 +673,11 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
_reject_path_traversal(absolute_path)
continue
# Allow custom mount container paths
if _is_custom_mount_path(absolute_path):
_reject_path_traversal(absolute_path)
continue
if any(absolute_path == prefix.rstrip("/") or absolute_path.startswith(prefix) for prefix in _LOCAL_BASH_SYSTEM_PATH_PREFIXES):
continue
@@ -582,6 +722,8 @@ def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState
result = acp_pattern.sub(replace_acp_match, result)
# Custom mount paths are resolved by LocalSandbox._resolve_paths_in_command()
# Replace user-data paths
if VIRTUAL_PATH_PREFIX in result and thread_data is not None:
pattern = re.compile(rf"{re.escape(VIRTUAL_PATH_PREFIX)}(/[^\s\"';&|<>()]*)?")
@@ -757,6 +899,59 @@ def ensure_thread_directories_exist(runtime: ToolRuntime[ContextT, ThreadState]
runtime.state["thread_directories_created"] = True
def _truncate_bash_output(output: str, max_chars: int) -> str:
"""Middle-truncate bash output, preserving head and tail (50/50 split).
bash output may have errors at either end (stderr/stdout ordering is
non-deterministic), so both ends are preserved equally.
The returned string (including the truncation marker) is guaranteed to be
no longer than max_chars characters. Pass max_chars=0 to disable truncation
and return the full output unchanged.
"""
if max_chars == 0:
return output
if len(output) <= max_chars:
return output
total_len = len(output)
# Compute the exact worst-case marker length: skipped chars is at most
# total_len, so this is a tight upper bound.
marker_max_len = len(f"\n... [middle truncated: {total_len} chars skipped] ...\n")
kept = max(0, max_chars - marker_max_len)
if kept == 0:
return output[:max_chars]
head_len = kept // 2
tail_len = kept - head_len
skipped = total_len - kept
marker = f"\n... [middle truncated: {skipped} chars skipped] ...\n"
return f"{output[:head_len]}{marker}{output[-tail_len:] if tail_len > 0 else ''}"
def _truncate_read_file_output(output: str, max_chars: int) -> str:
"""Head-truncate read_file output, preserving the beginning of the file.
Source code and documents are read top-to-bottom; the head contains the
most context (imports, class definitions, function signatures).
The returned string (including the truncation marker) is guaranteed to be
no longer than max_chars characters. Pass max_chars=0 to disable truncation
and return the full output unchanged.
"""
if max_chars == 0:
return output
if len(output) <= max_chars:
return output
total = len(output)
# Compute the exact worst-case marker length: both numeric fields are at
# their maximum (total chars), so this is a tight upper bound.
marker_max_len = len(f"\n... [truncated: showing first {total} of {total} chars. Use start_line/end_line to read a specific range] ...")
kept = max(0, max_chars - marker_max_len)
if kept == 0:
return output[:max_chars]
marker = f"\n... [truncated: showing first {kept} of {total} chars. Use start_line/end_line to read a specific range] ..."
return f"{output[:kept]}{marker}"
@tool("bash", parse_docstring=True)
def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str:
"""Execute a bash command in a Linux environment.
@@ -781,9 +976,23 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
command = replace_virtual_paths_in_command(command, thread_data)
command = _apply_cwd_prefix(command, thread_data)
output = sandbox.execute_command(command)
return mask_local_paths_in_output(output, thread_data)
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
ensure_thread_directories_exist(runtime)
return sandbox.execute_command(command)
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_bash_output(sandbox.execute_command(command), max_chars)
except SandboxError as e:
return f"Error: {e}"
except PermissionError as e:
@@ -811,8 +1020,9 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
path = _resolve_skills_path(path)
elif _is_acp_workspace_path(path):
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
else:
elif not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
children = sandbox.list_dir(path)
if not children:
return "(empty)"
@@ -827,6 +1037,126 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime)}"
@tool("glob", parse_docstring=True)
def glob_tool(
runtime: ToolRuntime[ContextT, ThreadState],
description: str,
pattern: str,
path: str,
include_dirs: bool = False,
max_results: int = _DEFAULT_GLOB_MAX_RESULTS,
) -> str:
"""Find files or directories that match a glob pattern under a root directory.
Args:
description: Explain why you are searching for these paths in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
pattern: The glob pattern to match relative to the root path, for example `**/*.py`.
path: The **absolute** root directory to search under.
include_dirs: Whether matching directories should also be returned. Default is False.
max_results: Maximum number of paths to return. Default is 200.
"""
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
effective_max_results = _resolve_max_results(
"glob",
max_results,
default=_DEFAULT_GLOB_MAX_RESULTS,
upper_bound=_MAX_GLOB_MAX_RESULTS,
)
thread_data = None
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
if thread_data is None:
raise SandboxRuntimeError("Thread data not available for local sandbox")
path = _resolve_local_read_path(path, thread_data)
matches, truncated = sandbox.glob(path, pattern, include_dirs=include_dirs, max_results=effective_max_results)
if thread_data is not None:
matches = [mask_local_paths_in_output(match, thread_data) for match in matches]
return _format_glob_results(requested_path, matches, truncated)
except SandboxError as e:
return f"Error: {e}"
except FileNotFoundError:
return f"Error: Directory not found: {requested_path}"
except NotADirectoryError:
return f"Error: Path is not a directory: {requested_path}"
except PermissionError:
return f"Error: Permission denied: {requested_path}"
except Exception as e:
return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime)}"
@tool("grep", parse_docstring=True)
def grep_tool(
runtime: ToolRuntime[ContextT, ThreadState],
description: str,
pattern: str,
path: str,
glob: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = _DEFAULT_GREP_MAX_RESULTS,
) -> str:
"""Search for matching lines inside text files under a root directory.
Args:
description: Explain why you are searching file contents in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
pattern: The string or regex pattern to search for.
path: The **absolute** root directory to search under.
glob: Optional glob filter for candidate files, for example `**/*.py`.
literal: Whether to treat `pattern` as a plain string. Default is False.
case_sensitive: Whether matching is case-sensitive. Default is False.
max_results: Maximum number of matching lines to return. Default is 100.
"""
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
effective_max_results = _resolve_max_results(
"grep",
max_results,
default=_DEFAULT_GREP_MAX_RESULTS,
upper_bound=_MAX_GREP_MAX_RESULTS,
)
thread_data = None
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
if thread_data is None:
raise SandboxRuntimeError("Thread data not available for local sandbox")
path = _resolve_local_read_path(path, thread_data)
matches, truncated = sandbox.grep(
path,
pattern,
glob=glob,
literal=literal,
case_sensitive=case_sensitive,
max_results=effective_max_results,
)
if thread_data is not None:
matches = [
GrepMatch(
path=mask_local_paths_in_output(match.path, thread_data),
line_number=match.line_number,
line=match.line,
)
for match in matches
]
return _format_grep_results(requested_path, matches, truncated)
except SandboxError as e:
return f"Error: {e}"
except FileNotFoundError:
return f"Error: Directory not found: {requested_path}"
except NotADirectoryError:
return f"Error: Path is not a directory: {requested_path}"
except re.error as e:
return f"Error: Invalid regex pattern: {e}"
except PermissionError:
return f"Error: Permission denied: {requested_path}"
except Exception as e:
return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime)}"
@tool("read_file", parse_docstring=True)
def read_file_tool(
runtime: ToolRuntime[ContextT, ThreadState],
@@ -854,14 +1184,22 @@ def read_file_tool(
path = _resolve_skills_path(path)
elif _is_acp_workspace_path(path):
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
else:
elif not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
content = sandbox.read_file(path)
if not content:
return "(empty)"
if start_line is not None and end_line is not None:
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
return content
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
except Exception:
max_chars = 50000
return _truncate_read_file_output(content, max_chars)
except SandboxError as e:
return f"Error: {e}"
except FileNotFoundError:
@@ -896,8 +1234,11 @@ def write_file_tool(
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data)
path = _resolve_and_validate_user_data_path(path, thread_data)
sandbox.write_file(path, content, append)
if not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
with get_file_operation_lock(sandbox, path):
sandbox.write_file(path, content, append)
return "OK"
except SandboxError as e:
return f"Error: {e}"
@@ -937,17 +1278,20 @@ def str_replace_tool(
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data)
path = _resolve_and_validate_user_data_path(path, thread_data)
content = sandbox.read_file(path)
if not content:
return "OK"
if old_str not in content:
return f"Error: String to replace not found in file: {requested_path}"
if replace_all:
content = content.replace(old_str, new_str)
else:
content = content.replace(old_str, new_str, 1)
sandbox.write_file(path, content)
if not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
with get_file_operation_lock(sandbox, path):
content = sandbox.read_file(path)
if not content:
return "OK"
if old_str not in content:
return f"Error: String to replace not found in file: {requested_path}"
if replace_all:
content = content.replace(old_str, new_str)
else:
content = content.replace(old_str, new_str, 1)
sandbox.write_file(path, content)
return "OK"
except SandboxError as e:
return f"Error: {e}"
@@ -33,15 +33,72 @@ def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None
front_matter = front_matter_match.group(1)
# Parse YAML front matter (simple key-value parsing)
# Parse YAML front matter with basic multiline string support
metadata = {}
for line in front_matter.split("\n"):
line = line.strip()
if not line:
lines = front_matter.split("\n")
current_key = None
current_value = []
is_multiline = False
multiline_style = None
indent_level = None
for line in lines:
if is_multiline:
if not line.strip():
current_value.append("")
continue
current_indent = len(line) - len(line.lstrip())
if indent_level is None:
if current_indent > 0:
indent_level = current_indent
current_value.append(line[indent_level:])
continue
elif current_indent >= indent_level:
current_value.append(line[indent_level:])
continue
# If we reach here, it's either a new key or the end of multiline
if current_key and is_multiline:
if multiline_style == "|":
metadata[current_key] = "\n".join(current_value).rstrip()
else:
text = "\n".join(current_value).rstrip()
# Replace single newlines with spaces for folded blocks
metadata[current_key] = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
current_key = None
current_value = []
is_multiline = False
multiline_style = None
indent_level = None
if not line.strip():
continue
if ":" in line:
# Handle nested dicts simply by ignoring indentation for now,
# or just extracting top-level keys
key, value = line.split(":", 1)
metadata[key.strip()] = value.strip()
key = key.strip()
value = value.strip()
if value in (">", "|"):
current_key = key
is_multiline = True
multiline_style = value
current_value = []
indent_level = None
else:
metadata[key] = value
if current_key and is_multiline:
if multiline_style == "|":
metadata[current_key] = "\n".join(current_value).rstrip()
else:
text = "\n".join(current_value).rstrip()
metadata[current_key] = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
# Extract required fields
name = metadata.get("name")
@@ -57,6 +57,42 @@ def _build_mcp_servers() -> dict[str, dict[str, Any]]:
return build_servers_config(ExtensionsConfig.from_file())
def _build_acp_mcp_servers() -> list[dict[str, Any]]:
"""Build ACP ``mcpServers`` payload for ``new_session``.
The ACP client expects a list of server objects, while DeerFlow's MCP helper
returns a name -> config mapping for the LangChain MCP adapter. This helper
converts the enabled servers into the ACP wire format.
"""
from deerflow.config.extensions_config import ExtensionsConfig
extensions_config = ExtensionsConfig.from_file()
enabled_servers = extensions_config.get_enabled_mcp_servers()
mcp_servers: list[dict[str, Any]] = []
for name, server_config in enabled_servers.items():
transport_type = server_config.type or "stdio"
payload: dict[str, Any] = {"name": name, "type": transport_type}
if transport_type == "stdio":
if not server_config.command:
raise ValueError(f"MCP server '{name}' with stdio transport requires 'command' field")
payload["command"] = server_config.command
payload["args"] = server_config.args
payload["env"] = [{"name": key, "value": value} for key, value in server_config.env.items()]
elif transport_type in ("http", "sse"):
if not server_config.url:
raise ValueError(f"MCP server '{name}' with {transport_type} transport requires 'url' field")
payload["url"] = server_config.url
payload["headers"] = [{"name": key, "value": value} for key, value in server_config.headers.items()]
else:
raise ValueError(f"MCP server '{name}' has unsupported transport type: {transport_type}")
mcp_servers.append(payload)
return mcp_servers
def _build_permission_response(options: list[Any], *, auto_approve: bool) -> Any:
"""Build an ACP permission response.
@@ -173,7 +209,15 @@ def build_invoke_acp_agent_tool(agents: dict) -> BaseTool:
cmd = agent_config.command
args = agent_config.args or []
physical_cwd = _get_work_dir(thread_id)
mcp_servers = _build_mcp_servers()
try:
mcp_servers = _build_acp_mcp_servers()
except ValueError as exc:
logger.warning(
"Invalid MCP server configuration for ACP agent '%s'; continuing without MCP servers: %s",
agent,
exc,
)
mcp_servers = []
agent_env: dict[str, str] | None = None
if agent_config.env:
agent_env = {k: (os.environ.get(v[1:], "") if v.startswith("$") else v) for k, v in agent_config.env.items()}
@@ -0,0 +1,3 @@
from .factory import build_tracing_callbacks
__all__ = ["build_tracing_callbacks"]
@@ -0,0 +1,54 @@
from __future__ import annotations
from typing import Any
from deerflow.config import (
get_enabled_tracing_providers,
get_tracing_config,
validate_enabled_tracing_providers,
)
def _create_langsmith_tracer(config) -> Any:
from langchain_core.tracers.langchain import LangChainTracer
return LangChainTracer(project_name=config.project)
def _create_langfuse_handler(config) -> Any:
from langfuse import Langfuse
from langfuse.langchain import CallbackHandler as LangfuseCallbackHandler
# langfuse>=4 initializes project-specific credentials through the client
# singleton; the LangChain callback then attaches to that configured client.
Langfuse(
secret_key=config.secret_key,
public_key=config.public_key,
host=config.host,
)
return LangfuseCallbackHandler(public_key=config.public_key)
def build_tracing_callbacks() -> list[Any]:
"""Build callbacks for all explicitly enabled tracing providers."""
validate_enabled_tracing_providers()
enabled_providers = get_enabled_tracing_providers()
if not enabled_providers:
return []
tracing_config = get_tracing_config()
callbacks: list[Any] = []
for provider in enabled_providers:
if provider == "langsmith":
try:
callbacks.append(_create_langsmith_tracer(tracing_config.langsmith))
except Exception as exc: # pragma: no cover - exercised via tests with monkeypatch
raise RuntimeError(f"LangSmith tracing initialization failed: {exc}") from exc
elif provider == "langfuse":
try:
callbacks.append(_create_langfuse_handler(tracing_config.langfuse))
except Exception as exc: # pragma: no cover - exercised via tests with monkeypatch
raise RuntimeError(f"Langfuse tracing initialization failed: {exc}") from exc
return callbacks
@@ -1,10 +1,22 @@
"""File conversion utilities.
Converts document files (PDF, PPT, Excel, Word) to Markdown using markitdown.
Converts document files (PDF, PPT, Excel, Word) to Markdown.
PDF conversion strategy (auto mode):
1. Try pymupdf4llm if installed — better heading detection, faster on most files.
2. If output is suspiciously short (< _MIN_CHARS_PER_PAGE chars/page, or < 200 chars
total when page count is unavailable), treat as image-based and fall back to MarkItDown.
3. If pymupdf4llm is not installed, use MarkItDown directly (existing behaviour).
Large files (> ASYNC_THRESHOLD_BYTES) are converted in a thread pool via
asyncio.to_thread() to avoid blocking the event loop (fixes #1569).
No FastAPI or HTTP dependencies — pure utility functions.
"""
import asyncio
import logging
import re
from pathlib import Path
logger = logging.getLogger(__name__)
@@ -20,28 +32,278 @@ CONVERTIBLE_EXTENSIONS = {
".docx",
}
# Files larger than this threshold are converted in a background thread.
# Small files complete in < 1s synchronously; spawning a thread adds unnecessary
# scheduling overhead for them.
_ASYNC_THRESHOLD_BYTES = 1 * 1024 * 1024 # 1 MB
# If pymupdf4llm produces fewer characters *per page* than this threshold,
# the PDF is likely image-based or encrypted — fall back to MarkItDown.
# Rationale: normal text PDFs yield 200-2000 chars/page; image-based PDFs
# yield close to 0. 50 chars/page gives a wide safety margin.
# Falls back to absolute 200-char check when page count is unavailable.
_MIN_CHARS_PER_PAGE = 50
def _pymupdf_output_too_sparse(text: str, file_path: Path) -> bool:
"""Return True if pymupdf4llm output is suspiciously short (image-based PDF).
Uses chars-per-page rather than an absolute threshold so that both short
documents (few pages, few chars) and long documents (many pages, many chars)
are handled correctly.
"""
chars = len(text.strip())
doc = None
pages: int | None = None
try:
import pymupdf
doc = pymupdf.open(str(file_path))
pages = len(doc)
except Exception:
pass
finally:
if doc is not None:
try:
doc.close()
except Exception:
pass
if pages is not None and pages > 0:
return (chars / pages) < _MIN_CHARS_PER_PAGE
# Fallback: absolute threshold when page count is unavailable
return chars < 200
def _convert_pdf_with_pymupdf4llm(file_path: Path) -> str | None:
"""Attempt PDF conversion with pymupdf4llm.
Returns the markdown text, or None if pymupdf4llm is not installed or
if conversion fails (e.g. encrypted/corrupt PDF).
"""
try:
import pymupdf4llm
except ImportError:
return None
try:
return pymupdf4llm.to_markdown(str(file_path))
except Exception:
logger.exception("pymupdf4llm failed to convert %s; falling back to MarkItDown", file_path.name)
return None
def _convert_with_markitdown(file_path: Path) -> str:
"""Convert any supported file to markdown text using MarkItDown."""
from markitdown import MarkItDown
md = MarkItDown()
return md.convert(str(file_path)).text_content
def _do_convert(file_path: Path, pdf_converter: str) -> str:
"""Synchronous conversion — called directly or via asyncio.to_thread.
Args:
file_path: Path to the file.
pdf_converter: "auto" | "pymupdf4llm" | "markitdown"
"""
is_pdf = file_path.suffix.lower() == ".pdf"
if is_pdf and pdf_converter != "markitdown":
# Try pymupdf4llm first (auto or explicit)
pymupdf_text = _convert_pdf_with_pymupdf4llm(file_path)
if pymupdf_text is not None:
# pymupdf4llm is installed
if pdf_converter == "pymupdf4llm":
# Explicit — use as-is regardless of output length
return pymupdf_text
# auto mode: fall back if output looks like a failed parse.
# Use chars-per-page to distinguish image-based PDFs (near 0) from
# legitimately short documents.
if not _pymupdf_output_too_sparse(pymupdf_text, file_path):
return pymupdf_text
logger.warning(
"pymupdf4llm produced only %d chars for %s (likely image-based PDF); falling back to MarkItDown",
len(pymupdf_text.strip()),
file_path.name,
)
# pymupdf4llm not installed or fallback triggered → use MarkItDown
return _convert_with_markitdown(file_path)
async def convert_file_to_markdown(file_path: Path) -> Path | None:
"""Convert a file to markdown using markitdown.
"""Convert a supported document file to Markdown.
PDF files are handled with a two-converter strategy (see module docstring).
Large files (> 1 MB) are offloaded to a thread pool to avoid blocking the
event loop.
Args:
file_path: Path to the file to convert.
Returns:
Path to the markdown file if conversion was successful, None otherwise.
Path to the generated .md file, or None if conversion failed.
"""
try:
from markitdown import MarkItDown
pdf_converter = _get_pdf_converter()
file_size = file_path.stat().st_size
md = MarkItDown()
result = md.convert(str(file_path))
if file_size > _ASYNC_THRESHOLD_BYTES:
text = await asyncio.to_thread(_do_convert, file_path, pdf_converter)
else:
text = _do_convert(file_path, pdf_converter)
# Save as .md file with same name
md_path = file_path.with_suffix(".md")
md_path.write_text(result.text_content, encoding="utf-8")
md_path.write_text(text, encoding="utf-8")
logger.info(f"Converted {file_path.name} to markdown: {md_path.name}")
logger.info("Converted %s to markdown: %s (%d chars)", file_path.name, md_path.name, len(text))
return md_path
except Exception as e:
logger.error(f"Failed to convert {file_path.name} to markdown: {e}")
logger.error("Failed to convert %s to markdown: %s", file_path.name, e)
return None
# Regex for bold-only lines that look like section headings.
# Targets SEC filing structural headings that pymupdf4llm renders as **bold**
# rather than # Markdown headings (because they use same font size as body text,
# distinguished only by bold+caps formatting).
#
# Pattern requires ALL of:
# 1. Entire line is a single **...** block (no surrounding prose)
# 2. Starts with a recognised structural keyword:
# - ITEM / PART / SECTION (with optional number/letter after)
# - SCHEDULE, EXHIBIT, APPENDIX, ANNEX, CHAPTER
# All-caps addresses, boilerplate ("CURRENT REPORT", "SIGNATURES",
# "WASHINGTON, DC 20549") do NOT start with these keywords and are excluded.
#
# Chinese headings (第三节...) are already captured as standard # headings
# by pymupdf4llm, so they don't need this pattern.
_BOLD_HEADING_RE = re.compile(r"^\*\*((ITEM|PART|SECTION|SCHEDULE|EXHIBIT|APPENDIX|ANNEX|CHAPTER)\b[A-Z0-9 .,\-]*)\*\*\s*$")
# Regex for split-bold headings produced by pymupdf4llm when a heading spans
# multiple text spans in the PDF (e.g. section number and title are separate spans).
# Matches lines like: **1** **Introduction** or **3.2** **Multi-Head Attention**
# Requirements:
# 1. Entire line consists only of **...** blocks separated by whitespace (no prose)
# 2. First block is a section number (digits and dots, e.g. "1", "3.2", "A.1")
# 3. Second block must not be purely numeric/punctuation — excludes financial table
# headers like **2023** **2022** **2021** while allowing non-ASCII titles such as
# **1** **概述** or accented words (negative lookahead instead of [A-Za-z])
# 4. At most two additional blocks (four total) with [^*]+ (no * inside) to keep
# the regex linear and avoid ReDoS on attacker-controlled content
_SPLIT_BOLD_HEADING_RE = re.compile(r"^\*\*[\dA-Z][\d\.]*\*\*\s+\*\*(?!\d[\d\s.,\-–—/:()%]*\*\*)[^*]+\*\*(?:\s+\*\*[^*]+\*\*){0,2}\s*$")
# Maximum number of outline entries injected into the agent context.
# Keeps prompt size bounded even for very long documents.
MAX_OUTLINE_ENTRIES = 50
_ALLOWED_PDF_CONVERTERS = {"auto", "pymupdf4llm", "markitdown"}
def _clean_bold_title(raw: str) -> str:
"""Normalise a title string that may contain pymupdf4llm bold artefacts.
pymupdf4llm sometimes emits adjacent bold spans as ``**A** **B**`` instead
of a single ``**A B**`` block. This helper merges those fragments and then
strips the outermost ``**...**`` wrapper so the caller gets plain text.
Examples::
"**Overview**""Overview"
"**UNITED STATES** **SECURITIES**""UNITED STATES SECURITIES"
"plain text""plain text" (unchanged)
"""
# Merge adjacent bold spans: "** **" → " "
merged = re.sub(r"\*\*\s*\*\*", " ", raw).strip()
# Strip outermost **...** if the whole string is wrapped
if m := re.fullmatch(r"\*\*(.+?)\*\*", merged, re.DOTALL):
return m.group(1).strip()
return merged
def extract_outline(md_path: Path) -> list[dict]:
"""Extract document outline (headings) from a Markdown file.
Recognises three heading styles produced by pymupdf4llm:
1. Standard Markdown headings: lines starting with one or more '#'.
Inline ``**...**`` wrappers and adjacent bold spans (``** **``) are
cleaned so the title is plain text.
2. Bold-only structural headings: ``**ITEM 1. BUSINESS**``, ``**PART II**``,
etc. SEC filings use bold+caps for section headings with the same font
size as body text, so pymupdf4llm cannot promote them to # headings.
3. Split-bold headings: ``**1** **Introduction**``, ``**3.2** **Attention**``.
pymupdf4llm emits these when the section number and title text are
separate spans in the underlying PDF (common in academic papers).
Args:
md_path: Path to the .md file.
Returns:
List of dicts with keys: title (str), line (int, 1-based).
When the outline is truncated at MAX_OUTLINE_ENTRIES, a sentinel entry
``{"truncated": True}`` is appended as the last element so callers can
render a "showing first N headings" hint without re-scanning the file.
Returns an empty list if the file cannot be read or has no headings.
"""
outline: list[dict] = []
try:
with md_path.open(encoding="utf-8") as f:
for lineno, line in enumerate(f, 1):
stripped = line.strip()
if not stripped:
continue
# Style 1: standard Markdown heading
if stripped.startswith("#"):
title = _clean_bold_title(stripped.lstrip("#").strip())
if title:
outline.append({"title": title, "line": lineno})
# Style 2: single bold block with SEC structural keyword
elif m := _BOLD_HEADING_RE.match(stripped):
title = m.group(1).strip()
if title:
outline.append({"title": title, "line": lineno})
# Style 3: split-bold heading — **<num>** **<title>**
# Regex already enforces max 4 blocks and non-numeric second block.
elif _SPLIT_BOLD_HEADING_RE.match(stripped):
title = " ".join(re.findall(r"\*\*([^*]+)\*\*", stripped))
if title:
outline.append({"title": title, "line": lineno})
if len(outline) >= MAX_OUTLINE_ENTRIES:
outline.append({"truncated": True})
break
except Exception:
return []
return outline
def _get_pdf_converter() -> str:
"""Read pdf_converter setting from app config, defaulting to 'auto'.
Normalizes the value to lowercase and validates it against the allowed set
so that values like 'AUTO' or 'MarkItDown' from config.yaml don't silently
fall through to unexpected behaviour.
"""
try:
from deerflow.config.app_config import get_app_config
cfg = get_app_config()
uploads_cfg = getattr(cfg, "uploads", None)
if uploads_cfg is not None:
raw = str(getattr(uploads_cfg, "pdf_converter", "auto")).strip().lower()
if raw not in _ALLOWED_PDF_CONVERTERS:
logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw)
return "auto"
return raw
except Exception:
pass
return "auto"
+4
View File
@@ -14,6 +14,7 @@ dependencies = [
"langchain-deepseek>=1.0.1",
"langchain-mcp-adapters>=0.1.0",
"langchain-openai>=1.1.7",
"langfuse>=3.4.1",
"langgraph>=1.0.6,<1.0.10",
"langgraph-api>=0.7.0,<0.8.0",
"langgraph-cli>=0.4.14",
@@ -44,6 +45,9 @@ postgres = [
"psycopg-pool>=3.3.0",
]
[project.optional-dependencies]
pymupdf = ["pymupdf4llm>=0.0.17"]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"