mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 07:26:50 +00:00
Merge branch 'main' into rayhpeng/persistence-scaffold
# Conflicts: # .env.example # backend/packages/harness/deerflow/agents/middlewares/title_middleware.py
This commit is contained in:
@@ -345,6 +345,8 @@ def make_lead_agent(config: RunnableConfig):
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
|
||||
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
|
||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
|
||||
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
|
||||
),
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,14 @@ from deerflow.subagents import get_available_subagent_names
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_enabled_skills():
|
||||
try:
|
||||
return list(load_skills(enabled_only=True))
|
||||
except Exception:
|
||||
logger.exception("Failed to load enabled skills for prompt injection")
|
||||
return []
|
||||
|
||||
|
||||
def _build_subagent_section(max_concurrent: int) -> str:
|
||||
"""Build the subagent system prompt section with dynamic concurrency limit.
|
||||
|
||||
@@ -386,7 +394,7 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||
Returns the <skill_system>...</skill_system> block listing all enabled skills,
|
||||
suitable for injection into any agent's system prompt.
|
||||
"""
|
||||
skills = load_skills(enabled_only=True)
|
||||
skills = _get_enabled_skills()
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
@@ -402,6 +410,10 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||
if available_skills is not None:
|
||||
skills = [skill for skill in skills if skill.name in available_skills]
|
||||
|
||||
# Check again after filtering
|
||||
if not skills:
|
||||
return ""
|
||||
|
||||
skill_items = "\n".join(
|
||||
f" <skill>\n <name>{skill.name}</name>\n <description>{skill.description}</description>\n <location>{skill.get_container_file_path(container_base_path)}</location>\n </skill>" for skill in skills
|
||||
)
|
||||
@@ -446,7 +458,7 @@ def get_deferred_tools_prompt_section() -> str:
|
||||
|
||||
if not get_app_config().tool_search.enabled:
|
||||
return ""
|
||||
except FileNotFoundError:
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
registry = get_deferred_registry()
|
||||
|
||||
@@ -29,6 +29,17 @@ Instructions:
|
||||
2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies)
|
||||
3. Update the memory sections as needed following the detailed length guidelines below
|
||||
|
||||
Before extracting facts, perform a structured reflection on the conversation:
|
||||
1. Error/Retry Detection: Did the agent encounter errors, require retries, or produce incorrect results?
|
||||
If yes, record the root cause and correct approach as a high-confidence fact with category "correction".
|
||||
2. User Correction Detection: Did the user correct the agent's direction, understanding, or output?
|
||||
If yes, record the correct interpretation or approach as a high-confidence fact with category "correction".
|
||||
Include what went wrong in "sourceError" only when category is "correction" and the mistake is explicit in the conversation.
|
||||
3. Project Constraint Discovery: Were any project-specific constraints discovered during the conversation?
|
||||
If yes, record them as facts with the most appropriate category and confidence.
|
||||
|
||||
{correction_hint}
|
||||
|
||||
Memory Section Guidelines:
|
||||
|
||||
**User Context** (Current state - concise summaries):
|
||||
@@ -62,6 +73,7 @@ Memory Section Guidelines:
|
||||
* context: Background facts (job title, projects, locations, languages)
|
||||
* behavior: Working patterns, communication habits, problem-solving approaches
|
||||
* goal: Stated objectives, learning targets, project ambitions
|
||||
* correction: Explicit agent mistakes or user corrections, including the correct approach
|
||||
- Confidence levels:
|
||||
* 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y")
|
||||
* 0.7-0.8: Strongly implied from actions/discussions
|
||||
@@ -94,7 +106,7 @@ Output Format (JSON):
|
||||
"longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }}
|
||||
}},
|
||||
"newFacts": [
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
|
||||
],
|
||||
"factsToRemove": ["fact_id_1", "fact_id_2"]
|
||||
}}
|
||||
@@ -104,6 +116,8 @@ Important Rules:
|
||||
- Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs)
|
||||
- Include specific metrics, version numbers, and proper nouns in facts
|
||||
- Only add facts that are clearly stated (0.9+) or strongly implied (0.7+)
|
||||
- Use category "correction" for explicit agent mistakes or user corrections; assign confidence >= 0.95 when the correction is explicit
|
||||
- Include "sourceError" only for explicit correction facts when the prior mistake or wrong approach is clearly stated; omit it otherwise
|
||||
- Remove facts that are contradicted by new information
|
||||
- When updating topOfMind, integrate new focus areas while removing completed/abandoned ones
|
||||
Keep 3-5 concurrent focus themes that are still active and relevant
|
||||
@@ -126,7 +140,7 @@ Message:
|
||||
Extract facts in this JSON format:
|
||||
{{
|
||||
"facts": [
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
|
||||
]
|
||||
}}
|
||||
|
||||
@@ -136,6 +150,7 @@ Categories:
|
||||
- context: Background context (location, job, projects)
|
||||
- behavior: Behavioral patterns
|
||||
- goal: User's goals or objectives
|
||||
- correction: Explicit corrections or mistakes to avoid repeating
|
||||
|
||||
Rules:
|
||||
- Only extract clear, specific facts
|
||||
@@ -231,6 +246,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
if earlier.get("summary"):
|
||||
history_sections.append(f"Earlier: {earlier['summary']}")
|
||||
|
||||
background = history_data.get("longTermBackground", {})
|
||||
if background.get("summary"):
|
||||
history_sections.append(f"Background: {background['summary']}")
|
||||
|
||||
if history_sections:
|
||||
sections.append("History:\n" + "\n".join(f"- {s}" for s in history_sections))
|
||||
|
||||
@@ -262,7 +281,11 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
continue
|
||||
category = str(fact.get("category", "context")).strip() or "context"
|
||||
confidence = _coerce_confidence(fact.get("confidence"), default=0.0)
|
||||
line = f"- [{category} | {confidence:.2f}] {content}"
|
||||
source_error = fact.get("sourceError")
|
||||
if category == "correction" and isinstance(source_error, str) and source_error.strip():
|
||||
line = f"- [{category} | {confidence:.2f}] {content} (avoid: {source_error.strip()})"
|
||||
else:
|
||||
line = f"- [{category} | {confidence:.2f}] {content}"
|
||||
|
||||
# Each additional line is preceded by a newline (except the first).
|
||||
line_text = ("\n" + line) if fact_lines else line
|
||||
|
||||
@@ -20,6 +20,7 @@ class ConversationContext:
|
||||
messages: list[Any]
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
agent_name: str | None = None
|
||||
correction_detected: bool = False
|
||||
|
||||
|
||||
class MemoryUpdateQueue:
|
||||
@@ -37,25 +38,38 @@ class MemoryUpdateQueue:
|
||||
self._timer: threading.Timer | None = None
|
||||
self._processing = False
|
||||
|
||||
def add(self, thread_id: str, messages: list[Any], agent_name: str | None = None) -> None:
|
||||
def add(
|
||||
self,
|
||||
thread_id: str,
|
||||
messages: list[Any],
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
) -> None:
|
||||
"""Add a conversation to the update queue.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
messages: The conversation messages.
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
return
|
||||
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
existing_context = next(
|
||||
(context for context in self._queue if context.thread_id == thread_id),
|
||||
None,
|
||||
)
|
||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
correction_detected=merged_correction_detected,
|
||||
)
|
||||
|
||||
# Check if this thread already has a pending update
|
||||
# If so, replace it with the newer one
|
||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||
@@ -115,6 +129,7 @@ class MemoryUpdateQueue:
|
||||
messages=context.messages,
|
||||
thread_id=context.thread_id,
|
||||
agent_name=context.agent_name,
|
||||
correction_detected=context.correction_detected,
|
||||
)
|
||||
if success:
|
||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||
|
||||
@@ -266,13 +266,20 @@ class MemoryUpdater:
|
||||
model_name = self._model_name or config.model_name
|
||||
return create_chat_model(name=model_name, thinking_enabled=False)
|
||||
|
||||
def update_memory(self, messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
|
||||
def update_memory(
|
||||
self,
|
||||
messages: list[Any],
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Update memory based on conversation messages.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
thread_id: Optional thread ID for tracking source.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
@@ -295,9 +302,19 @@ class MemoryUpdater:
|
||||
return False
|
||||
|
||||
# Build prompt
|
||||
correction_hint = ""
|
||||
if correction_detected:
|
||||
correction_hint = (
|
||||
"IMPORTANT: Explicit correction signals were detected in this conversation. "
|
||||
"Pay special attention to what the agent got wrong, what the user corrected, "
|
||||
"and record the correct approach as a fact with category "
|
||||
'"correction" and confidence >= 0.95 when appropriate.'
|
||||
)
|
||||
|
||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||
current_memory=json.dumps(current_memory, indent=2),
|
||||
conversation=conversation_text,
|
||||
correction_hint=correction_hint,
|
||||
)
|
||||
|
||||
# Call LLM
|
||||
@@ -383,6 +400,8 @@ class MemoryUpdater:
|
||||
confidence = fact.get("confidence", 0.5)
|
||||
if confidence >= config.fact_confidence_threshold:
|
||||
raw_content = fact.get("content", "")
|
||||
if not isinstance(raw_content, str):
|
||||
continue
|
||||
normalized_content = raw_content.strip()
|
||||
fact_key = _fact_content_key(normalized_content)
|
||||
if fact_key is not None and fact_key in existing_fact_keys:
|
||||
@@ -396,6 +415,11 @@ class MemoryUpdater:
|
||||
"createdAt": now,
|
||||
"source": thread_id or "unknown",
|
||||
}
|
||||
source_error = fact.get("sourceError")
|
||||
if isinstance(source_error, str):
|
||||
normalized_source_error = source_error.strip()
|
||||
if normalized_source_error:
|
||||
fact_entry["sourceError"] = normalized_source_error
|
||||
current_memory["facts"].append(fact_entry)
|
||||
if fact_key is not None:
|
||||
existing_fact_keys.add(fact_key)
|
||||
@@ -412,16 +436,22 @@ class MemoryUpdater:
|
||||
return current_memory
|
||||
|
||||
|
||||
def update_memory_from_conversation(messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
|
||||
def update_memory_from_conversation(
|
||||
messages: list[Any],
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Convenience function to update memory from a conversation.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
thread_id: Optional thread ID.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
return updater.update_memory(messages, thread_id, agent_name)
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected)
|
||||
|
||||
+275
@@ -0,0 +1,275 @@
|
||||
"""LLM error handling middleware with retry/backoff and user-facing fallbacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
|
||||
_BUSY_PATTERNS = (
|
||||
"server busy",
|
||||
"temporarily unavailable",
|
||||
"try again later",
|
||||
"please retry",
|
||||
"please try again",
|
||||
"overloaded",
|
||||
"high demand",
|
||||
"rate limit",
|
||||
"负载较高",
|
||||
"服务繁忙",
|
||||
"稍后重试",
|
||||
"请稍后重试",
|
||||
)
|
||||
_QUOTA_PATTERNS = (
|
||||
"insufficient_quota",
|
||||
"quota",
|
||||
"billing",
|
||||
"credit",
|
||||
"payment",
|
||||
"余额不足",
|
||||
"超出限额",
|
||||
"额度不足",
|
||||
"欠费",
|
||||
)
|
||||
_AUTH_PATTERNS = (
|
||||
"authentication",
|
||||
"unauthorized",
|
||||
"invalid api key",
|
||||
"invalid_api_key",
|
||||
"permission",
|
||||
"forbidden",
|
||||
"access denied",
|
||||
"无权",
|
||||
"未授权",
|
||||
)
|
||||
|
||||
|
||||
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Retry transient LLM errors and surface graceful assistant messages."""
|
||||
|
||||
retry_max_attempts: int = 3
|
||||
retry_base_delay_ms: int = 1000
|
||||
retry_cap_delay_ms: int = 8000
|
||||
|
||||
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
|
||||
detail = _extract_error_detail(exc)
|
||||
lowered = detail.lower()
|
||||
error_code = _extract_error_code(exc)
|
||||
status_code = _extract_status_code(exc)
|
||||
|
||||
if _matches_any(lowered, _QUOTA_PATTERNS) or _matches_any(str(error_code).lower(), _QUOTA_PATTERNS):
|
||||
return False, "quota"
|
||||
if _matches_any(lowered, _AUTH_PATTERNS):
|
||||
return False, "auth"
|
||||
|
||||
exc_name = exc.__class__.__name__
|
||||
if exc_name in {
|
||||
"APITimeoutError",
|
||||
"APIConnectionError",
|
||||
"InternalServerError",
|
||||
}:
|
||||
return True, "transient"
|
||||
if status_code in _RETRIABLE_STATUS_CODES:
|
||||
return True, "transient"
|
||||
if _matches_any(lowered, _BUSY_PATTERNS):
|
||||
return True, "busy"
|
||||
|
||||
return False, "generic"
|
||||
|
||||
def _build_retry_delay_ms(self, attempt: int, exc: BaseException) -> int:
|
||||
retry_after = _extract_retry_after_ms(exc)
|
||||
if retry_after is not None:
|
||||
return retry_after
|
||||
backoff = self.retry_base_delay_ms * (2 ** max(0, attempt - 1))
|
||||
return min(backoff, self.retry_cap_delay_ms)
|
||||
|
||||
def _build_retry_message(self, attempt: int, wait_ms: int, reason: str) -> str:
|
||||
seconds = max(1, round(wait_ms / 1000))
|
||||
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
|
||||
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
|
||||
|
||||
def _build_user_message(self, exc: BaseException, reason: str) -> str:
|
||||
detail = _extract_error_detail(exc)
|
||||
if reason == "quota":
|
||||
return "The configured LLM provider rejected the request because the account is out of quota, billing is unavailable, or usage is restricted. Please fix the provider account and try again."
|
||||
if reason == "auth":
|
||||
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
|
||||
if reason in {"busy", "transient"}:
|
||||
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
|
||||
return f"LLM request failed: {detail}"
|
||||
|
||||
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
|
||||
try:
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
writer(
|
||||
{
|
||||
"type": "llm_retry",
|
||||
"attempt": attempt,
|
||||
"max_attempts": self.retry_max_attempts,
|
||||
"wait_ms": wait_ms,
|
||||
"reason": reason,
|
||||
"message": self._build_retry_message(attempt, wait_ms, reason),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to emit llm_retry event", exc_info=True)
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
if retriable and attempt < self.retry_max_attempts:
|
||||
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
||||
logger.warning(
|
||||
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
||||
attempt,
|
||||
self.retry_max_attempts,
|
||||
wait_ms,
|
||||
_extract_error_detail(exc),
|
||||
)
|
||||
self._emit_retry_event(attempt, wait_ms, reason)
|
||||
time.sleep(wait_ms / 1000)
|
||||
attempt += 1
|
||||
continue
|
||||
logger.warning(
|
||||
"LLM call failed after %d attempt(s): %s",
|
||||
attempt,
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return await handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
if retriable and attempt < self.retry_max_attempts:
|
||||
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
||||
logger.warning(
|
||||
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
||||
attempt,
|
||||
self.retry_max_attempts,
|
||||
wait_ms,
|
||||
_extract_error_detail(exc),
|
||||
)
|
||||
self._emit_retry_event(attempt, wait_ms, reason)
|
||||
await asyncio.sleep(wait_ms / 1000)
|
||||
attempt += 1
|
||||
continue
|
||||
logger.warning(
|
||||
"LLM call failed after %d attempt(s): %s",
|
||||
attempt,
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
|
||||
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
|
||||
return any(pattern in detail for pattern in patterns)
|
||||
|
||||
|
||||
def _extract_error_code(exc: BaseException) -> Any:
|
||||
for attr in ("code", "error_code"):
|
||||
value = getattr(exc, attr, None)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
|
||||
body = getattr(exc, "body", None)
|
||||
if isinstance(body, dict):
|
||||
error = body.get("error")
|
||||
if isinstance(error, dict):
|
||||
for key in ("code", "type"):
|
||||
value = error.get(key)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _extract_status_code(exc: BaseException) -> int | None:
|
||||
for attr in ("status_code", "status"):
|
||||
value = getattr(exc, attr, None)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
response = getattr(exc, "response", None)
|
||||
status = getattr(response, "status_code", None)
|
||||
return status if isinstance(status, int) else None
|
||||
|
||||
|
||||
def _extract_retry_after_ms(exc: BaseException) -> int | None:
|
||||
response = getattr(exc, "response", None)
|
||||
headers = getattr(response, "headers", None)
|
||||
if headers is None:
|
||||
return None
|
||||
|
||||
raw = None
|
||||
header_name = ""
|
||||
for key in ("retry-after-ms", "Retry-After-Ms", "retry-after", "Retry-After"):
|
||||
header_name = key
|
||||
if hasattr(headers, "get"):
|
||||
raw = headers.get(key)
|
||||
if raw:
|
||||
break
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
try:
|
||||
multiplier = 1 if "ms" in header_name.lower() else 1000
|
||||
return max(0, int(float(raw) * multiplier))
|
||||
except (TypeError, ValueError):
|
||||
try:
|
||||
target = parsedate_to_datetime(str(raw))
|
||||
delta = target.timestamp() - time.time()
|
||||
return max(0, int(delta * 1000))
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
return None
|
||||
|
||||
|
||||
def _extract_error_detail(exc: BaseException) -> str:
|
||||
detail = str(exc).strip()
|
||||
if detail:
|
||||
return detail
|
||||
message = getattr(exc, "message", None)
|
||||
if isinstance(message, str) and message.strip():
|
||||
return message.strip()
|
||||
return exc.__class__.__name__
|
||||
@@ -182,6 +182,23 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
|
||||
return None, False
|
||||
|
||||
@staticmethod
|
||||
def _append_text(content: str | list | None, text: str) -> str | list:
|
||||
"""Append *text* to AIMessage content, handling str, list, and None.
|
||||
|
||||
When content is a list of content blocks (e.g. Anthropic thinking mode),
|
||||
we append a new ``{"type": "text", ...}`` block instead of concatenating
|
||||
a string to a list, which would raise ``TypeError``.
|
||||
"""
|
||||
if content is None:
|
||||
return text
|
||||
if isinstance(content, list):
|
||||
return [*content, {"type": "text", "text": f"\n\n{text}"}]
|
||||
if isinstance(content, str):
|
||||
return content + f"\n\n{text}"
|
||||
# Fallback: coerce unexpected types to str to avoid TypeError
|
||||
return str(content) + f"\n\n{text}"
|
||||
|
||||
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
warning, hard_stop = self._track_and_check(state, runtime)
|
||||
|
||||
@@ -192,7 +209,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
stripped_msg = last_msg.model_copy(
|
||||
update={
|
||||
"tool_calls": [],
|
||||
"content": (last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}",
|
||||
"content": self._append_text(last_msg.content, _HARD_STOP_MSG),
|
||||
}
|
||||
)
|
||||
return {"messages": [stripped_msg]}
|
||||
|
||||
@@ -14,6 +14,21 @@ from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||
_CORRECTION_PATTERNS = (
|
||||
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
|
||||
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
|
||||
re.compile(r"\btry again\b", re.IGNORECASE),
|
||||
re.compile(r"\bredo\b", re.IGNORECASE),
|
||||
re.compile(r"不对"),
|
||||
re.compile(r"你理解错了"),
|
||||
re.compile(r"你理解有误"),
|
||||
re.compile(r"重试"),
|
||||
re.compile(r"重新来"),
|
||||
re.compile(r"换一种"),
|
||||
re.compile(r"改用"),
|
||||
)
|
||||
|
||||
|
||||
class MemoryMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
@@ -21,6 +36,22 @@ class MemoryMiddlewareState(AgentState):
|
||||
pass
|
||||
|
||||
|
||||
def _extract_message_text(message: Any) -> str:
|
||||
"""Extract plain text from message content for filtering and signal detection."""
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
text_parts.append(part)
|
||||
elif isinstance(part, dict):
|
||||
text_val = part.get("text")
|
||||
if isinstance(text_val, str):
|
||||
text_parts.append(text_val)
|
||||
return " ".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
"""Filter messages to keep only user inputs and final assistant responses.
|
||||
|
||||
@@ -44,18 +75,13 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
Returns:
|
||||
Filtered list containing only user inputs and final assistant responses.
|
||||
"""
|
||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||
|
||||
filtered = []
|
||||
skip_next_ai = False
|
||||
for msg in messages:
|
||||
msg_type = getattr(msg, "type", None)
|
||||
|
||||
if msg_type == "human":
|
||||
content = getattr(msg, "content", "")
|
||||
if isinstance(content, list):
|
||||
content = " ".join(p.get("text", "") for p in content if isinstance(p, dict))
|
||||
content_str = str(content)
|
||||
content_str = _extract_message_text(msg)
|
||||
if "<uploaded_files>" in content_str:
|
||||
# Strip the ephemeral upload block; keep the user's real question.
|
||||
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
||||
@@ -87,6 +113,25 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
return filtered
|
||||
|
||||
|
||||
def detect_correction(messages: list[Any]) -> bool:
|
||||
"""Detect explicit user corrections in recent conversation turns.
|
||||
|
||||
The queue keeps only one pending context per thread, so callers pass the
|
||||
latest filtered message list. Checking only recent user turns keeps signal
|
||||
detection conservative while avoiding stale corrections from long histories.
|
||||
"""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = _extract_message_text(msg).strip()
|
||||
if not content:
|
||||
continue
|
||||
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
"""Middleware that queues conversation for memory update after agent execution.
|
||||
|
||||
@@ -150,7 +195,13 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
return None
|
||||
|
||||
# Queue the filtered conversation for memory update
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
queue = get_memory_queue()
|
||||
queue.add(thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name)
|
||||
queue.add(
|
||||
thread_id=thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=self._agent_name,
|
||||
correction_detected=correction_detected,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@@ -116,44 +116,33 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
return config
|
||||
|
||||
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Synchronously generate a title. Returns state update or None."""
|
||||
"""Generate a local fallback title without blocking on an LLM call."""
|
||||
if not self._should_generate_title(state):
|
||||
return None
|
||||
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
config = get_title_config()
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
|
||||
try:
|
||||
response = model.invoke(prompt, config=self._get_runnable_config())
|
||||
title = self._parse_title(response.content)
|
||||
if not title:
|
||||
title = self._fallback_title(user_msg)
|
||||
except Exception:
|
||||
logger.exception("Failed to generate title (sync)")
|
||||
title = self._fallback_title(user_msg)
|
||||
|
||||
return {"title": title}
|
||||
_, user_msg = self._build_title_prompt(state)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
|
||||
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Asynchronously generate a title. Returns state update or None."""
|
||||
"""Generate a title asynchronously and fall back locally on failure."""
|
||||
if not self._should_generate_title(state):
|
||||
return None
|
||||
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
config = get_title_config()
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
|
||||
try:
|
||||
response = await model.ainvoke(prompt, config=self._get_runnable_config())
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
response = await model.ainvoke(prompt)
|
||||
title = self._parse_title(response.content)
|
||||
if not title:
|
||||
title = self._fallback_title(user_msg)
|
||||
if title:
|
||||
return {"title": title}
|
||||
except Exception:
|
||||
logger.exception("Failed to generate title (async)")
|
||||
title = self._fallback_title(user_msg)
|
||||
|
||||
return {"title": title}
|
||||
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
|
||||
@override
|
||||
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
|
||||
+4
-1
@@ -72,6 +72,7 @@ def _build_runtime_middlewares(
|
||||
lazy_init: bool = True,
|
||||
) -> list[AgentMiddleware]:
|
||||
"""Build shared base middlewares for agent execution."""
|
||||
from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware
|
||||
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from deerflow.sandbox.middleware import SandboxMiddleware
|
||||
|
||||
@@ -90,6 +91,8 @@ def _build_runtime_middlewares(
|
||||
|
||||
middlewares.append(DanglingToolCallMiddleware())
|
||||
|
||||
middlewares.append(LLMErrorHandlingMiddleware())
|
||||
|
||||
# Guardrail middleware (if configured)
|
||||
from deerflow.config.guardrails_config import get_guardrails_config
|
||||
|
||||
@@ -135,6 +138,6 @@ def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentM
|
||||
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
include_uploads=False,
|
||||
include_dangling_tool_call_patch=False,
|
||||
include_dangling_tool_call_patch=True,
|
||||
lazy_init=lazy_init,
|
||||
)
|
||||
|
||||
@@ -10,10 +10,52 @@ from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.utils.file_conversion import extract_outline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_OUTLINE_PREVIEW_LINES = 5
|
||||
|
||||
|
||||
def _extract_outline_for_file(file_path: Path) -> tuple[list[dict], list[str]]:
|
||||
"""Return the document outline and fallback preview for *file_path*.
|
||||
|
||||
Looks for a sibling ``<stem>.md`` file produced by the upload conversion
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
(outline, preview) where:
|
||||
- outline: list of ``{title, line}`` dicts (plus optional sentinel).
|
||||
Empty when no headings are found or no .md exists.
|
||||
- preview: first few non-empty lines of the .md, used as a content
|
||||
anchor when outline is empty so the agent has some context.
|
||||
Empty when outline is non-empty (no fallback needed).
|
||||
"""
|
||||
md_path = file_path.with_suffix(".md")
|
||||
if not md_path.is_file():
|
||||
return [], []
|
||||
|
||||
outline = extract_outline(md_path)
|
||||
if outline:
|
||||
logger.debug("Extracted %d outline entries from %s", len(outline), file_path.name)
|
||||
return outline, []
|
||||
|
||||
# outline is empty — read the first few non-empty lines as a content preview
|
||||
preview: list[str] = []
|
||||
try:
|
||||
with md_path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
stripped = line.strip()
|
||||
if stripped:
|
||||
preview.append(stripped)
|
||||
if len(preview) >= _OUTLINE_PREVIEW_LINES:
|
||||
break
|
||||
except Exception:
|
||||
logger.debug("Failed to read preview lines from %s", md_path, exc_info=True)
|
||||
return [], preview
|
||||
|
||||
|
||||
class UploadsMiddlewareState(AgentState):
|
||||
"""State schema for uploads middleware."""
|
||||
|
||||
@@ -39,12 +81,38 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
super().__init__()
|
||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||
|
||||
def _format_file_entry(self, file: dict, lines: list[str]) -> None:
|
||||
"""Append a single file entry (name, size, path, optional outline) to lines."""
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
outline = file.get("outline") or []
|
||||
if outline:
|
||||
truncated = outline[-1].get("truncated", False)
|
||||
visible = [e for e in outline if not e.get("truncated")]
|
||||
lines.append(" Document outline (use `read_file` with line ranges to read sections):")
|
||||
for entry in visible:
|
||||
lines.append(f" L{entry['line']}: {entry['title']}")
|
||||
if truncated:
|
||||
lines.append(f" ... (showing first {len(visible)} headings; use `read_file` to explore further)")
|
||||
else:
|
||||
preview = file.get("outline_preview") or []
|
||||
if preview:
|
||||
lines.append(" No structural headings detected. Document begins with:")
|
||||
for text in preview:
|
||||
lines.append(f" > {text}")
|
||||
lines.append(" Use `grep` to search for keywords (e.g. `grep(pattern='keyword', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("")
|
||||
|
||||
def _create_files_message(self, new_files: list[dict], historical_files: list[dict]) -> str:
|
||||
"""Create a formatted message listing uploaded files.
|
||||
|
||||
Args:
|
||||
new_files: Files uploaded in the current message.
|
||||
historical_files: Files uploaded in previous messages.
|
||||
Each file dict may contain an optional ``outline`` key — a list of
|
||||
``{title, line}`` dicts extracted from the converted Markdown file.
|
||||
|
||||
Returns:
|
||||
Formatted string inside <uploaded_files> tags.
|
||||
@@ -55,25 +123,24 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
lines.append("")
|
||||
if new_files:
|
||||
for file in new_files:
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
lines.append("")
|
||||
self._format_file_entry(file, lines)
|
||||
else:
|
||||
lines.append("(empty)")
|
||||
lines.append("")
|
||||
|
||||
if historical_files:
|
||||
lines.append("The following files were uploaded in previous messages and are still available:")
|
||||
lines.append("")
|
||||
for file in historical_files:
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
lines.append("")
|
||||
self._format_file_entry(file, lines)
|
||||
|
||||
lines.append("You can read these files using the `read_file` tool with the paths shown above.")
|
||||
lines.append("To work with these files:")
|
||||
lines.append("- Read from the file first — use the outline line numbers and `read_file` to locate relevant sections.")
|
||||
lines.append("- Use `grep` to search for keywords when you are not sure which section to look at")
|
||||
lines.append(" (e.g. `grep(pattern='revenue', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("- Use `glob` to find files by name pattern")
|
||||
lines.append(" (e.g. `glob(pattern='**/*.md', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("- Only fall back to web search if the file content is clearly insufficient to answer the question.")
|
||||
lines.append("</uploaded_files>")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -147,6 +214,13 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
|
||||
# Resolve uploads directory for existence checks
|
||||
thread_id = (runtime.context or {}).get("thread_id")
|
||||
if thread_id is None:
|
||||
try:
|
||||
from langgraph.config import get_config
|
||||
|
||||
thread_id = get_config().get("configurable", {}).get("thread_id")
|
||||
except RuntimeError:
|
||||
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
||||
|
||||
# Get newly uploaded files from the current message's additional_kwargs.files
|
||||
@@ -159,15 +233,26 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
for file_path in sorted(uploads_dir.iterdir()):
|
||||
if file_path.is_file() and file_path.name not in new_filenames:
|
||||
stat = file_path.stat()
|
||||
outline, preview = _extract_outline_for_file(file_path)
|
||||
historical_files.append(
|
||||
{
|
||||
"filename": file_path.name,
|
||||
"size": stat.st_size,
|
||||
"path": f"/mnt/user-data/uploads/{file_path.name}",
|
||||
"extension": file_path.suffix,
|
||||
"outline": outline,
|
||||
"outline_preview": preview,
|
||||
}
|
||||
)
|
||||
|
||||
# Attach outlines to new files as well
|
||||
if uploads_dir:
|
||||
for file in new_files:
|
||||
phys_path = uploads_dir / file["filename"]
|
||||
outline, preview = _extract_outline_for_file(phys_path)
|
||||
file["outline"] = outline
|
||||
file["outline_preview"] = preview
|
||||
|
||||
if not new_files and not historical_files:
|
||||
return None
|
||||
|
||||
|
||||
@@ -117,6 +117,7 @@ class DeerFlowClient:
|
||||
subagent_enabled: bool = False,
|
||||
plan_mode: bool = False,
|
||||
agent_name: str | None = None,
|
||||
available_skills: set[str] | None = None,
|
||||
middlewares: Sequence[AgentMiddleware] | None = None,
|
||||
):
|
||||
"""Initialize the client.
|
||||
@@ -133,6 +134,7 @@ class DeerFlowClient:
|
||||
subagent_enabled: Enable subagent delegation.
|
||||
plan_mode: Enable TodoList middleware for plan mode.
|
||||
agent_name: Name of the agent to use.
|
||||
available_skills: Optional set of skill names to make available. If None (default), all scanned skills are available.
|
||||
middlewares: Optional list of custom middlewares to inject into the agent.
|
||||
"""
|
||||
if config_path is not None:
|
||||
@@ -148,6 +150,7 @@ class DeerFlowClient:
|
||||
self._subagent_enabled = subagent_enabled
|
||||
self._plan_mode = plan_mode
|
||||
self._agent_name = agent_name
|
||||
self._available_skills = set(available_skills) if available_skills is not None else None
|
||||
self._middlewares = list(middlewares) if middlewares else []
|
||||
|
||||
# Lazy agent — created on first call, recreated when config changes.
|
||||
@@ -208,6 +211,8 @@ class DeerFlowClient:
|
||||
cfg.get("thinking_enabled"),
|
||||
cfg.get("is_plan_mode"),
|
||||
cfg.get("subagent_enabled"),
|
||||
self._agent_name,
|
||||
frozenset(self._available_skills) if self._available_skills is not None else None,
|
||||
)
|
||||
|
||||
if self._agent is not None and self._agent_config_key == key:
|
||||
@@ -226,6 +231,7 @@ class DeerFlowClient:
|
||||
subagent_enabled=subagent_enabled,
|
||||
max_concurrent_subagents=max_concurrent_subagents,
|
||||
agent_name=self._agent_name,
|
||||
available_skills=self._available_skills,
|
||||
),
|
||||
"state_schema": ThreadState,
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import uuid
|
||||
from agent_sandbox import Sandbox as AioSandboxClient
|
||||
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.search import GrepMatch, path_matches, should_ignore_path, truncate_line
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -124,16 +125,96 @@ class AioSandbox(Sandbox):
|
||||
content: The text content to write to the file.
|
||||
append: Whether to append the content to the file.
|
||||
"""
|
||||
try:
|
||||
if append:
|
||||
# Read existing content first and append
|
||||
existing = self.read_file(path)
|
||||
if not existing.startswith("Error:"):
|
||||
content = existing + content
|
||||
self._client.file.write_file(file=path, content=content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write file in sandbox: {e}")
|
||||
raise
|
||||
with self._lock:
|
||||
try:
|
||||
if append:
|
||||
existing = self.read_file(path)
|
||||
if not existing.startswith("Error:"):
|
||||
content = existing + content
|
||||
self._client.file.write_file(file=path, content=content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write file in sandbox: {e}")
|
||||
raise
|
||||
|
||||
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
||||
if not include_dirs:
|
||||
result = self._client.file.find_files(path=path, glob=pattern)
|
||||
files = result.data.files if result.data and result.data.files else []
|
||||
filtered = [file_path for file_path in files if not should_ignore_path(file_path)]
|
||||
truncated = len(filtered) > max_results
|
||||
return filtered[:max_results], truncated
|
||||
|
||||
result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
|
||||
entries = result.data.files if result.data and result.data.files else []
|
||||
matches: list[str] = []
|
||||
root_path = path.rstrip("/") or "/"
|
||||
root_prefix = root_path if root_path == "/" else f"{root_path}/"
|
||||
for entry in entries:
|
||||
if entry.path != root_path and not entry.path.startswith(root_prefix):
|
||||
continue
|
||||
if should_ignore_path(entry.path):
|
||||
continue
|
||||
rel_path = entry.path[len(root_path) :].lstrip("/")
|
||||
if path_matches(pattern, rel_path):
|
||||
matches.append(entry.path)
|
||||
if len(matches) >= max_results:
|
||||
return matches, True
|
||||
return matches, False
|
||||
|
||||
def grep(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str,
|
||||
*,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> tuple[list[GrepMatch], bool]:
|
||||
import re as _re
|
||||
|
||||
regex_source = _re.escape(pattern) if literal else pattern
|
||||
# Validate the pattern locally so an invalid regex raises re.error
|
||||
# (caught by grep_tool's except re.error handler) rather than a
|
||||
# generic remote API error.
|
||||
_re.compile(regex_source, 0 if case_sensitive else _re.IGNORECASE)
|
||||
regex = regex_source if case_sensitive else f"(?i){regex_source}"
|
||||
|
||||
if glob is not None:
|
||||
find_result = self._client.file.find_files(path=path, glob=glob)
|
||||
candidate_paths = find_result.data.files if find_result.data and find_result.data.files else []
|
||||
else:
|
||||
list_result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
|
||||
entries = list_result.data.files if list_result.data and list_result.data.files else []
|
||||
candidate_paths = [entry.path for entry in entries if not entry.is_directory]
|
||||
|
||||
matches: list[GrepMatch] = []
|
||||
truncated = False
|
||||
|
||||
for file_path in candidate_paths:
|
||||
if should_ignore_path(file_path):
|
||||
continue
|
||||
|
||||
search_result = self._client.file.search_in_file(file=file_path, regex=regex)
|
||||
data = search_result.data
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
line_numbers = data.line_numbers or []
|
||||
matched_lines = data.matches or []
|
||||
for line_number, line in zip(line_numbers, matched_lines):
|
||||
matches.append(
|
||||
GrepMatch(
|
||||
path=file_path,
|
||||
line_number=line_number if isinstance(line_number, int) else 0,
|
||||
line=truncate_line(line),
|
||||
)
|
||||
)
|
||||
if len(matches) >= max_results:
|
||||
truncated = True
|
||||
return matches, truncated
|
||||
|
||||
return matches, truncated
|
||||
|
||||
def update_file(self, path: str, content: bytes) -> None:
|
||||
"""Update a file with binary content in the sandbox.
|
||||
@@ -142,9 +223,10 @@ class AioSandbox(Sandbox):
|
||||
path: The absolute path of the file to update.
|
||||
content: The binary content to write to the file.
|
||||
"""
|
||||
try:
|
||||
base64_content = base64.b64encode(content).decode("utf-8")
|
||||
self._client.file.write_file(file=path, content=base64_content, encoding="base64")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update file in sandbox: {e}")
|
||||
raise
|
||||
with self._lock:
|
||||
try:
|
||||
base64_content = base64.b64encode(content).decode("utf-8")
|
||||
self._client.file.write_file(file=path, content=base64_content, encoding="base64")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update file in sandbox: {e}")
|
||||
raise
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_api_key_warned = False
|
||||
|
||||
|
||||
class JinaClient:
|
||||
def crawl(self, url: str, return_format: str = "html", timeout: int = 10) -> str:
|
||||
async def crawl(self, url: str, return_format: str = "html", timeout: int = 10) -> str:
|
||||
global _api_key_warned
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Return-Format": return_format,
|
||||
@@ -15,11 +18,13 @@ class JinaClient:
|
||||
}
|
||||
if os.getenv("JINA_API_KEY"):
|
||||
headers["Authorization"] = f"Bearer {os.getenv('JINA_API_KEY')}"
|
||||
else:
|
||||
elif not _api_key_warned:
|
||||
_api_key_warned = True
|
||||
logger.warning("Jina API key is not set. Provide your own key to access a higher rate limit. See https://jina.ai/reader for more information.")
|
||||
data = {"url": url}
|
||||
try:
|
||||
response = requests.post("https://r.jina.ai/", headers=headers, json=data)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post("https://r.jina.ai/", headers=headers, json=data, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_message = f"Jina API returned status {response.status_code}: {response.text}"
|
||||
@@ -34,5 +39,5 @@ class JinaClient:
|
||||
return response.text
|
||||
except Exception as e:
|
||||
error_message = f"Request to Jina API failed: {str(e)}"
|
||||
logger.error(error_message)
|
||||
logger.exception(error_message)
|
||||
return f"Error: {error_message}"
|
||||
|
||||
@@ -8,7 +8,7 @@ readability_extractor = ReadabilityExtractor()
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
def web_fetch_tool(url: str) -> str:
|
||||
async def web_fetch_tool(url: str) -> str:
|
||||
"""Fetch the contents of a web page at a given URL.
|
||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||
@@ -23,6 +23,8 @@ def web_fetch_tool(url: str) -> str:
|
||||
config = get_app_config().get_tool_config("web_fetch")
|
||||
if config is not None and "timeout" in config.model_extra:
|
||||
timeout = config.model_extra.get("timeout")
|
||||
html_content = jina_client.crawl(url, return_format="html", timeout=timeout)
|
||||
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
|
||||
if isinstance(html_content, str) and html_content.startswith("Error:"):
|
||||
return html_content
|
||||
article = readability_extractor.extract_article(html_content)
|
||||
return article.to_markdown()[:4096]
|
||||
|
||||
@@ -3,7 +3,13 @@ from .extensions_config import ExtensionsConfig, get_extensions_config
|
||||
from .memory_config import MemoryConfig, get_memory_config
|
||||
from .paths import Paths, get_paths
|
||||
from .skills_config import SkillsConfig
|
||||
from .tracing_config import get_tracing_config, is_tracing_enabled
|
||||
from .tracing_config import (
|
||||
get_enabled_tracing_providers,
|
||||
get_explicitly_enabled_tracing_providers,
|
||||
get_tracing_config,
|
||||
is_tracing_enabled,
|
||||
validate_enabled_tracing_providers,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_app_config",
|
||||
@@ -15,5 +21,8 @@ __all__ = [
|
||||
"MemoryConfig",
|
||||
"get_memory_config",
|
||||
"get_tracing_config",
|
||||
"get_explicitly_enabled_tracing_providers",
|
||||
"get_enabled_tracing_providers",
|
||||
"is_tracing_enabled",
|
||||
"validate_enabled_tracing_providers",
|
||||
]
|
||||
|
||||
@@ -22,6 +22,11 @@ class AgentConfig(BaseModel):
|
||||
description: str = ""
|
||||
model: str | None = None
|
||||
tool_groups: list[str] | None = None
|
||||
# skills controls which skills are loaded into the agent's prompt:
|
||||
# - None (or omitted): load all enabled skills (default fallback behavior)
|
||||
# - [] (explicit empty list): disable all skills
|
||||
# - ["skill1", "skill2"]: load only the specified skills
|
||||
skills: list[str] | None = None
|
||||
|
||||
|
||||
def load_agent_config(name: str | None) -> AgentConfig | None:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from typing import Any, Self
|
||||
|
||||
@@ -11,16 +12,16 @@ from deerflow.config.acp_config import load_acp_config_from_dict
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
||||
from deerflow.config.database_config import DatabaseConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.guardrails_config import load_guardrails_config_from_dict
|
||||
from deerflow.config.memory_config import load_memory_config_from_dict
|
||||
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
|
||||
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.run_events_config import RunEventsConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.skills_config import SkillsConfig
|
||||
from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict
|
||||
from deerflow.config.subagents_config import load_subagents_config_from_dict
|
||||
from deerflow.config.summarization_config import load_summarization_config_from_dict
|
||||
from deerflow.config.title_config import load_title_config_from_dict
|
||||
from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict
|
||||
from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict
|
||||
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict
|
||||
from deerflow.config.token_usage_config import TokenUsageConfig
|
||||
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
|
||||
@@ -30,6 +31,13 @@ load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_config_candidates() -> tuple[Path, ...]:
|
||||
"""Return deterministic config.yaml locations without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
return (backend_dir / "config.yaml", repo_root / "config.yaml")
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""Config for the DeerFlow application"""
|
||||
|
||||
@@ -42,6 +50,11 @@ class AppConfig(BaseModel):
|
||||
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
|
||||
extensions: ExtensionsConfig = Field(default_factory=ExtensionsConfig, description="Extensions configuration (MCP servers and skills state)")
|
||||
tool_search: ToolSearchConfig = Field(default_factory=ToolSearchConfig, description="Tool search / deferred loading configuration")
|
||||
title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration")
|
||||
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
|
||||
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
|
||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
model_config = ConfigDict(extra="allow", frozen=False)
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
|
||||
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
|
||||
@@ -55,7 +68,7 @@ class AppConfig(BaseModel):
|
||||
Priority:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, first check the `config.yaml` in the current directory, then fallback to `config.yaml` in the parent directory.
|
||||
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
|
||||
"""
|
||||
if config_path:
|
||||
path = Path(config_path)
|
||||
@@ -68,14 +81,10 @@ class AppConfig(BaseModel):
|
||||
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
# Check if the config.yaml is in the current directory
|
||||
path = Path(os.getcwd()) / "config.yaml"
|
||||
if not path.exists():
|
||||
# Check if the config.yaml is in the parent directory of CWD
|
||||
path = Path(os.getcwd()).parent / "config.yaml"
|
||||
if not path.exists():
|
||||
raise FileNotFoundError("`config.yaml` file not found at the current directory nor its parent directory")
|
||||
return path
|
||||
for path in _default_config_candidates():
|
||||
if path.exists():
|
||||
return path
|
||||
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str | None = None) -> Self:
|
||||
@@ -248,6 +257,8 @@ _app_config: AppConfig | None = None
|
||||
_app_config_path: Path | None = None
|
||||
_app_config_mtime: float | None = None
|
||||
_app_config_is_custom = False
|
||||
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
|
||||
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
|
||||
|
||||
|
||||
def _get_config_mtime(config_path: Path) -> float | None:
|
||||
@@ -280,6 +291,10 @@ def get_app_config() -> AppConfig:
|
||||
"""
|
||||
global _app_config, _app_config_path, _app_config_mtime
|
||||
|
||||
runtime_override = _current_app_config.get()
|
||||
if runtime_override is not None:
|
||||
return runtime_override
|
||||
|
||||
if _app_config is not None and _app_config_is_custom:
|
||||
return _app_config
|
||||
|
||||
@@ -341,3 +356,26 @@ def set_app_config(config: AppConfig) -> None:
|
||||
_app_config_path = None
|
||||
_app_config_mtime = None
|
||||
_app_config_is_custom = True
|
||||
|
||||
|
||||
def peek_current_app_config() -> AppConfig | None:
|
||||
"""Return the runtime-scoped AppConfig override, if one is active."""
|
||||
return _current_app_config.get()
|
||||
|
||||
|
||||
def push_current_app_config(config: AppConfig) -> None:
|
||||
"""Push a runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
_current_app_config_stack.set(stack + (_current_app_config.get(),))
|
||||
_current_app_config.set(config)
|
||||
|
||||
|
||||
def pop_current_app_config() -> None:
|
||||
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
if not stack:
|
||||
_current_app_config.set(None)
|
||||
return
|
||||
previous = stack[-1]
|
||||
_current_app_config_stack.set(stack[:-1])
|
||||
_current_app_config.set(previous)
|
||||
|
||||
@@ -80,6 +80,12 @@ class ExtensionsConfig(BaseModel):
|
||||
Args:
|
||||
config_path: Optional path to extensions config file.
|
||||
|
||||
Resolution order:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, search backend/repository-root defaults for
|
||||
`extensions_config.json`, then legacy `mcp_config.json`.
|
||||
|
||||
Returns:
|
||||
Path to the extensions config file if found, otherwise None.
|
||||
"""
|
||||
@@ -94,24 +100,16 @@ class ExtensionsConfig(BaseModel):
|
||||
raise FileNotFoundError(f"Extensions config file specified by environment variable `DEER_FLOW_EXTENSIONS_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
# Check if the extensions_config.json is in the current directory
|
||||
path = Path(os.getcwd()) / "extensions_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
# Check if the extensions_config.json is in the parent directory of CWD
|
||||
path = Path(os.getcwd()).parent / "extensions_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
# Backward compatibility: check for mcp_config.json
|
||||
path = Path(os.getcwd()) / "mcp_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
path = Path(os.getcwd()).parent / "mcp_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
for path in (
|
||||
backend_dir / "extensions_config.json",
|
||||
repo_root / "extensions_config.json",
|
||||
backend_dir / "mcp_config.json",
|
||||
repo_root / "mcp_config.json",
|
||||
):
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
# Extensions are optional, so return None if not found
|
||||
return None
|
||||
|
||||
@@ -9,6 +9,12 @@ VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
||||
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
|
||||
|
||||
def _default_local_base_dir() -> Path:
|
||||
"""Return the repo-local DeerFlow state directory without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
return backend_dir / ".deer-flow"
|
||||
|
||||
|
||||
def _validate_thread_id(thread_id: str) -> str:
|
||||
"""Validate a thread ID before using it in filesystem paths."""
|
||||
if not _SAFE_THREAD_ID_RE.match(thread_id):
|
||||
@@ -67,8 +73,7 @@ class Paths:
|
||||
BaseDir resolution (in priority order):
|
||||
1. Constructor argument `base_dir`
|
||||
2. DEER_FLOW_HOME environment variable
|
||||
3. Local dev fallback: cwd/.deer-flow (when cwd is the backend/ dir)
|
||||
4. Default: $HOME/.deer-flow
|
||||
3. Repo-local fallback derived from this module path: `{backend_dir}/.deer-flow`
|
||||
"""
|
||||
|
||||
def __init__(self, base_dir: str | Path | None = None) -> None:
|
||||
@@ -104,11 +109,7 @@ class Paths:
|
||||
if env_home := os.getenv("DEER_FLOW_HOME"):
|
||||
return Path(env_home).resolve()
|
||||
|
||||
cwd = Path.cwd()
|
||||
if cwd.name == "backend" or (cwd / "pyproject.toml").exists():
|
||||
return cwd / ".deer-flow"
|
||||
|
||||
return Path.home() / ".deer-flow"
|
||||
return _default_local_base_dir()
|
||||
|
||||
@property
|
||||
def memory_file(self) -> Path:
|
||||
|
||||
@@ -64,4 +64,15 @@ class SandboxConfig(BaseModel):
|
||||
description="Environment variables to inject into the sandbox container. Values starting with $ will be resolved from host environment variables.",
|
||||
)
|
||||
|
||||
bash_output_max_chars: int = Field(
|
||||
default=20000,
|
||||
ge=0,
|
||||
description="Maximum characters to keep from bash tool output. Output exceeding this limit is middle-truncated (head + tail), preserving the first and last half. Set to 0 to disable truncation.",
|
||||
)
|
||||
read_file_output_max_chars: int = Field(
|
||||
default=50000,
|
||||
ge=0,
|
||||
description="Maximum characters to keep from read_file tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@@ -3,6 +3,11 @@ from pathlib import Path
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def _default_repo_root() -> Path:
|
||||
"""Resolve the repo root without relying on the current working directory."""
|
||||
return Path(__file__).resolve().parents[5]
|
||||
|
||||
|
||||
class SkillsConfig(BaseModel):
|
||||
"""Configuration for skills system"""
|
||||
|
||||
@@ -26,8 +31,8 @@ class SkillsConfig(BaseModel):
|
||||
# Use configured path (can be absolute or relative)
|
||||
path = Path(self.path)
|
||||
if not path.is_absolute():
|
||||
# If relative, resolve from current working directory
|
||||
path = Path.cwd() / path
|
||||
# If relative, resolve from the repo root for deterministic behavior.
|
||||
path = _default_repo_root() / path
|
||||
return path.resolve()
|
||||
else:
|
||||
# Default: ../skills relative to backend directory
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_config_lock = threading.Lock()
|
||||
|
||||
|
||||
class TracingConfig(BaseModel):
|
||||
class LangSmithTracingConfig(BaseModel):
|
||||
"""Configuration for LangSmith tracing."""
|
||||
|
||||
enabled: bool = Field(...)
|
||||
@@ -18,9 +16,69 @@ class TracingConfig(BaseModel):
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if tracing is fully configured (enabled and has API key)."""
|
||||
return self.enabled and bool(self.api_key)
|
||||
|
||||
def validate(self) -> None:
|
||||
if self.enabled and not self.api_key:
|
||||
raise ValueError("LangSmith tracing is enabled but LANGSMITH_API_KEY (or LANGCHAIN_API_KEY) is not set.")
|
||||
|
||||
|
||||
class LangfuseTracingConfig(BaseModel):
|
||||
"""Configuration for Langfuse tracing."""
|
||||
|
||||
enabled: bool = Field(...)
|
||||
public_key: str | None = Field(...)
|
||||
secret_key: str | None = Field(...)
|
||||
host: str = Field(...)
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return self.enabled and bool(self.public_key) and bool(self.secret_key)
|
||||
|
||||
def validate(self) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
missing: list[str] = []
|
||||
if not self.public_key:
|
||||
missing.append("LANGFUSE_PUBLIC_KEY")
|
||||
if not self.secret_key:
|
||||
missing.append("LANGFUSE_SECRET_KEY")
|
||||
if missing:
|
||||
raise ValueError(f"Langfuse tracing is enabled but required settings are missing: {', '.join(missing)}")
|
||||
|
||||
|
||||
class TracingConfig(BaseModel):
|
||||
"""Tracing configuration for supported providers."""
|
||||
|
||||
langsmith: LangSmithTracingConfig = Field(...)
|
||||
langfuse: LangfuseTracingConfig = Field(...)
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.enabled_providers)
|
||||
|
||||
@property
|
||||
def explicitly_enabled_providers(self) -> list[str]:
|
||||
enabled: list[str] = []
|
||||
if self.langsmith.enabled:
|
||||
enabled.append("langsmith")
|
||||
if self.langfuse.enabled:
|
||||
enabled.append("langfuse")
|
||||
return enabled
|
||||
|
||||
@property
|
||||
def enabled_providers(self) -> list[str]:
|
||||
enabled: list[str] = []
|
||||
if self.langsmith.is_configured:
|
||||
enabled.append("langsmith")
|
||||
if self.langfuse.is_configured:
|
||||
enabled.append("langfuse")
|
||||
return enabled
|
||||
|
||||
def validate_enabled(self) -> None:
|
||||
self.langsmith.validate()
|
||||
self.langfuse.validate()
|
||||
|
||||
|
||||
_tracing_config: TracingConfig | None = None
|
||||
|
||||
@@ -29,12 +87,7 @@ _TRUTHY_VALUES = {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _env_flag_preferred(*names: str) -> bool:
|
||||
"""Return the boolean value of the first env var that is present and non-empty.
|
||||
|
||||
Accepted truthy values (case-insensitive): ``1``, ``true``, ``yes``, ``on``.
|
||||
Any other non-empty value is treated as falsy. If none of the named
|
||||
variables is set, returns ``False``.
|
||||
"""
|
||||
"""Return the boolean value of the first env var that is present and non-empty."""
|
||||
for name in names:
|
||||
value = os.environ.get(name)
|
||||
if value is not None and value.strip():
|
||||
@@ -52,43 +105,45 @@ def _first_env_value(*names: str) -> str | None:
|
||||
|
||||
|
||||
def get_tracing_config() -> TracingConfig:
|
||||
"""Get the current tracing configuration from environment variables.
|
||||
|
||||
``LANGSMITH_*`` variables take precedence over their legacy ``LANGCHAIN_*``
|
||||
counterparts. For boolean flags (``enabled``), the *first* variable that is
|
||||
present and non-empty in the priority list is the sole authority – its value
|
||||
is parsed and returned without consulting the remaining candidates. Accepted
|
||||
truthy values are ``1``, ``true``, ``yes``, and ``on`` (case-insensitive);
|
||||
any other non-empty value is treated as falsy.
|
||||
|
||||
Priority order:
|
||||
enabled : LANGSMITH_TRACING > LANGCHAIN_TRACING_V2 > LANGCHAIN_TRACING
|
||||
api_key : LANGSMITH_API_KEY > LANGCHAIN_API_KEY
|
||||
project : LANGSMITH_PROJECT > LANGCHAIN_PROJECT (default: "deer-flow")
|
||||
endpoint : LANGSMITH_ENDPOINT > LANGCHAIN_ENDPOINT (default: https://api.smith.langchain.com)
|
||||
|
||||
Returns:
|
||||
TracingConfig with current settings.
|
||||
"""
|
||||
"""Get the current tracing configuration from environment variables."""
|
||||
global _tracing_config
|
||||
if _tracing_config is not None:
|
||||
return _tracing_config
|
||||
with _config_lock:
|
||||
if _tracing_config is not None: # Double-check after acquiring lock
|
||||
if _tracing_config is not None:
|
||||
return _tracing_config
|
||||
_tracing_config = TracingConfig(
|
||||
# Keep compatibility with both legacy LANGCHAIN_* and newer LANGSMITH_* variables.
|
||||
enabled=_env_flag_preferred("LANGSMITH_TRACING", "LANGCHAIN_TRACING_V2", "LANGCHAIN_TRACING"),
|
||||
api_key=_first_env_value("LANGSMITH_API_KEY", "LANGCHAIN_API_KEY"),
|
||||
project=_first_env_value("LANGSMITH_PROJECT", "LANGCHAIN_PROJECT") or "deer-flow",
|
||||
endpoint=_first_env_value("LANGSMITH_ENDPOINT", "LANGCHAIN_ENDPOINT") or "https://api.smith.langchain.com",
|
||||
langsmith=LangSmithTracingConfig(
|
||||
enabled=_env_flag_preferred("LANGSMITH_TRACING", "LANGCHAIN_TRACING_V2", "LANGCHAIN_TRACING"),
|
||||
api_key=_first_env_value("LANGSMITH_API_KEY", "LANGCHAIN_API_KEY"),
|
||||
project=_first_env_value("LANGSMITH_PROJECT", "LANGCHAIN_PROJECT") or "deer-flow",
|
||||
endpoint=_first_env_value("LANGSMITH_ENDPOINT", "LANGCHAIN_ENDPOINT") or "https://api.smith.langchain.com",
|
||||
),
|
||||
langfuse=LangfuseTracingConfig(
|
||||
enabled=_env_flag_preferred("LANGFUSE_TRACING"),
|
||||
public_key=_first_env_value("LANGFUSE_PUBLIC_KEY"),
|
||||
secret_key=_first_env_value("LANGFUSE_SECRET_KEY"),
|
||||
host=_first_env_value("LANGFUSE_BASE_URL") or "https://cloud.langfuse.com",
|
||||
),
|
||||
)
|
||||
return _tracing_config
|
||||
|
||||
|
||||
def get_enabled_tracing_providers() -> list[str]:
|
||||
"""Return the configured tracing providers that are enabled and complete."""
|
||||
return get_tracing_config().enabled_providers
|
||||
|
||||
|
||||
def get_explicitly_enabled_tracing_providers() -> list[str]:
|
||||
"""Return tracing providers explicitly enabled by config, even if incomplete."""
|
||||
return get_tracing_config().explicitly_enabled_providers
|
||||
|
||||
|
||||
def validate_enabled_tracing_providers() -> None:
|
||||
"""Validate that any explicitly enabled providers are fully configured."""
|
||||
get_tracing_config().validate_enabled()
|
||||
|
||||
|
||||
def is_tracing_enabled() -> bool:
|
||||
"""Check if LangSmith tracing is enabled and configured.
|
||||
Returns:
|
||||
True if tracing is enabled and has an API key.
|
||||
"""
|
||||
"""Check if any tracing provider is enabled and fully configured."""
|
||||
return get_tracing_config().is_configured
|
||||
|
||||
@@ -2,8 +2,9 @@ import logging
|
||||
|
||||
from langchain.chat_models import BaseChatModel
|
||||
|
||||
from deerflow.config import get_app_config, get_tracing_config, is_tracing_enabled
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.reflection import resolve_class
|
||||
from deerflow.tracing import build_tracing_callbacks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -88,17 +89,9 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
|
||||
|
||||
model_instance = model_class(**kwargs, **model_settings_from_config)
|
||||
|
||||
if is_tracing_enabled():
|
||||
try:
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
|
||||
tracing_config = get_tracing_config()
|
||||
tracer = LangChainTracer(
|
||||
project_name=tracing_config.project,
|
||||
)
|
||||
existing_callbacks = model_instance.callbacks or []
|
||||
model_instance.callbacks = [*existing_callbacks, tracer]
|
||||
logger.debug(f"LangSmith tracing attached to model '{name}' (project='{tracing_config.project}')")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to attach LangSmith tracing to model '{name}': {e}")
|
||||
callbacks = build_tracing_callbacks()
|
||||
if callbacks:
|
||||
existing_callbacks = model_instance.callbacks or []
|
||||
model_instance.callbacks = [*existing_callbacks, *callbacks]
|
||||
logger.debug(f"Tracing attached to model '{name}' with providers={len(callbacks)}")
|
||||
return model_instance
|
||||
|
||||
@@ -123,6 +123,11 @@ async def run_agent(
|
||||
# Inject runtime context so middlewares can access thread_id
|
||||
# (langgraph-cli does this automatically; we must do it manually)
|
||||
runtime = Runtime(context={"thread_id": thread_id}, store=store)
|
||||
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
|
||||
# prefers it over ``configurable`` for thread-level data), make
|
||||
# sure ``thread_id`` is available there too.
|
||||
if "context" in config and isinstance(config["context"], dict):
|
||||
config["context"].setdefault("thread_id", thread_id)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
# Inject RunJournal as a LangChain callback handler.
|
||||
|
||||
@@ -25,6 +25,7 @@ class MemoryStreamBridge(StreamBridge):
|
||||
self._maxsize = queue_maxsize
|
||||
self._queues: dict[str, asyncio.Queue[StreamEvent]] = {}
|
||||
self._counters: dict[str, int] = {}
|
||||
self._dropped_counts: dict[str, int] = {}
|
||||
|
||||
# -- helpers ---------------------------------------------------------------
|
||||
|
||||
@@ -32,6 +33,7 @@ class MemoryStreamBridge(StreamBridge):
|
||||
if run_id not in self._queues:
|
||||
self._queues[run_id] = asyncio.Queue(maxsize=self._maxsize)
|
||||
self._counters[run_id] = 0
|
||||
self._dropped_counts[run_id] = 0
|
||||
return self._queues[run_id]
|
||||
|
||||
def _next_id(self, run_id: str) -> str:
|
||||
@@ -48,14 +50,41 @@ class MemoryStreamBridge(StreamBridge):
|
||||
try:
|
||||
await asyncio.wait_for(queue.put(entry), timeout=_PUBLISH_TIMEOUT)
|
||||
except TimeoutError:
|
||||
logger.warning("Stream bridge queue full for run %s — dropping event %s", run_id, event)
|
||||
self._dropped_counts[run_id] = self._dropped_counts.get(run_id, 0) + 1
|
||||
logger.warning(
|
||||
"Stream bridge queue full for run %s — dropping event %s (total dropped: %d)",
|
||||
run_id,
|
||||
event,
|
||||
self._dropped_counts[run_id],
|
||||
)
|
||||
|
||||
async def publish_end(self, run_id: str) -> None:
|
||||
queue = self._get_or_create_queue(run_id)
|
||||
try:
|
||||
await asyncio.wait_for(queue.put(END_SENTINEL), timeout=_PUBLISH_TIMEOUT)
|
||||
except TimeoutError:
|
||||
logger.warning("Stream bridge queue full for run %s — dropping END sentinel", run_id)
|
||||
|
||||
# END sentinel is critical — it is the only signal that allows
|
||||
# subscribers to terminate. If the queue is full we evict the
|
||||
# oldest *regular* events to make room rather than dropping END,
|
||||
# which would cause the SSE connection to hang forever and leak
|
||||
# the queue/counter resources for this run_id.
|
||||
if queue.full():
|
||||
evicted = 0
|
||||
while queue.full():
|
||||
try:
|
||||
queue.get_nowait()
|
||||
evicted += 1
|
||||
except asyncio.QueueEmpty:
|
||||
break # pragma: no cover – defensive
|
||||
if evicted:
|
||||
logger.warning(
|
||||
"Stream bridge queue full for run %s — evicted %d event(s) to guarantee END sentinel delivery",
|
||||
run_id,
|
||||
evicted,
|
||||
)
|
||||
|
||||
# After eviction the queue is guaranteed to have space, so a
|
||||
# simple non-blocking put is safe. We still use put() (which
|
||||
# blocks until space is available) as a defensive measure.
|
||||
await queue.put(END_SENTINEL)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
@@ -84,7 +113,18 @@ class MemoryStreamBridge(StreamBridge):
|
||||
await asyncio.sleep(delay)
|
||||
self._queues.pop(run_id, None)
|
||||
self._counters.pop(run_id, None)
|
||||
self._dropped_counts.pop(run_id, None)
|
||||
|
||||
async def close(self) -> None:
|
||||
self._queues.clear()
|
||||
self._counters.clear()
|
||||
self._dropped_counts.clear()
|
||||
|
||||
def dropped_count(self, run_id: str) -> int:
|
||||
"""Return the number of events dropped for *run_id*."""
|
||||
return self._dropped_counts.get(run_id, 0)
|
||||
|
||||
@property
|
||||
def dropped_total(self) -> int:
|
||||
"""Return the total number of events dropped across all runs."""
|
||||
return sum(self._dropped_counts.values())
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import threading
|
||||
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
|
||||
_FILE_OPERATION_LOCKS: dict[tuple[str, str], threading.Lock] = {}
|
||||
_FILE_OPERATION_LOCKS_GUARD = threading.Lock()
|
||||
|
||||
|
||||
def get_file_operation_lock_key(sandbox: Sandbox, path: str) -> tuple[str, str]:
|
||||
sandbox_id = getattr(sandbox, "id", None)
|
||||
if not sandbox_id:
|
||||
sandbox_id = f"instance:{id(sandbox)}"
|
||||
return sandbox_id, path
|
||||
|
||||
|
||||
def get_file_operation_lock(sandbox: Sandbox, path: str) -> threading.Lock:
|
||||
lock_key = get_file_operation_lock_key(sandbox, path)
|
||||
with _FILE_OPERATION_LOCKS_GUARD:
|
||||
lock = _FILE_OPERATION_LOCKS.get(lock_key)
|
||||
if lock is None:
|
||||
lock = threading.Lock()
|
||||
_FILE_OPERATION_LOCKS[lock_key] = lock
|
||||
return lock
|
||||
@@ -1,72 +1,6 @@
|
||||
import fnmatch
|
||||
from pathlib import Path
|
||||
|
||||
IGNORE_PATTERNS = [
|
||||
# Version Control
|
||||
".git",
|
||||
".svn",
|
||||
".hg",
|
||||
".bzr",
|
||||
# Dependencies
|
||||
"node_modules",
|
||||
"__pycache__",
|
||||
".venv",
|
||||
"venv",
|
||||
".env",
|
||||
"env",
|
||||
".tox",
|
||||
".nox",
|
||||
".eggs",
|
||||
"*.egg-info",
|
||||
"site-packages",
|
||||
# Build outputs
|
||||
"dist",
|
||||
"build",
|
||||
".next",
|
||||
".nuxt",
|
||||
".output",
|
||||
".turbo",
|
||||
"target",
|
||||
"out",
|
||||
# IDE & Editor
|
||||
".idea",
|
||||
".vscode",
|
||||
"*.swp",
|
||||
"*.swo",
|
||||
"*~",
|
||||
".project",
|
||||
".classpath",
|
||||
".settings",
|
||||
# OS generated
|
||||
".DS_Store",
|
||||
"Thumbs.db",
|
||||
"desktop.ini",
|
||||
"*.lnk",
|
||||
# Logs & temp files
|
||||
"*.log",
|
||||
"*.tmp",
|
||||
"*.temp",
|
||||
"*.bak",
|
||||
"*.cache",
|
||||
".cache",
|
||||
"logs",
|
||||
# Coverage & test artifacts
|
||||
".coverage",
|
||||
"coverage",
|
||||
".nyc_output",
|
||||
"htmlcov",
|
||||
".pytest_cache",
|
||||
".mypy_cache",
|
||||
".ruff_cache",
|
||||
]
|
||||
|
||||
|
||||
def _should_ignore(name: str) -> bool:
|
||||
"""Check if a file/directory name matches any ignore pattern."""
|
||||
for pattern in IGNORE_PATTERNS:
|
||||
if fnmatch.fnmatch(name, pattern):
|
||||
return True
|
||||
return False
|
||||
from deerflow.sandbox.search import should_ignore_name
|
||||
|
||||
|
||||
def list_dir(path: str, max_depth: int = 2) -> list[str]:
|
||||
@@ -95,7 +29,7 @@ def list_dir(path: str, max_depth: int = 2) -> list[str]:
|
||||
|
||||
try:
|
||||
for item in current_path.iterdir():
|
||||
if _should_ignore(item.name):
|
||||
if should_ignore_name(item.name):
|
||||
continue
|
||||
|
||||
post_fix = "/" if item.is_dir() else ""
|
||||
|
||||
@@ -1,11 +1,23 @@
|
||||
import errno
|
||||
import ntpath
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.sandbox.local.list_dir import list_dir
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.search import GrepMatch, find_glob_matches, find_grep_matches
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PathMapping:
|
||||
"""A path mapping from a container path to a local path with optional read-only flag."""
|
||||
|
||||
container_path: str
|
||||
local_path: str
|
||||
read_only: bool = False
|
||||
|
||||
|
||||
class LocalSandbox(Sandbox):
|
||||
@@ -39,17 +51,42 @@ class LocalSandbox(Sandbox):
|
||||
|
||||
return None
|
||||
|
||||
def __init__(self, id: str, path_mappings: dict[str, str] | None = None):
|
||||
def __init__(self, id: str, path_mappings: list[PathMapping] | None = None):
|
||||
"""
|
||||
Initialize local sandbox with optional path mappings.
|
||||
|
||||
Args:
|
||||
id: Sandbox identifier
|
||||
path_mappings: Dictionary mapping container paths to local paths
|
||||
Example: {"/mnt/skills": "/absolute/path/to/skills"}
|
||||
path_mappings: List of path mappings with optional read-only flag.
|
||||
Skills directory is read-only by default.
|
||||
"""
|
||||
super().__init__(id)
|
||||
self.path_mappings = path_mappings or {}
|
||||
self.path_mappings = path_mappings or []
|
||||
|
||||
def _is_read_only_path(self, resolved_path: str) -> bool:
|
||||
"""Check if a resolved path is under a read-only mount.
|
||||
|
||||
When multiple mappings match (nested mounts), prefer the most specific
|
||||
mapping (i.e. the one whose local_path is the longest prefix of the
|
||||
resolved path), similar to how ``_resolve_path`` handles container paths.
|
||||
"""
|
||||
resolved = str(Path(resolved_path).resolve())
|
||||
|
||||
best_mapping: PathMapping | None = None
|
||||
best_prefix_len = -1
|
||||
|
||||
for mapping in self.path_mappings:
|
||||
local_resolved = str(Path(mapping.local_path).resolve())
|
||||
if resolved == local_resolved or resolved.startswith(local_resolved + os.sep):
|
||||
prefix_len = len(local_resolved)
|
||||
if prefix_len > best_prefix_len:
|
||||
best_prefix_len = prefix_len
|
||||
best_mapping = mapping
|
||||
|
||||
if best_mapping is None:
|
||||
return False
|
||||
|
||||
return best_mapping.read_only
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
@@ -64,7 +101,9 @@ class LocalSandbox(Sandbox):
|
||||
path_str = str(path)
|
||||
|
||||
# Try each mapping (longest prefix first for more specific matches)
|
||||
for container_path, local_path in sorted(self.path_mappings.items(), key=lambda x: len(x[0]), reverse=True):
|
||||
for mapping in sorted(self.path_mappings, key=lambda m: len(m.container_path), reverse=True):
|
||||
container_path = mapping.container_path
|
||||
local_path = mapping.local_path
|
||||
if path_str == container_path or path_str.startswith(container_path + "/"):
|
||||
# Replace the container path prefix with local path
|
||||
relative = path_str[len(container_path) :].lstrip("/")
|
||||
@@ -84,15 +123,16 @@ class LocalSandbox(Sandbox):
|
||||
Returns:
|
||||
Container path if mapping exists, otherwise original path
|
||||
"""
|
||||
path_str = str(Path(path).resolve())
|
||||
normalized_path = path.replace("\\", "/")
|
||||
path_str = str(Path(normalized_path).resolve())
|
||||
|
||||
# Try each mapping (longest local path first for more specific matches)
|
||||
for container_path, local_path in sorted(self.path_mappings.items(), key=lambda x: len(x[1]), reverse=True):
|
||||
local_path_resolved = str(Path(local_path).resolve())
|
||||
if path_str.startswith(local_path_resolved):
|
||||
for mapping in sorted(self.path_mappings, key=lambda m: len(m.local_path), reverse=True):
|
||||
local_path_resolved = str(Path(mapping.local_path).resolve())
|
||||
if path_str == local_path_resolved or path_str.startswith(local_path_resolved + "/"):
|
||||
# Replace the local path prefix with container path
|
||||
relative = path_str[len(local_path_resolved) :].lstrip("/")
|
||||
resolved = f"{container_path}/{relative}" if relative else container_path
|
||||
resolved = f"{mapping.container_path}/{relative}" if relative else mapping.container_path
|
||||
return resolved
|
||||
|
||||
# No mapping found, return original path
|
||||
@@ -111,7 +151,7 @@ class LocalSandbox(Sandbox):
|
||||
import re
|
||||
|
||||
# Sort mappings by local path length (longest first) for correct prefix matching
|
||||
sorted_mappings = sorted(self.path_mappings.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
sorted_mappings = sorted(self.path_mappings, key=lambda m: len(m.local_path), reverse=True)
|
||||
|
||||
if not sorted_mappings:
|
||||
return output
|
||||
@@ -119,12 +159,11 @@ class LocalSandbox(Sandbox):
|
||||
# Create pattern that matches absolute paths
|
||||
# Match paths like /Users/... or other absolute paths
|
||||
result = output
|
||||
for container_path, local_path in sorted_mappings:
|
||||
local_path_resolved = str(Path(local_path).resolve())
|
||||
for mapping in sorted_mappings:
|
||||
# Escape the local path for use in regex
|
||||
escaped_local = re.escape(local_path_resolved)
|
||||
# Match the local path followed by optional path components
|
||||
pattern = re.compile(escaped_local + r"(?:/[^\s\"';&|<>()]*)?")
|
||||
escaped_local = re.escape(str(Path(mapping.local_path).resolve()))
|
||||
# Match the local path followed by optional path components with either separator
|
||||
pattern = re.compile(escaped_local + r"(?:[/\\][^\s\"';&|<>()]*)?")
|
||||
|
||||
def replace_match(match: re.Match) -> str:
|
||||
matched_path = match.group(0)
|
||||
@@ -147,7 +186,7 @@ class LocalSandbox(Sandbox):
|
||||
import re
|
||||
|
||||
# Sort mappings by length (longest first) for correct prefix matching
|
||||
sorted_mappings = sorted(self.path_mappings.items(), key=lambda x: len(x[0]), reverse=True)
|
||||
sorted_mappings = sorted(self.path_mappings, key=lambda m: len(m.container_path), reverse=True)
|
||||
|
||||
# Build regex pattern to match all container paths
|
||||
# Match container path followed by optional path components
|
||||
@@ -157,7 +196,7 @@ class LocalSandbox(Sandbox):
|
||||
# Create pattern that matches any of the container paths.
|
||||
# The lookahead (?=/|$|...) ensures we only match at a path-segment boundary,
|
||||
# preventing /mnt/skills from matching inside /mnt/skills-extra.
|
||||
patterns = [re.escape(container_path) + r"(?=/|$|[\s\"';&|<>()])(?:/[^\s\"';&|<>()]*)?" for container_path, _ in sorted_mappings]
|
||||
patterns = [re.escape(m.container_path) + r"(?=/|$|[\s\"';&|<>()])(?:/[^\s\"';&|<>()]*)?" for m in sorted_mappings]
|
||||
pattern = re.compile("|".join(f"({p})" for p in patterns))
|
||||
|
||||
def replace_match(match: re.Match) -> str:
|
||||
@@ -248,6 +287,8 @@ class LocalSandbox(Sandbox):
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
resolved_path = self._resolve_path(path)
|
||||
if self._is_read_only_path(resolved_path):
|
||||
raise OSError(errno.EROFS, "Read-only file system", path)
|
||||
try:
|
||||
dir_path = os.path.dirname(resolved_path)
|
||||
if dir_path:
|
||||
@@ -259,8 +300,43 @@ class LocalSandbox(Sandbox):
|
||||
# Re-raise with the original path for clearer error messages, hiding internal resolved paths
|
||||
raise type(e)(e.errno, e.strerror, path) from None
|
||||
|
||||
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
||||
resolved_path = Path(self._resolve_path(path))
|
||||
matches, truncated = find_glob_matches(resolved_path, pattern, include_dirs=include_dirs, max_results=max_results)
|
||||
return [self._reverse_resolve_path(match) for match in matches], truncated
|
||||
|
||||
def grep(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str,
|
||||
*,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> tuple[list[GrepMatch], bool]:
|
||||
resolved_path = Path(self._resolve_path(path))
|
||||
matches, truncated = find_grep_matches(
|
||||
resolved_path,
|
||||
pattern,
|
||||
glob_pattern=glob,
|
||||
literal=literal,
|
||||
case_sensitive=case_sensitive,
|
||||
max_results=max_results,
|
||||
)
|
||||
return [
|
||||
GrepMatch(
|
||||
path=self._reverse_resolve_path(match.path),
|
||||
line_number=match.line_number,
|
||||
line=match.line,
|
||||
)
|
||||
for match in matches
|
||||
], truncated
|
||||
|
||||
def update_file(self, path: str, content: bytes) -> None:
|
||||
resolved_path = self._resolve_path(path)
|
||||
if self._is_read_only_path(resolved_path):
|
||||
raise OSError(errno.EROFS, "Read-only file system", path)
|
||||
try:
|
||||
dir_path = os.path.dirname(resolved_path)
|
||||
if dir_path:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.sandbox.local.local_sandbox import LocalSandbox
|
||||
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
||||
|
||||
@@ -14,16 +15,17 @@ class LocalSandboxProvider(SandboxProvider):
|
||||
"""Initialize the local sandbox provider with path mappings."""
|
||||
self._path_mappings = self._setup_path_mappings()
|
||||
|
||||
def _setup_path_mappings(self) -> dict[str, str]:
|
||||
def _setup_path_mappings(self) -> list[PathMapping]:
|
||||
"""
|
||||
Setup path mappings for local sandbox.
|
||||
|
||||
Maps container paths to actual local paths, including skills directory.
|
||||
Maps container paths to actual local paths, including skills directory
|
||||
and any custom mounts configured in config.yaml.
|
||||
|
||||
Returns:
|
||||
Dictionary of path mappings
|
||||
List of path mappings
|
||||
"""
|
||||
mappings = {}
|
||||
mappings: list[PathMapping] = []
|
||||
|
||||
# Map skills container path to local skills directory
|
||||
try:
|
||||
@@ -35,10 +37,63 @@ class LocalSandboxProvider(SandboxProvider):
|
||||
|
||||
# Only add mapping if skills directory exists
|
||||
if skills_path.exists():
|
||||
mappings[container_path] = str(skills_path)
|
||||
mappings.append(
|
||||
PathMapping(
|
||||
container_path=container_path,
|
||||
local_path=str(skills_path),
|
||||
read_only=True, # Skills directory is always read-only
|
||||
)
|
||||
)
|
||||
|
||||
# Map custom mounts from sandbox config
|
||||
_RESERVED_CONTAINER_PREFIXES = [container_path, "/mnt/acp-workspace", "/mnt/user-data"]
|
||||
sandbox_config = config.sandbox
|
||||
if sandbox_config and sandbox_config.mounts:
|
||||
for mount in sandbox_config.mounts:
|
||||
host_path = Path(mount.host_path)
|
||||
container_path = mount.container_path.rstrip("/") or "/"
|
||||
|
||||
if not host_path.is_absolute():
|
||||
logger.warning(
|
||||
"Mount host_path must be absolute, skipping: %s -> %s",
|
||||
mount.host_path,
|
||||
mount.container_path,
|
||||
)
|
||||
continue
|
||||
|
||||
if not container_path.startswith("/"):
|
||||
logger.warning(
|
||||
"Mount container_path must be absolute, skipping: %s -> %s",
|
||||
mount.host_path,
|
||||
mount.container_path,
|
||||
)
|
||||
continue
|
||||
|
||||
# Reject mounts that conflict with reserved container paths
|
||||
if any(container_path == p or container_path.startswith(p + "/") for p in _RESERVED_CONTAINER_PREFIXES):
|
||||
logger.warning(
|
||||
"Mount container_path conflicts with reserved prefix, skipping: %s",
|
||||
mount.container_path,
|
||||
)
|
||||
continue
|
||||
# Ensure the host path exists before adding mapping
|
||||
if host_path.exists():
|
||||
mappings.append(
|
||||
PathMapping(
|
||||
container_path=container_path,
|
||||
local_path=str(host_path.resolve()),
|
||||
read_only=mount.read_only,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Mount host_path does not exist, skipping: %s -> %s",
|
||||
mount.host_path,
|
||||
mount.container_path,
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't fail if config loading fails
|
||||
logger.warning("Could not setup skills path mapping: %s", e, exc_info=True)
|
||||
logger.warning("Could not setup path mappings: %s", e, exc_info=True)
|
||||
|
||||
return mappings
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from deerflow.sandbox.search import GrepMatch
|
||||
|
||||
|
||||
class Sandbox(ABC):
|
||||
"""Abstract base class for sandbox environments"""
|
||||
@@ -61,6 +63,25 @@ class Sandbox(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
||||
"""Find paths that match a glob pattern under a root directory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def grep(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str,
|
||||
*,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> tuple[list[GrepMatch], bool]:
|
||||
"""Search for matches inside text files under a directory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_file(self, path: str, content: bytes) -> None:
|
||||
"""Update a file with binary content.
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
import fnmatch
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path, PurePosixPath
|
||||
|
||||
IGNORE_PATTERNS = [
|
||||
".git",
|
||||
".svn",
|
||||
".hg",
|
||||
".bzr",
|
||||
"node_modules",
|
||||
"__pycache__",
|
||||
".venv",
|
||||
"venv",
|
||||
".env",
|
||||
"env",
|
||||
".tox",
|
||||
".nox",
|
||||
".eggs",
|
||||
"*.egg-info",
|
||||
"site-packages",
|
||||
"dist",
|
||||
"build",
|
||||
".next",
|
||||
".nuxt",
|
||||
".output",
|
||||
".turbo",
|
||||
"target",
|
||||
"out",
|
||||
".idea",
|
||||
".vscode",
|
||||
"*.swp",
|
||||
"*.swo",
|
||||
"*~",
|
||||
".project",
|
||||
".classpath",
|
||||
".settings",
|
||||
".DS_Store",
|
||||
"Thumbs.db",
|
||||
"desktop.ini",
|
||||
"*.lnk",
|
||||
"*.log",
|
||||
"*.tmp",
|
||||
"*.temp",
|
||||
"*.bak",
|
||||
"*.cache",
|
||||
".cache",
|
||||
"logs",
|
||||
".coverage",
|
||||
"coverage",
|
||||
".nyc_output",
|
||||
"htmlcov",
|
||||
".pytest_cache",
|
||||
".mypy_cache",
|
||||
".ruff_cache",
|
||||
]
|
||||
|
||||
DEFAULT_MAX_FILE_SIZE_BYTES = 1_000_000
|
||||
DEFAULT_LINE_SUMMARY_LENGTH = 200
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GrepMatch:
|
||||
path: str
|
||||
line_number: int
|
||||
line: str
|
||||
|
||||
|
||||
def should_ignore_name(name: str) -> bool:
|
||||
for pattern in IGNORE_PATTERNS:
|
||||
if fnmatch.fnmatch(name, pattern):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def should_ignore_path(path: str) -> bool:
|
||||
return any(should_ignore_name(segment) for segment in path.replace("\\", "/").split("/") if segment)
|
||||
|
||||
|
||||
def path_matches(pattern: str, rel_path: str) -> bool:
|
||||
path = PurePosixPath(rel_path)
|
||||
if path.match(pattern):
|
||||
return True
|
||||
if pattern.startswith("**/"):
|
||||
return path.match(pattern[3:])
|
||||
return False
|
||||
|
||||
|
||||
def truncate_line(line: str, max_chars: int = DEFAULT_LINE_SUMMARY_LENGTH) -> str:
|
||||
line = line.rstrip("\n\r")
|
||||
if len(line) <= max_chars:
|
||||
return line
|
||||
return line[: max_chars - 3] + "..."
|
||||
|
||||
|
||||
def is_binary_file(path: Path, sample_size: int = 8192) -> bool:
|
||||
try:
|
||||
with path.open("rb") as handle:
|
||||
return b"\0" in handle.read(sample_size)
|
||||
except OSError:
|
||||
return True
|
||||
|
||||
|
||||
def find_glob_matches(root: Path, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
||||
matches: list[str] = []
|
||||
truncated = False
|
||||
root = root.resolve()
|
||||
|
||||
if not root.exists():
|
||||
raise FileNotFoundError(root)
|
||||
if not root.is_dir():
|
||||
raise NotADirectoryError(root)
|
||||
|
||||
for current_root, dirs, files in os.walk(root):
|
||||
dirs[:] = [name for name in dirs if not should_ignore_name(name)]
|
||||
# root is already resolved; os.walk builds current_root by joining under root,
|
||||
# so relative_to() works without an extra stat()/resolve() per directory.
|
||||
rel_dir = Path(current_root).relative_to(root)
|
||||
|
||||
if include_dirs:
|
||||
for name in dirs:
|
||||
rel_path = (rel_dir / name).as_posix()
|
||||
if path_matches(pattern, rel_path):
|
||||
matches.append(str(Path(current_root) / name))
|
||||
if len(matches) >= max_results:
|
||||
truncated = True
|
||||
return matches, truncated
|
||||
|
||||
for name in files:
|
||||
if should_ignore_name(name):
|
||||
continue
|
||||
rel_path = (rel_dir / name).as_posix()
|
||||
if path_matches(pattern, rel_path):
|
||||
matches.append(str(Path(current_root) / name))
|
||||
if len(matches) >= max_results:
|
||||
truncated = True
|
||||
return matches, truncated
|
||||
|
||||
return matches, truncated
|
||||
|
||||
|
||||
def find_grep_matches(
|
||||
root: Path,
|
||||
pattern: str,
|
||||
*,
|
||||
glob_pattern: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
max_file_size: int = DEFAULT_MAX_FILE_SIZE_BYTES,
|
||||
line_summary_length: int = DEFAULT_LINE_SUMMARY_LENGTH,
|
||||
) -> tuple[list[GrepMatch], bool]:
|
||||
matches: list[GrepMatch] = []
|
||||
truncated = False
|
||||
root = root.resolve()
|
||||
|
||||
if not root.exists():
|
||||
raise FileNotFoundError(root)
|
||||
if not root.is_dir():
|
||||
raise NotADirectoryError(root)
|
||||
|
||||
regex_source = re.escape(pattern) if literal else pattern
|
||||
flags = 0 if case_sensitive else re.IGNORECASE
|
||||
regex = re.compile(regex_source, flags)
|
||||
|
||||
# Skip lines longer than this to prevent ReDoS on minified / no-newline files.
|
||||
_max_line_chars = line_summary_length * 10
|
||||
|
||||
for current_root, dirs, files in os.walk(root):
|
||||
dirs[:] = [name for name in dirs if not should_ignore_name(name)]
|
||||
rel_dir = Path(current_root).relative_to(root)
|
||||
|
||||
for name in files:
|
||||
if should_ignore_name(name):
|
||||
continue
|
||||
|
||||
candidate_path = Path(current_root) / name
|
||||
rel_path = (rel_dir / name).as_posix()
|
||||
|
||||
if glob_pattern is not None and not path_matches(glob_pattern, rel_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
if candidate_path.is_symlink():
|
||||
continue
|
||||
file_path = candidate_path.resolve()
|
||||
if not file_path.is_relative_to(root):
|
||||
continue
|
||||
if file_path.stat().st_size > max_file_size or is_binary_file(file_path):
|
||||
continue
|
||||
with file_path.open(encoding="utf-8", errors="replace") as handle:
|
||||
for line_number, line in enumerate(handle, start=1):
|
||||
if len(line) > _max_line_chars:
|
||||
continue
|
||||
if regex.search(line):
|
||||
matches.append(
|
||||
GrepMatch(
|
||||
path=str(file_path),
|
||||
line_number=line_number,
|
||||
line=truncate_line(line, line_summary_length),
|
||||
)
|
||||
)
|
||||
if len(matches) >= max_results:
|
||||
truncated = True
|
||||
return matches, truncated
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
return matches, truncated
|
||||
@@ -7,17 +7,21 @@ from langchain.tools import ToolRuntime, tool
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState, ThreadState
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
|
||||
from deerflow.sandbox.exceptions import (
|
||||
SandboxError,
|
||||
SandboxNotFoundError,
|
||||
SandboxRuntimeError,
|
||||
)
|
||||
from deerflow.sandbox.file_operation_lock import get_file_operation_lock
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
from deerflow.sandbox.search import GrepMatch
|
||||
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
|
||||
|
||||
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])/(?:[^\s\"'`;&|<>()]+)")
|
||||
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
|
||||
_FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE)
|
||||
_LOCAL_BASH_SYSTEM_PATH_PREFIXES = (
|
||||
"/bin/",
|
||||
"/usr/bin/",
|
||||
@@ -29,6 +33,10 @@ _LOCAL_BASH_SYSTEM_PATH_PREFIXES = (
|
||||
|
||||
_DEFAULT_SKILLS_CONTAINER_PATH = "/mnt/skills"
|
||||
_ACP_WORKSPACE_VIRTUAL_PATH = "/mnt/acp-workspace"
|
||||
_DEFAULT_GLOB_MAX_RESULTS = 200
|
||||
_MAX_GLOB_MAX_RESULTS = 1000
|
||||
_DEFAULT_GREP_MAX_RESULTS = 100
|
||||
_MAX_GREP_MAX_RESULTS = 500
|
||||
|
||||
|
||||
def _get_skills_container_path() -> str:
|
||||
@@ -111,6 +119,54 @@ def _is_acp_workspace_path(path: str) -> bool:
|
||||
return path == _ACP_WORKSPACE_VIRTUAL_PATH or path.startswith(f"{_ACP_WORKSPACE_VIRTUAL_PATH}/")
|
||||
|
||||
|
||||
def _get_custom_mounts():
|
||||
"""Get custom volume mounts from sandbox config.
|
||||
|
||||
Result is cached after the first successful config load. If config loading
|
||||
fails an empty list is returned *without* caching so that a later call can
|
||||
pick up the real value once the config is available.
|
||||
"""
|
||||
cached = getattr(_get_custom_mounts, "_cached", None)
|
||||
if cached is not None:
|
||||
return cached
|
||||
try:
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = get_app_config()
|
||||
mounts = []
|
||||
if config.sandbox and config.sandbox.mounts:
|
||||
# Only include mounts whose host_path exists, consistent with
|
||||
# LocalSandboxProvider._setup_path_mappings() which also filters
|
||||
# by host_path.exists().
|
||||
mounts = [m for m in config.sandbox.mounts if Path(m.host_path).exists()]
|
||||
_get_custom_mounts._cached = mounts # type: ignore[attr-defined]
|
||||
return mounts
|
||||
except Exception:
|
||||
# If config loading fails, return an empty list without caching so that
|
||||
# a later call can retry once the config is available.
|
||||
return []
|
||||
|
||||
|
||||
def _is_custom_mount_path(path: str) -> bool:
|
||||
"""Check if path is under a custom mount container_path."""
|
||||
for mount in _get_custom_mounts():
|
||||
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_custom_mount_for_path(path: str):
|
||||
"""Get the mount config matching this path (longest prefix first)."""
|
||||
best = None
|
||||
for mount in _get_custom_mounts():
|
||||
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
|
||||
if best is None or len(mount.container_path) > len(best.container_path):
|
||||
best = mount
|
||||
return best
|
||||
|
||||
|
||||
def _extract_thread_id_from_thread_data(thread_data: "ThreadDataState | None") -> str | None:
|
||||
"""Extract thread_id from thread_data by inspecting workspace_path.
|
||||
|
||||
@@ -243,6 +299,69 @@ def _get_mcp_allowed_paths() -> list[str]:
|
||||
return allowed_paths
|
||||
|
||||
|
||||
def _get_tool_config_int(name: str, key: str, default: int) -> int:
|
||||
try:
|
||||
tool_config = get_app_config().get_tool_config(name)
|
||||
if tool_config is not None and key in tool_config.model_extra:
|
||||
value = tool_config.model_extra.get(key)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
except Exception:
|
||||
pass
|
||||
return default
|
||||
|
||||
|
||||
def _clamp_max_results(value: int, *, default: int, upper_bound: int) -> int:
|
||||
if value <= 0:
|
||||
return default
|
||||
return min(value, upper_bound)
|
||||
|
||||
|
||||
def _resolve_max_results(name: str, requested: int, *, default: int, upper_bound: int) -> int:
|
||||
requested_max_results = _clamp_max_results(requested, default=default, upper_bound=upper_bound)
|
||||
configured_max_results = _clamp_max_results(
|
||||
_get_tool_config_int(name, "max_results", default),
|
||||
default=default,
|
||||
upper_bound=upper_bound,
|
||||
)
|
||||
return min(requested_max_results, configured_max_results)
|
||||
|
||||
|
||||
def _resolve_local_read_path(path: str, thread_data: ThreadDataState) -> str:
|
||||
validate_local_tool_path(path, thread_data, read_only=True)
|
||||
if _is_skills_path(path):
|
||||
return _resolve_skills_path(path)
|
||||
if _is_acp_workspace_path(path):
|
||||
return _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
|
||||
return _resolve_and_validate_user_data_path(path, thread_data)
|
||||
|
||||
|
||||
def _format_glob_results(root_path: str, matches: list[str], truncated: bool) -> str:
|
||||
if not matches:
|
||||
return f"No files matched under {root_path}"
|
||||
|
||||
lines = [f"Found {len(matches)} paths under {root_path}"]
|
||||
if truncated:
|
||||
lines[0] += f" (showing first {len(matches)})"
|
||||
lines.extend(f"{index}. {path}" for index, path in enumerate(matches, start=1))
|
||||
if truncated:
|
||||
lines.append("Results truncated. Narrow the path or pattern to see fewer matches.")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_grep_results(root_path: str, matches: list[GrepMatch], truncated: bool) -> str:
|
||||
if not matches:
|
||||
return f"No matches found under {root_path}"
|
||||
|
||||
lines = [f"Found {len(matches)} matches under {root_path}"]
|
||||
if truncated:
|
||||
lines[0] += f" (showing first {len(matches)})"
|
||||
lines.extend(f"{match.path}:{match.line_number}: {match.line}" for match in matches)
|
||||
if truncated:
|
||||
lines.append("Results truncated. Narrow the path or add a glob filter.")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _path_variants(path: str) -> set[str]:
|
||||
return {path, path.replace("\\", "/"), path.replace("/", "\\")}
|
||||
|
||||
@@ -377,6 +496,8 @@ def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None)
|
||||
|
||||
result = pattern.sub(replace_acp, result)
|
||||
|
||||
# Custom mount host paths are masked by LocalSandbox._reverse_resolve_paths_in_output()
|
||||
|
||||
# Mask user-data host paths
|
||||
if thread_data is None:
|
||||
return result
|
||||
@@ -425,6 +546,7 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
|
||||
- ``/mnt/user-data/*`` — always allowed (read + write)
|
||||
- ``/mnt/skills/*`` — allowed only when *read_only* is True
|
||||
- ``/mnt/acp-workspace/*`` — allowed only when *read_only* is True
|
||||
- Custom mount paths (from config.yaml) — respects per-mount ``read_only`` flag
|
||||
|
||||
Args:
|
||||
path: The virtual path to validate.
|
||||
@@ -456,7 +578,14 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
|
||||
if path.startswith(f"{VIRTUAL_PATH_PREFIX}/"):
|
||||
return
|
||||
|
||||
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path()}/, or {_ACP_WORKSPACE_VIRTUAL_PATH}/ are allowed")
|
||||
# Custom mount paths — respect read_only config
|
||||
if _is_custom_mount_path(path):
|
||||
mount = _get_custom_mount_for_path(path)
|
||||
if mount and mount.read_only and not read_only:
|
||||
raise PermissionError(f"Write access to read-only mount is not allowed: {path}")
|
||||
return
|
||||
|
||||
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path()}/, {_ACP_WORKSPACE_VIRTUAL_PATH}/, or configured mount paths are allowed")
|
||||
|
||||
|
||||
def _validate_resolved_user_data_path(resolved: Path, thread_data: ThreadDataState) -> None:
|
||||
@@ -506,15 +635,21 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
|
||||
boundary and must not be treated as isolation from the host filesystem.
|
||||
|
||||
In local mode, commands must use virtual paths under /mnt/user-data for
|
||||
user data access. Skills paths under /mnt/skills and ACP workspace paths
|
||||
under /mnt/acp-workspace are allowed (path-traversal checks only; write
|
||||
prevention for bash commands is not enforced here).
|
||||
user data access. Skills paths under /mnt/skills, ACP workspace paths
|
||||
under /mnt/acp-workspace, and custom mount container paths (configured in
|
||||
config.yaml) are allowed (path-traversal checks only; write prevention
|
||||
for bash commands is not enforced here).
|
||||
A small allowlist of common system path prefixes is kept for executable
|
||||
and device references (e.g. /bin/sh, /dev/null).
|
||||
"""
|
||||
if thread_data is None:
|
||||
raise SandboxRuntimeError("Thread data not available for local sandbox")
|
||||
|
||||
# Block file:// URLs which bypass the absolute-path regex but allow local file exfiltration
|
||||
file_url_match = _FILE_URL_PATTERN.search(command)
|
||||
if file_url_match:
|
||||
raise PermissionError(f"Unsafe file:// URL in command: {file_url_match.group()}. Use paths under {VIRTUAL_PATH_PREFIX}")
|
||||
|
||||
unsafe_paths: list[str] = []
|
||||
allowed_paths = _get_mcp_allowed_paths()
|
||||
|
||||
@@ -538,6 +673,11 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
|
||||
_reject_path_traversal(absolute_path)
|
||||
continue
|
||||
|
||||
# Allow custom mount container paths
|
||||
if _is_custom_mount_path(absolute_path):
|
||||
_reject_path_traversal(absolute_path)
|
||||
continue
|
||||
|
||||
if any(absolute_path == prefix.rstrip("/") or absolute_path.startswith(prefix) for prefix in _LOCAL_BASH_SYSTEM_PATH_PREFIXES):
|
||||
continue
|
||||
|
||||
@@ -582,6 +722,8 @@ def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState
|
||||
|
||||
result = acp_pattern.sub(replace_acp_match, result)
|
||||
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_paths_in_command()
|
||||
|
||||
# Replace user-data paths
|
||||
if VIRTUAL_PATH_PREFIX in result and thread_data is not None:
|
||||
pattern = re.compile(rf"{re.escape(VIRTUAL_PATH_PREFIX)}(/[^\s\"';&|<>()]*)?")
|
||||
@@ -757,6 +899,59 @@ def ensure_thread_directories_exist(runtime: ToolRuntime[ContextT, ThreadState]
|
||||
runtime.state["thread_directories_created"] = True
|
||||
|
||||
|
||||
def _truncate_bash_output(output: str, max_chars: int) -> str:
|
||||
"""Middle-truncate bash output, preserving head and tail (50/50 split).
|
||||
|
||||
bash output may have errors at either end (stderr/stdout ordering is
|
||||
non-deterministic), so both ends are preserved equally.
|
||||
|
||||
The returned string (including the truncation marker) is guaranteed to be
|
||||
no longer than max_chars characters. Pass max_chars=0 to disable truncation
|
||||
and return the full output unchanged.
|
||||
"""
|
||||
if max_chars == 0:
|
||||
return output
|
||||
if len(output) <= max_chars:
|
||||
return output
|
||||
total_len = len(output)
|
||||
# Compute the exact worst-case marker length: skipped chars is at most
|
||||
# total_len, so this is a tight upper bound.
|
||||
marker_max_len = len(f"\n... [middle truncated: {total_len} chars skipped] ...\n")
|
||||
kept = max(0, max_chars - marker_max_len)
|
||||
if kept == 0:
|
||||
return output[:max_chars]
|
||||
head_len = kept // 2
|
||||
tail_len = kept - head_len
|
||||
skipped = total_len - kept
|
||||
marker = f"\n... [middle truncated: {skipped} chars skipped] ...\n"
|
||||
return f"{output[:head_len]}{marker}{output[-tail_len:] if tail_len > 0 else ''}"
|
||||
|
||||
|
||||
def _truncate_read_file_output(output: str, max_chars: int) -> str:
|
||||
"""Head-truncate read_file output, preserving the beginning of the file.
|
||||
|
||||
Source code and documents are read top-to-bottom; the head contains the
|
||||
most context (imports, class definitions, function signatures).
|
||||
|
||||
The returned string (including the truncation marker) is guaranteed to be
|
||||
no longer than max_chars characters. Pass max_chars=0 to disable truncation
|
||||
and return the full output unchanged.
|
||||
"""
|
||||
if max_chars == 0:
|
||||
return output
|
||||
if len(output) <= max_chars:
|
||||
return output
|
||||
total = len(output)
|
||||
# Compute the exact worst-case marker length: both numeric fields are at
|
||||
# their maximum (total chars), so this is a tight upper bound.
|
||||
marker_max_len = len(f"\n... [truncated: showing first {total} of {total} chars. Use start_line/end_line to read a specific range] ...")
|
||||
kept = max(0, max_chars - marker_max_len)
|
||||
if kept == 0:
|
||||
return output[:max_chars]
|
||||
marker = f"\n... [truncated: showing first {kept} of {total} chars. Use start_line/end_line to read a specific range] ..."
|
||||
return f"{output[:kept]}{marker}"
|
||||
|
||||
|
||||
@tool("bash", parse_docstring=True)
|
||||
def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str:
|
||||
"""Execute a bash command in a Linux environment.
|
||||
@@ -781,9 +976,23 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
|
||||
command = replace_virtual_paths_in_command(command, thread_data)
|
||||
command = _apply_cwd_prefix(command, thread_data)
|
||||
output = sandbox.execute_command(command)
|
||||
return mask_local_paths_in_output(output, thread_data)
|
||||
try:
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
sandbox_cfg = get_app_config().sandbox
|
||||
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
|
||||
except Exception:
|
||||
max_chars = 20000
|
||||
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
return sandbox.execute_command(command)
|
||||
try:
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
sandbox_cfg = get_app_config().sandbox
|
||||
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
|
||||
except Exception:
|
||||
max_chars = 20000
|
||||
return _truncate_bash_output(sandbox.execute_command(command), max_chars)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except PermissionError as e:
|
||||
@@ -811,8 +1020,9 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
||||
path = _resolve_skills_path(path)
|
||||
elif _is_acp_workspace_path(path):
|
||||
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
|
||||
else:
|
||||
elif not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
children = sandbox.list_dir(path)
|
||||
if not children:
|
||||
return "(empty)"
|
||||
@@ -827,6 +1037,126 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
||||
return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
@tool("glob", parse_docstring=True)
|
||||
def glob_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
include_dirs: bool = False,
|
||||
max_results: int = _DEFAULT_GLOB_MAX_RESULTS,
|
||||
) -> str:
|
||||
"""Find files or directories that match a glob pattern under a root directory.
|
||||
|
||||
Args:
|
||||
description: Explain why you are searching for these paths in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
pattern: The glob pattern to match relative to the root path, for example `**/*.py`.
|
||||
path: The **absolute** root directory to search under.
|
||||
include_dirs: Whether matching directories should also be returned. Default is False.
|
||||
max_results: Maximum number of paths to return. Default is 200.
|
||||
"""
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
effective_max_results = _resolve_max_results(
|
||||
"glob",
|
||||
max_results,
|
||||
default=_DEFAULT_GLOB_MAX_RESULTS,
|
||||
upper_bound=_MAX_GLOB_MAX_RESULTS,
|
||||
)
|
||||
thread_data = None
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
if thread_data is None:
|
||||
raise SandboxRuntimeError("Thread data not available for local sandbox")
|
||||
path = _resolve_local_read_path(path, thread_data)
|
||||
matches, truncated = sandbox.glob(path, pattern, include_dirs=include_dirs, max_results=effective_max_results)
|
||||
if thread_data is not None:
|
||||
matches = [mask_local_paths_in_output(match, thread_data) for match in matches]
|
||||
return _format_glob_results(requested_path, matches, truncated)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except FileNotFoundError:
|
||||
return f"Error: Directory not found: {requested_path}"
|
||||
except NotADirectoryError:
|
||||
return f"Error: Path is not a directory: {requested_path}"
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
@tool("grep", parse_docstring=True)
|
||||
def grep_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = _DEFAULT_GREP_MAX_RESULTS,
|
||||
) -> str:
|
||||
"""Search for matching lines inside text files under a root directory.
|
||||
|
||||
Args:
|
||||
description: Explain why you are searching file contents in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
pattern: The string or regex pattern to search for.
|
||||
path: The **absolute** root directory to search under.
|
||||
glob: Optional glob filter for candidate files, for example `**/*.py`.
|
||||
literal: Whether to treat `pattern` as a plain string. Default is False.
|
||||
case_sensitive: Whether matching is case-sensitive. Default is False.
|
||||
max_results: Maximum number of matching lines to return. Default is 100.
|
||||
"""
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
effective_max_results = _resolve_max_results(
|
||||
"grep",
|
||||
max_results,
|
||||
default=_DEFAULT_GREP_MAX_RESULTS,
|
||||
upper_bound=_MAX_GREP_MAX_RESULTS,
|
||||
)
|
||||
thread_data = None
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
if thread_data is None:
|
||||
raise SandboxRuntimeError("Thread data not available for local sandbox")
|
||||
path = _resolve_local_read_path(path, thread_data)
|
||||
matches, truncated = sandbox.grep(
|
||||
path,
|
||||
pattern,
|
||||
glob=glob,
|
||||
literal=literal,
|
||||
case_sensitive=case_sensitive,
|
||||
max_results=effective_max_results,
|
||||
)
|
||||
if thread_data is not None:
|
||||
matches = [
|
||||
GrepMatch(
|
||||
path=mask_local_paths_in_output(match.path, thread_data),
|
||||
line_number=match.line_number,
|
||||
line=match.line,
|
||||
)
|
||||
for match in matches
|
||||
]
|
||||
return _format_grep_results(requested_path, matches, truncated)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except FileNotFoundError:
|
||||
return f"Error: Directory not found: {requested_path}"
|
||||
except NotADirectoryError:
|
||||
return f"Error: Path is not a directory: {requested_path}"
|
||||
except re.error as e:
|
||||
return f"Error: Invalid regex pattern: {e}"
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
@tool("read_file", parse_docstring=True)
|
||||
def read_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
@@ -854,14 +1184,22 @@ def read_file_tool(
|
||||
path = _resolve_skills_path(path)
|
||||
elif _is_acp_workspace_path(path):
|
||||
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
|
||||
else:
|
||||
elif not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
content = sandbox.read_file(path)
|
||||
if not content:
|
||||
return "(empty)"
|
||||
if start_line is not None and end_line is not None:
|
||||
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
|
||||
return content
|
||||
try:
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
sandbox_cfg = get_app_config().sandbox
|
||||
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
|
||||
except Exception:
|
||||
max_chars = 50000
|
||||
return _truncate_read_file_output(content, max_chars)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except FileNotFoundError:
|
||||
@@ -896,8 +1234,11 @@ def write_file_tool(
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
validate_local_tool_path(path, thread_data)
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
sandbox.write_file(path, content, append)
|
||||
if not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
with get_file_operation_lock(sandbox, path):
|
||||
sandbox.write_file(path, content, append)
|
||||
return "OK"
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
@@ -937,17 +1278,20 @@ def str_replace_tool(
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
validate_local_tool_path(path, thread_data)
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
content = sandbox.read_file(path)
|
||||
if not content:
|
||||
return "OK"
|
||||
if old_str not in content:
|
||||
return f"Error: String to replace not found in file: {requested_path}"
|
||||
if replace_all:
|
||||
content = content.replace(old_str, new_str)
|
||||
else:
|
||||
content = content.replace(old_str, new_str, 1)
|
||||
sandbox.write_file(path, content)
|
||||
if not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
with get_file_operation_lock(sandbox, path):
|
||||
content = sandbox.read_file(path)
|
||||
if not content:
|
||||
return "OK"
|
||||
if old_str not in content:
|
||||
return f"Error: String to replace not found in file: {requested_path}"
|
||||
if replace_all:
|
||||
content = content.replace(old_str, new_str)
|
||||
else:
|
||||
content = content.replace(old_str, new_str, 1)
|
||||
sandbox.write_file(path, content)
|
||||
return "OK"
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
@@ -33,15 +33,72 @@ def parse_skill_file(skill_file: Path, category: str, relative_path: Path | None
|
||||
|
||||
front_matter = front_matter_match.group(1)
|
||||
|
||||
# Parse YAML front matter (simple key-value parsing)
|
||||
# Parse YAML front matter with basic multiline string support
|
||||
metadata = {}
|
||||
for line in front_matter.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
lines = front_matter.split("\n")
|
||||
current_key = None
|
||||
current_value = []
|
||||
is_multiline = False
|
||||
multiline_style = None
|
||||
indent_level = None
|
||||
|
||||
for line in lines:
|
||||
if is_multiline:
|
||||
if not line.strip():
|
||||
current_value.append("")
|
||||
continue
|
||||
|
||||
current_indent = len(line) - len(line.lstrip())
|
||||
|
||||
if indent_level is None:
|
||||
if current_indent > 0:
|
||||
indent_level = current_indent
|
||||
current_value.append(line[indent_level:])
|
||||
continue
|
||||
elif current_indent >= indent_level:
|
||||
current_value.append(line[indent_level:])
|
||||
continue
|
||||
|
||||
# If we reach here, it's either a new key or the end of multiline
|
||||
if current_key and is_multiline:
|
||||
if multiline_style == "|":
|
||||
metadata[current_key] = "\n".join(current_value).rstrip()
|
||||
else:
|
||||
text = "\n".join(current_value).rstrip()
|
||||
# Replace single newlines with spaces for folded blocks
|
||||
metadata[current_key] = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
|
||||
|
||||
current_key = None
|
||||
current_value = []
|
||||
is_multiline = False
|
||||
multiline_style = None
|
||||
indent_level = None
|
||||
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
if ":" in line:
|
||||
# Handle nested dicts simply by ignoring indentation for now,
|
||||
# or just extracting top-level keys
|
||||
key, value = line.split(":", 1)
|
||||
metadata[key.strip()] = value.strip()
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
if value in (">", "|"):
|
||||
current_key = key
|
||||
is_multiline = True
|
||||
multiline_style = value
|
||||
current_value = []
|
||||
indent_level = None
|
||||
else:
|
||||
metadata[key] = value
|
||||
|
||||
if current_key and is_multiline:
|
||||
if multiline_style == "|":
|
||||
metadata[current_key] = "\n".join(current_value).rstrip()
|
||||
else:
|
||||
text = "\n".join(current_value).rstrip()
|
||||
metadata[current_key] = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
|
||||
|
||||
# Extract required fields
|
||||
name = metadata.get("name")
|
||||
|
||||
@@ -57,6 +57,42 @@ def _build_mcp_servers() -> dict[str, dict[str, Any]]:
|
||||
return build_servers_config(ExtensionsConfig.from_file())
|
||||
|
||||
|
||||
def _build_acp_mcp_servers() -> list[dict[str, Any]]:
|
||||
"""Build ACP ``mcpServers`` payload for ``new_session``.
|
||||
|
||||
The ACP client expects a list of server objects, while DeerFlow's MCP helper
|
||||
returns a name -> config mapping for the LangChain MCP adapter. This helper
|
||||
converts the enabled servers into the ACP wire format.
|
||||
"""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
enabled_servers = extensions_config.get_enabled_mcp_servers()
|
||||
|
||||
mcp_servers: list[dict[str, Any]] = []
|
||||
for name, server_config in enabled_servers.items():
|
||||
transport_type = server_config.type or "stdio"
|
||||
payload: dict[str, Any] = {"name": name, "type": transport_type}
|
||||
|
||||
if transport_type == "stdio":
|
||||
if not server_config.command:
|
||||
raise ValueError(f"MCP server '{name}' with stdio transport requires 'command' field")
|
||||
payload["command"] = server_config.command
|
||||
payload["args"] = server_config.args
|
||||
payload["env"] = [{"name": key, "value": value} for key, value in server_config.env.items()]
|
||||
elif transport_type in ("http", "sse"):
|
||||
if not server_config.url:
|
||||
raise ValueError(f"MCP server '{name}' with {transport_type} transport requires 'url' field")
|
||||
payload["url"] = server_config.url
|
||||
payload["headers"] = [{"name": key, "value": value} for key, value in server_config.headers.items()]
|
||||
else:
|
||||
raise ValueError(f"MCP server '{name}' has unsupported transport type: {transport_type}")
|
||||
|
||||
mcp_servers.append(payload)
|
||||
|
||||
return mcp_servers
|
||||
|
||||
|
||||
def _build_permission_response(options: list[Any], *, auto_approve: bool) -> Any:
|
||||
"""Build an ACP permission response.
|
||||
|
||||
@@ -173,7 +209,15 @@ def build_invoke_acp_agent_tool(agents: dict) -> BaseTool:
|
||||
cmd = agent_config.command
|
||||
args = agent_config.args or []
|
||||
physical_cwd = _get_work_dir(thread_id)
|
||||
mcp_servers = _build_mcp_servers()
|
||||
try:
|
||||
mcp_servers = _build_acp_mcp_servers()
|
||||
except ValueError as exc:
|
||||
logger.warning(
|
||||
"Invalid MCP server configuration for ACP agent '%s'; continuing without MCP servers: %s",
|
||||
agent,
|
||||
exc,
|
||||
)
|
||||
mcp_servers = []
|
||||
agent_env: dict[str, str] | None = None
|
||||
if agent_config.env:
|
||||
agent_env = {k: (os.environ.get(v[1:], "") if v.startswith("$") else v) for k, v in agent_config.env.items()}
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .factory import build_tracing_callbacks
|
||||
|
||||
__all__ = ["build_tracing_callbacks"]
|
||||
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config import (
|
||||
get_enabled_tracing_providers,
|
||||
get_tracing_config,
|
||||
validate_enabled_tracing_providers,
|
||||
)
|
||||
|
||||
|
||||
def _create_langsmith_tracer(config) -> Any:
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
|
||||
return LangChainTracer(project_name=config.project)
|
||||
|
||||
|
||||
def _create_langfuse_handler(config) -> Any:
|
||||
from langfuse import Langfuse
|
||||
from langfuse.langchain import CallbackHandler as LangfuseCallbackHandler
|
||||
|
||||
# langfuse>=4 initializes project-specific credentials through the client
|
||||
# singleton; the LangChain callback then attaches to that configured client.
|
||||
Langfuse(
|
||||
secret_key=config.secret_key,
|
||||
public_key=config.public_key,
|
||||
host=config.host,
|
||||
)
|
||||
return LangfuseCallbackHandler(public_key=config.public_key)
|
||||
|
||||
|
||||
def build_tracing_callbacks() -> list[Any]:
|
||||
"""Build callbacks for all explicitly enabled tracing providers."""
|
||||
validate_enabled_tracing_providers()
|
||||
enabled_providers = get_enabled_tracing_providers()
|
||||
if not enabled_providers:
|
||||
return []
|
||||
|
||||
tracing_config = get_tracing_config()
|
||||
callbacks: list[Any] = []
|
||||
|
||||
for provider in enabled_providers:
|
||||
if provider == "langsmith":
|
||||
try:
|
||||
callbacks.append(_create_langsmith_tracer(tracing_config.langsmith))
|
||||
except Exception as exc: # pragma: no cover - exercised via tests with monkeypatch
|
||||
raise RuntimeError(f"LangSmith tracing initialization failed: {exc}") from exc
|
||||
elif provider == "langfuse":
|
||||
try:
|
||||
callbacks.append(_create_langfuse_handler(tracing_config.langfuse))
|
||||
except Exception as exc: # pragma: no cover - exercised via tests with monkeypatch
|
||||
raise RuntimeError(f"Langfuse tracing initialization failed: {exc}") from exc
|
||||
|
||||
return callbacks
|
||||
@@ -1,10 +1,22 @@
|
||||
"""File conversion utilities.
|
||||
|
||||
Converts document files (PDF, PPT, Excel, Word) to Markdown using markitdown.
|
||||
Converts document files (PDF, PPT, Excel, Word) to Markdown.
|
||||
|
||||
PDF conversion strategy (auto mode):
|
||||
1. Try pymupdf4llm if installed — better heading detection, faster on most files.
|
||||
2. If output is suspiciously short (< _MIN_CHARS_PER_PAGE chars/page, or < 200 chars
|
||||
total when page count is unavailable), treat as image-based and fall back to MarkItDown.
|
||||
3. If pymupdf4llm is not installed, use MarkItDown directly (existing behaviour).
|
||||
|
||||
Large files (> ASYNC_THRESHOLD_BYTES) are converted in a thread pool via
|
||||
asyncio.to_thread() to avoid blocking the event loop (fixes #1569).
|
||||
|
||||
No FastAPI or HTTP dependencies — pure utility functions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,28 +32,278 @@ CONVERTIBLE_EXTENSIONS = {
|
||||
".docx",
|
||||
}
|
||||
|
||||
# Files larger than this threshold are converted in a background thread.
|
||||
# Small files complete in < 1s synchronously; spawning a thread adds unnecessary
|
||||
# scheduling overhead for them.
|
||||
_ASYNC_THRESHOLD_BYTES = 1 * 1024 * 1024 # 1 MB
|
||||
|
||||
# If pymupdf4llm produces fewer characters *per page* than this threshold,
|
||||
# the PDF is likely image-based or encrypted — fall back to MarkItDown.
|
||||
# Rationale: normal text PDFs yield 200-2000 chars/page; image-based PDFs
|
||||
# yield close to 0. 50 chars/page gives a wide safety margin.
|
||||
# Falls back to absolute 200-char check when page count is unavailable.
|
||||
_MIN_CHARS_PER_PAGE = 50
|
||||
|
||||
|
||||
def _pymupdf_output_too_sparse(text: str, file_path: Path) -> bool:
|
||||
"""Return True if pymupdf4llm output is suspiciously short (image-based PDF).
|
||||
|
||||
Uses chars-per-page rather than an absolute threshold so that both short
|
||||
documents (few pages, few chars) and long documents (many pages, many chars)
|
||||
are handled correctly.
|
||||
"""
|
||||
chars = len(text.strip())
|
||||
doc = None
|
||||
pages: int | None = None
|
||||
try:
|
||||
import pymupdf
|
||||
|
||||
doc = pymupdf.open(str(file_path))
|
||||
pages = len(doc)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if doc is not None:
|
||||
try:
|
||||
doc.close()
|
||||
except Exception:
|
||||
pass
|
||||
if pages is not None and pages > 0:
|
||||
return (chars / pages) < _MIN_CHARS_PER_PAGE
|
||||
# Fallback: absolute threshold when page count is unavailable
|
||||
return chars < 200
|
||||
|
||||
|
||||
def _convert_pdf_with_pymupdf4llm(file_path: Path) -> str | None:
|
||||
"""Attempt PDF conversion with pymupdf4llm.
|
||||
|
||||
Returns the markdown text, or None if pymupdf4llm is not installed or
|
||||
if conversion fails (e.g. encrypted/corrupt PDF).
|
||||
"""
|
||||
try:
|
||||
import pymupdf4llm
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
try:
|
||||
return pymupdf4llm.to_markdown(str(file_path))
|
||||
except Exception:
|
||||
logger.exception("pymupdf4llm failed to convert %s; falling back to MarkItDown", file_path.name)
|
||||
return None
|
||||
|
||||
|
||||
def _convert_with_markitdown(file_path: Path) -> str:
|
||||
"""Convert any supported file to markdown text using MarkItDown."""
|
||||
from markitdown import MarkItDown
|
||||
|
||||
md = MarkItDown()
|
||||
return md.convert(str(file_path)).text_content
|
||||
|
||||
|
||||
def _do_convert(file_path: Path, pdf_converter: str) -> str:
|
||||
"""Synchronous conversion — called directly or via asyncio.to_thread.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file.
|
||||
pdf_converter: "auto" | "pymupdf4llm" | "markitdown"
|
||||
"""
|
||||
is_pdf = file_path.suffix.lower() == ".pdf"
|
||||
|
||||
if is_pdf and pdf_converter != "markitdown":
|
||||
# Try pymupdf4llm first (auto or explicit)
|
||||
pymupdf_text = _convert_pdf_with_pymupdf4llm(file_path)
|
||||
|
||||
if pymupdf_text is not None:
|
||||
# pymupdf4llm is installed
|
||||
if pdf_converter == "pymupdf4llm":
|
||||
# Explicit — use as-is regardless of output length
|
||||
return pymupdf_text
|
||||
# auto mode: fall back if output looks like a failed parse.
|
||||
# Use chars-per-page to distinguish image-based PDFs (near 0) from
|
||||
# legitimately short documents.
|
||||
if not _pymupdf_output_too_sparse(pymupdf_text, file_path):
|
||||
return pymupdf_text
|
||||
logger.warning(
|
||||
"pymupdf4llm produced only %d chars for %s (likely image-based PDF); falling back to MarkItDown",
|
||||
len(pymupdf_text.strip()),
|
||||
file_path.name,
|
||||
)
|
||||
# pymupdf4llm not installed or fallback triggered → use MarkItDown
|
||||
|
||||
return _convert_with_markitdown(file_path)
|
||||
|
||||
|
||||
async def convert_file_to_markdown(file_path: Path) -> Path | None:
|
||||
"""Convert a file to markdown using markitdown.
|
||||
"""Convert a supported document file to Markdown.
|
||||
|
||||
PDF files are handled with a two-converter strategy (see module docstring).
|
||||
Large files (> 1 MB) are offloaded to a thread pool to avoid blocking the
|
||||
event loop.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to convert.
|
||||
|
||||
Returns:
|
||||
Path to the markdown file if conversion was successful, None otherwise.
|
||||
Path to the generated .md file, or None if conversion failed.
|
||||
"""
|
||||
try:
|
||||
from markitdown import MarkItDown
|
||||
pdf_converter = _get_pdf_converter()
|
||||
file_size = file_path.stat().st_size
|
||||
|
||||
md = MarkItDown()
|
||||
result = md.convert(str(file_path))
|
||||
if file_size > _ASYNC_THRESHOLD_BYTES:
|
||||
text = await asyncio.to_thread(_do_convert, file_path, pdf_converter)
|
||||
else:
|
||||
text = _do_convert(file_path, pdf_converter)
|
||||
|
||||
# Save as .md file with same name
|
||||
md_path = file_path.with_suffix(".md")
|
||||
md_path.write_text(result.text_content, encoding="utf-8")
|
||||
md_path.write_text(text, encoding="utf-8")
|
||||
|
||||
logger.info(f"Converted {file_path.name} to markdown: {md_path.name}")
|
||||
logger.info("Converted %s to markdown: %s (%d chars)", file_path.name, md_path.name, len(text))
|
||||
return md_path
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert {file_path.name} to markdown: {e}")
|
||||
logger.error("Failed to convert %s to markdown: %s", file_path.name, e)
|
||||
return None
|
||||
|
||||
|
||||
# Regex for bold-only lines that look like section headings.
|
||||
# Targets SEC filing structural headings that pymupdf4llm renders as **bold**
|
||||
# rather than # Markdown headings (because they use same font size as body text,
|
||||
# distinguished only by bold+caps formatting).
|
||||
#
|
||||
# Pattern requires ALL of:
|
||||
# 1. Entire line is a single **...** block (no surrounding prose)
|
||||
# 2. Starts with a recognised structural keyword:
|
||||
# - ITEM / PART / SECTION (with optional number/letter after)
|
||||
# - SCHEDULE, EXHIBIT, APPENDIX, ANNEX, CHAPTER
|
||||
# All-caps addresses, boilerplate ("CURRENT REPORT", "SIGNATURES",
|
||||
# "WASHINGTON, DC 20549") do NOT start with these keywords and are excluded.
|
||||
#
|
||||
# Chinese headings (第三节...) are already captured as standard # headings
|
||||
# by pymupdf4llm, so they don't need this pattern.
|
||||
_BOLD_HEADING_RE = re.compile(r"^\*\*((ITEM|PART|SECTION|SCHEDULE|EXHIBIT|APPENDIX|ANNEX|CHAPTER)\b[A-Z0-9 .,\-]*)\*\*\s*$")
|
||||
|
||||
# Regex for split-bold headings produced by pymupdf4llm when a heading spans
|
||||
# multiple text spans in the PDF (e.g. section number and title are separate spans).
|
||||
# Matches lines like: **1** **Introduction** or **3.2** **Multi-Head Attention**
|
||||
# Requirements:
|
||||
# 1. Entire line consists only of **...** blocks separated by whitespace (no prose)
|
||||
# 2. First block is a section number (digits and dots, e.g. "1", "3.2", "A.1")
|
||||
# 3. Second block must not be purely numeric/punctuation — excludes financial table
|
||||
# headers like **2023** **2022** **2021** while allowing non-ASCII titles such as
|
||||
# **1** **概述** or accented words (negative lookahead instead of [A-Za-z])
|
||||
# 4. At most two additional blocks (four total) with [^*]+ (no * inside) to keep
|
||||
# the regex linear and avoid ReDoS on attacker-controlled content
|
||||
_SPLIT_BOLD_HEADING_RE = re.compile(r"^\*\*[\dA-Z][\d\.]*\*\*\s+\*\*(?!\d[\d\s.,\-–—/:()%]*\*\*)[^*]+\*\*(?:\s+\*\*[^*]+\*\*){0,2}\s*$")
|
||||
|
||||
# Maximum number of outline entries injected into the agent context.
|
||||
# Keeps prompt size bounded even for very long documents.
|
||||
MAX_OUTLINE_ENTRIES = 50
|
||||
|
||||
_ALLOWED_PDF_CONVERTERS = {"auto", "pymupdf4llm", "markitdown"}
|
||||
|
||||
|
||||
def _clean_bold_title(raw: str) -> str:
|
||||
"""Normalise a title string that may contain pymupdf4llm bold artefacts.
|
||||
|
||||
pymupdf4llm sometimes emits adjacent bold spans as ``**A** **B**`` instead
|
||||
of a single ``**A B**`` block. This helper merges those fragments and then
|
||||
strips the outermost ``**...**`` wrapper so the caller gets plain text.
|
||||
|
||||
Examples::
|
||||
|
||||
"**Overview**" → "Overview"
|
||||
"**UNITED STATES** **SECURITIES**" → "UNITED STATES SECURITIES"
|
||||
"plain text" → "plain text" (unchanged)
|
||||
"""
|
||||
# Merge adjacent bold spans: "** **" → " "
|
||||
merged = re.sub(r"\*\*\s*\*\*", " ", raw).strip()
|
||||
# Strip outermost **...** if the whole string is wrapped
|
||||
if m := re.fullmatch(r"\*\*(.+?)\*\*", merged, re.DOTALL):
|
||||
return m.group(1).strip()
|
||||
return merged
|
||||
|
||||
|
||||
def extract_outline(md_path: Path) -> list[dict]:
|
||||
"""Extract document outline (headings) from a Markdown file.
|
||||
|
||||
Recognises three heading styles produced by pymupdf4llm:
|
||||
|
||||
1. Standard Markdown headings: lines starting with one or more '#'.
|
||||
Inline ``**...**`` wrappers and adjacent bold spans (``** **``) are
|
||||
cleaned so the title is plain text.
|
||||
|
||||
2. Bold-only structural headings: ``**ITEM 1. BUSINESS**``, ``**PART II**``,
|
||||
etc. SEC filings use bold+caps for section headings with the same font
|
||||
size as body text, so pymupdf4llm cannot promote them to # headings.
|
||||
|
||||
3. Split-bold headings: ``**1** **Introduction**``, ``**3.2** **Attention**``.
|
||||
pymupdf4llm emits these when the section number and title text are
|
||||
separate spans in the underlying PDF (common in academic papers).
|
||||
|
||||
Args:
|
||||
md_path: Path to the .md file.
|
||||
|
||||
Returns:
|
||||
List of dicts with keys: title (str), line (int, 1-based).
|
||||
When the outline is truncated at MAX_OUTLINE_ENTRIES, a sentinel entry
|
||||
``{"truncated": True}`` is appended as the last element so callers can
|
||||
render a "showing first N headings" hint without re-scanning the file.
|
||||
Returns an empty list if the file cannot be read or has no headings.
|
||||
"""
|
||||
outline: list[dict] = []
|
||||
try:
|
||||
with md_path.open(encoding="utf-8") as f:
|
||||
for lineno, line in enumerate(f, 1):
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
|
||||
# Style 1: standard Markdown heading
|
||||
if stripped.startswith("#"):
|
||||
title = _clean_bold_title(stripped.lstrip("#").strip())
|
||||
if title:
|
||||
outline.append({"title": title, "line": lineno})
|
||||
|
||||
# Style 2: single bold block with SEC structural keyword
|
||||
elif m := _BOLD_HEADING_RE.match(stripped):
|
||||
title = m.group(1).strip()
|
||||
if title:
|
||||
outline.append({"title": title, "line": lineno})
|
||||
|
||||
# Style 3: split-bold heading — **<num>** **<title>**
|
||||
# Regex already enforces max 4 blocks and non-numeric second block.
|
||||
elif _SPLIT_BOLD_HEADING_RE.match(stripped):
|
||||
title = " ".join(re.findall(r"\*\*([^*]+)\*\*", stripped))
|
||||
if title:
|
||||
outline.append({"title": title, "line": lineno})
|
||||
|
||||
if len(outline) >= MAX_OUTLINE_ENTRIES:
|
||||
outline.append({"truncated": True})
|
||||
break
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
return outline
|
||||
|
||||
|
||||
def _get_pdf_converter() -> str:
|
||||
"""Read pdf_converter setting from app config, defaulting to 'auto'.
|
||||
|
||||
Normalizes the value to lowercase and validates it against the allowed set
|
||||
so that values like 'AUTO' or 'MarkItDown' from config.yaml don't silently
|
||||
fall through to unexpected behaviour.
|
||||
"""
|
||||
try:
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
cfg = get_app_config()
|
||||
uploads_cfg = getattr(cfg, "uploads", None)
|
||||
if uploads_cfg is not None:
|
||||
raw = str(getattr(uploads_cfg, "pdf_converter", "auto")).strip().lower()
|
||||
if raw not in _ALLOWED_PDF_CONVERTERS:
|
||||
logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw)
|
||||
return "auto"
|
||||
return raw
|
||||
except Exception:
|
||||
pass
|
||||
return "auto"
|
||||
|
||||
@@ -14,6 +14,7 @@ dependencies = [
|
||||
"langchain-deepseek>=1.0.1",
|
||||
"langchain-mcp-adapters>=0.1.0",
|
||||
"langchain-openai>=1.1.7",
|
||||
"langfuse>=3.4.1",
|
||||
"langgraph>=1.0.6,<1.0.10",
|
||||
"langgraph-api>=0.7.0,<0.8.0",
|
||||
"langgraph-cli>=0.4.14",
|
||||
@@ -44,6 +45,9 @@ postgres = [
|
||||
"psycopg-pool>=3.3.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
pymupdf = ["pymupdf4llm>=0.0.17"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
Reference in New Issue
Block a user