Merge branch 'main' into rayhpeng/persistence-scaffold

# Conflicts:
#	.env.example
#	backend/packages/harness/deerflow/agents/middlewares/title_middleware.py
This commit is contained in:
rayhpeng
2026-04-04 21:28:07 +08:00
180 changed files with 10945 additions and 787 deletions
@@ -345,6 +345,8 @@ def make_lead_agent(config: RunnableConfig):
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
),
state_schema=ThreadState,
)
@@ -8,6 +8,14 @@ from deerflow.subagents import get_available_subagent_names
logger = logging.getLogger(__name__)
def _get_enabled_skills():
try:
return list(load_skills(enabled_only=True))
except Exception:
logger.exception("Failed to load enabled skills for prompt injection")
return []
def _build_subagent_section(max_concurrent: int) -> str:
"""Build the subagent system prompt section with dynamic concurrency limit.
@@ -386,7 +394,7 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
Returns the <skill_system>...</skill_system> block listing all enabled skills,
suitable for injection into any agent's system prompt.
"""
skills = load_skills(enabled_only=True)
skills = _get_enabled_skills()
try:
from deerflow.config import get_app_config
@@ -402,6 +410,10 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
if available_skills is not None:
skills = [skill for skill in skills if skill.name in available_skills]
# Check again after filtering
if not skills:
return ""
skill_items = "\n".join(
f" <skill>\n <name>{skill.name}</name>\n <description>{skill.description}</description>\n <location>{skill.get_container_file_path(container_base_path)}</location>\n </skill>" for skill in skills
)
@@ -446,7 +458,7 @@ def get_deferred_tools_prompt_section() -> str:
if not get_app_config().tool_search.enabled:
return ""
except FileNotFoundError:
except Exception:
return ""
registry = get_deferred_registry()
@@ -29,6 +29,17 @@ Instructions:
2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies)
3. Update the memory sections as needed following the detailed length guidelines below
Before extracting facts, perform a structured reflection on the conversation:
1. Error/Retry Detection: Did the agent encounter errors, require retries, or produce incorrect results?
If yes, record the root cause and correct approach as a high-confidence fact with category "correction".
2. User Correction Detection: Did the user correct the agent's direction, understanding, or output?
If yes, record the correct interpretation or approach as a high-confidence fact with category "correction".
Include what went wrong in "sourceError" only when category is "correction" and the mistake is explicit in the conversation.
3. Project Constraint Discovery: Were any project-specific constraints discovered during the conversation?
If yes, record them as facts with the most appropriate category and confidence.
{correction_hint}
Memory Section Guidelines:
**User Context** (Current state - concise summaries):
@@ -62,6 +73,7 @@ Memory Section Guidelines:
* context: Background facts (job title, projects, locations, languages)
* behavior: Working patterns, communication habits, problem-solving approaches
* goal: Stated objectives, learning targets, project ambitions
* correction: Explicit agent mistakes or user corrections, including the correct approach
- Confidence levels:
* 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y")
* 0.7-0.8: Strongly implied from actions/discussions
@@ -94,7 +106,7 @@ Output Format (JSON):
"longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }}
}},
"newFacts": [
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
],
"factsToRemove": ["fact_id_1", "fact_id_2"]
}}
@@ -104,6 +116,8 @@ Important Rules:
- Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs)
- Include specific metrics, version numbers, and proper nouns in facts
- Only add facts that are clearly stated (0.9+) or strongly implied (0.7+)
- Use category "correction" for explicit agent mistakes or user corrections; assign confidence >= 0.95 when the correction is explicit
- Include "sourceError" only for explicit correction facts when the prior mistake or wrong approach is clearly stated; omit it otherwise
- Remove facts that are contradicted by new information
- When updating topOfMind, integrate new focus areas while removing completed/abandoned ones
Keep 3-5 concurrent focus themes that are still active and relevant
@@ -126,7 +140,7 @@ Message:
Extract facts in this JSON format:
{{
"facts": [
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
]
}}
@@ -136,6 +150,7 @@ Categories:
- context: Background context (location, job, projects)
- behavior: Behavioral patterns
- goal: User's goals or objectives
- correction: Explicit corrections or mistakes to avoid repeating
Rules:
- Only extract clear, specific facts
@@ -231,6 +246,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
if earlier.get("summary"):
history_sections.append(f"Earlier: {earlier['summary']}")
background = history_data.get("longTermBackground", {})
if background.get("summary"):
history_sections.append(f"Background: {background['summary']}")
if history_sections:
sections.append("History:\n" + "\n".join(f"- {s}" for s in history_sections))
@@ -262,7 +281,11 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
continue
category = str(fact.get("category", "context")).strip() or "context"
confidence = _coerce_confidence(fact.get("confidence"), default=0.0)
line = f"- [{category} | {confidence:.2f}] {content}"
source_error = fact.get("sourceError")
if category == "correction" and isinstance(source_error, str) and source_error.strip():
line = f"- [{category} | {confidence:.2f}] {content} (avoid: {source_error.strip()})"
else:
line = f"- [{category} | {confidence:.2f}] {content}"
# Each additional line is preceded by a newline (except the first).
line_text = ("\n" + line) if fact_lines else line
@@ -20,6 +20,7 @@ class ConversationContext:
messages: list[Any]
timestamp: datetime = field(default_factory=datetime.utcnow)
agent_name: str | None = None
correction_detected: bool = False
class MemoryUpdateQueue:
@@ -37,25 +38,38 @@ class MemoryUpdateQueue:
self._timer: threading.Timer | None = None
self._processing = False
def add(self, thread_id: str, messages: list[Any], agent_name: str | None = None) -> None:
def add(
self,
thread_id: str,
messages: list[Any],
agent_name: str | None = None,
correction_detected: bool = False,
) -> None:
"""Add a conversation to the update queue.
Args:
thread_id: The thread ID.
messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
correction_detected: Whether recent turns include an explicit correction signal.
"""
config = get_memory_config()
if not config.enabled:
return
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
)
with self._lock:
existing_context = next(
(context for context in self._queue if context.thread_id == thread_id),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=merged_correction_detected,
)
# Check if this thread already has a pending update
# If so, replace it with the newer one
self._queue = [c for c in self._queue if c.thread_id != thread_id]
@@ -115,6 +129,7 @@ class MemoryUpdateQueue:
messages=context.messages,
thread_id=context.thread_id,
agent_name=context.agent_name,
correction_detected=context.correction_detected,
)
if success:
logger.info("Memory updated successfully for thread %s", context.thread_id)
@@ -266,13 +266,20 @@ class MemoryUpdater:
model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False)
def update_memory(self, messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
def update_memory(
self,
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
) -> bool:
"""Update memory based on conversation messages.
Args:
messages: List of conversation messages.
thread_id: Optional thread ID for tracking source.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
Returns:
True if update was successful, False otherwise.
@@ -295,9 +302,19 @@ class MemoryUpdater:
return False
# Build prompt
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
)
# Call LLM
@@ -383,6 +400,8 @@ class MemoryUpdater:
confidence = fact.get("confidence", 0.5)
if confidence >= config.fact_confidence_threshold:
raw_content = fact.get("content", "")
if not isinstance(raw_content, str):
continue
normalized_content = raw_content.strip()
fact_key = _fact_content_key(normalized_content)
if fact_key is not None and fact_key in existing_fact_keys:
@@ -396,6 +415,11 @@ class MemoryUpdater:
"createdAt": now,
"source": thread_id or "unknown",
}
source_error = fact.get("sourceError")
if isinstance(source_error, str):
normalized_source_error = source_error.strip()
if normalized_source_error:
fact_entry["sourceError"] = normalized_source_error
current_memory["facts"].append(fact_entry)
if fact_key is not None:
existing_fact_keys.add(fact_key)
@@ -412,16 +436,22 @@ class MemoryUpdater:
return current_memory
def update_memory_from_conversation(messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
def update_memory_from_conversation(
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
) -> bool:
"""Convenience function to update memory from a conversation.
Args:
messages: List of conversation messages.
thread_id: Optional thread ID.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
Returns:
True if successful, False otherwise.
"""
updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name)
return updater.update_memory(messages, thread_id, agent_name, correction_detected)
@@ -0,0 +1,275 @@
"""LLM error handling middleware with retry/backoff and user-facing fallbacks."""
from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Awaitable, Callable
from email.utils import parsedate_to_datetime
from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import (
ModelCallResult,
ModelRequest,
ModelResponse,
)
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
logger = logging.getLogger(__name__)
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
_BUSY_PATTERNS = (
"server busy",
"temporarily unavailable",
"try again later",
"please retry",
"please try again",
"overloaded",
"high demand",
"rate limit",
"负载较高",
"服务繁忙",
"稍后重试",
"请稍后重试",
)
_QUOTA_PATTERNS = (
"insufficient_quota",
"quota",
"billing",
"credit",
"payment",
"余额不足",
"超出限额",
"额度不足",
"欠费",
)
_AUTH_PATTERNS = (
"authentication",
"unauthorized",
"invalid api key",
"invalid_api_key",
"permission",
"forbidden",
"access denied",
"无权",
"未授权",
)
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"""Retry transient LLM errors and surface graceful assistant messages."""
retry_max_attempts: int = 3
retry_base_delay_ms: int = 1000
retry_cap_delay_ms: int = 8000
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
detail = _extract_error_detail(exc)
lowered = detail.lower()
error_code = _extract_error_code(exc)
status_code = _extract_status_code(exc)
if _matches_any(lowered, _QUOTA_PATTERNS) or _matches_any(str(error_code).lower(), _QUOTA_PATTERNS):
return False, "quota"
if _matches_any(lowered, _AUTH_PATTERNS):
return False, "auth"
exc_name = exc.__class__.__name__
if exc_name in {
"APITimeoutError",
"APIConnectionError",
"InternalServerError",
}:
return True, "transient"
if status_code in _RETRIABLE_STATUS_CODES:
return True, "transient"
if _matches_any(lowered, _BUSY_PATTERNS):
return True, "busy"
return False, "generic"
def _build_retry_delay_ms(self, attempt: int, exc: BaseException) -> int:
retry_after = _extract_retry_after_ms(exc)
if retry_after is not None:
return retry_after
backoff = self.retry_base_delay_ms * (2 ** max(0, attempt - 1))
return min(backoff, self.retry_cap_delay_ms)
def _build_retry_message(self, attempt: int, wait_ms: int, reason: str) -> str:
seconds = max(1, round(wait_ms / 1000))
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
def _build_user_message(self, exc: BaseException, reason: str) -> str:
detail = _extract_error_detail(exc)
if reason == "quota":
return "The configured LLM provider rejected the request because the account is out of quota, billing is unavailable, or usage is restricted. Please fix the provider account and try again."
if reason == "auth":
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
if reason in {"busy", "transient"}:
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
return f"LLM request failed: {detail}"
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
try:
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer(
{
"type": "llm_retry",
"attempt": attempt,
"max_attempts": self.retry_max_attempts,
"wait_ms": wait_ms,
"reason": reason,
"message": self._build_retry_message(attempt, wait_ms, reason),
}
)
except Exception:
logger.debug("Failed to emit llm_retry event", exc_info=True)
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
attempt = 1
while True:
try:
return handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
if retriable and attempt < self.retry_max_attempts:
wait_ms = self._build_retry_delay_ms(attempt, exc)
logger.warning(
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
attempt,
self.retry_max_attempts,
wait_ms,
_extract_error_detail(exc),
)
self._emit_retry_event(attempt, wait_ms, reason)
time.sleep(wait_ms / 1000)
attempt += 1
continue
logger.warning(
"LLM call failed after %d attempt(s): %s",
attempt,
_extract_error_detail(exc),
exc_info=exc,
)
return AIMessage(content=self._build_user_message(exc, reason))
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
attempt = 1
while True:
try:
return await handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
if retriable and attempt < self.retry_max_attempts:
wait_ms = self._build_retry_delay_ms(attempt, exc)
logger.warning(
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
attempt,
self.retry_max_attempts,
wait_ms,
_extract_error_detail(exc),
)
self._emit_retry_event(attempt, wait_ms, reason)
await asyncio.sleep(wait_ms / 1000)
attempt += 1
continue
logger.warning(
"LLM call failed after %d attempt(s): %s",
attempt,
_extract_error_detail(exc),
exc_info=exc,
)
return AIMessage(content=self._build_user_message(exc, reason))
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
return any(pattern in detail for pattern in patterns)
def _extract_error_code(exc: BaseException) -> Any:
for attr in ("code", "error_code"):
value = getattr(exc, attr, None)
if value not in (None, ""):
return value
body = getattr(exc, "body", None)
if isinstance(body, dict):
error = body.get("error")
if isinstance(error, dict):
for key in ("code", "type"):
value = error.get(key)
if value not in (None, ""):
return value
return None
def _extract_status_code(exc: BaseException) -> int | None:
for attr in ("status_code", "status"):
value = getattr(exc, attr, None)
if isinstance(value, int):
return value
response = getattr(exc, "response", None)
status = getattr(response, "status_code", None)
return status if isinstance(status, int) else None
def _extract_retry_after_ms(exc: BaseException) -> int | None:
response = getattr(exc, "response", None)
headers = getattr(response, "headers", None)
if headers is None:
return None
raw = None
header_name = ""
for key in ("retry-after-ms", "Retry-After-Ms", "retry-after", "Retry-After"):
header_name = key
if hasattr(headers, "get"):
raw = headers.get(key)
if raw:
break
if not raw:
return None
try:
multiplier = 1 if "ms" in header_name.lower() else 1000
return max(0, int(float(raw) * multiplier))
except (TypeError, ValueError):
try:
target = parsedate_to_datetime(str(raw))
delta = target.timestamp() - time.time()
return max(0, int(delta * 1000))
except (TypeError, ValueError, OverflowError):
return None
def _extract_error_detail(exc: BaseException) -> str:
detail = str(exc).strip()
if detail:
return detail
message = getattr(exc, "message", None)
if isinstance(message, str) and message.strip():
return message.strip()
return exc.__class__.__name__
@@ -182,6 +182,23 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
return None, False
@staticmethod
def _append_text(content: str | list | None, text: str) -> str | list:
"""Append *text* to AIMessage content, handling str, list, and None.
When content is a list of content blocks (e.g. Anthropic thinking mode),
we append a new ``{"type": "text", ...}`` block instead of concatenating
a string to a list, which would raise ``TypeError``.
"""
if content is None:
return text
if isinstance(content, list):
return [*content, {"type": "text", "text": f"\n\n{text}"}]
if isinstance(content, str):
return content + f"\n\n{text}"
# Fallback: coerce unexpected types to str to avoid TypeError
return str(content) + f"\n\n{text}"
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
warning, hard_stop = self._track_and_check(state, runtime)
@@ -192,7 +209,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
stripped_msg = last_msg.model_copy(
update={
"tool_calls": [],
"content": (last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}",
"content": self._append_text(last_msg.content, _HARD_STOP_MSG),
}
)
return {"messages": [stripped_msg]}
@@ -14,6 +14,21 @@ from deerflow.config.memory_config import get_memory_config
logger = logging.getLogger(__name__)
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
_CORRECTION_PATTERNS = (
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
re.compile(r"\btry again\b", re.IGNORECASE),
re.compile(r"\bredo\b", re.IGNORECASE),
re.compile(r"不对"),
re.compile(r"你理解错了"),
re.compile(r"你理解有误"),
re.compile(r"重试"),
re.compile(r"重新来"),
re.compile(r"换一种"),
re.compile(r"改用"),
)
class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
@@ -21,6 +36,22 @@ class MemoryMiddlewareState(AgentState):
pass
def _extract_message_text(message: Any) -> str:
"""Extract plain text from message content for filtering and signal detection."""
content = getattr(message, "content", "")
if isinstance(content, list):
text_parts: list[str] = []
for part in content:
if isinstance(part, str):
text_parts.append(part)
elif isinstance(part, dict):
text_val = part.get("text")
if isinstance(text_val, str):
text_parts.append(text_val)
return " ".join(text_parts)
return str(content)
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
"""Filter messages to keep only user inputs and final assistant responses.
@@ -44,18 +75,13 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
Returns:
Filtered list containing only user inputs and final assistant responses.
"""
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
filtered = []
skip_next_ai = False
for msg in messages:
msg_type = getattr(msg, "type", None)
if msg_type == "human":
content = getattr(msg, "content", "")
if isinstance(content, list):
content = " ".join(p.get("text", "") for p in content if isinstance(p, dict))
content_str = str(content)
content_str = _extract_message_text(msg)
if "<uploaded_files>" in content_str:
# Strip the ephemeral upload block; keep the user's real question.
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
@@ -87,6 +113,25 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
return filtered
def detect_correction(messages: list[Any]) -> bool:
"""Detect explicit user corrections in recent conversation turns.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale corrections from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
return True
return False
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution.
@@ -150,7 +195,13 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
return None
# Queue the filtered conversation for memory update
correction_detected = detect_correction(filtered_messages)
queue = get_memory_queue()
queue.add(thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name)
queue.add(
thread_id=thread_id,
messages=filtered_messages,
agent_name=self._agent_name,
correction_detected=correction_detected,
)
return None
@@ -116,44 +116,33 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return config
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Synchronously generate a title. Returns state update or None."""
"""Generate a local fallback title without blocking on an LLM call."""
if not self._should_generate_title(state):
return None
prompt, user_msg = self._build_title_prompt(state)
config = get_title_config()
model = create_chat_model(name=config.model_name, thinking_enabled=False)
try:
response = model.invoke(prompt, config=self._get_runnable_config())
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)
except Exception:
logger.exception("Failed to generate title (sync)")
title = self._fallback_title(user_msg)
return {"title": title}
_, user_msg = self._build_title_prompt(state)
return {"title": self._fallback_title(user_msg)}
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Asynchronously generate a title. Returns state update or None."""
"""Generate a title asynchronously and fall back locally on failure."""
if not self._should_generate_title(state):
return None
prompt, user_msg = self._build_title_prompt(state)
config = get_title_config()
model = create_chat_model(name=config.model_name, thinking_enabled=False)
prompt, user_msg = self._build_title_prompt(state)
try:
response = await model.ainvoke(prompt, config=self._get_runnable_config())
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
else:
model = create_chat_model(thinking_enabled=False)
response = await model.ainvoke(prompt)
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)
if title:
return {"title": title}
except Exception:
logger.exception("Failed to generate title (async)")
title = self._fallback_title(user_msg)
return {"title": title}
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
return {"title": self._fallback_title(user_msg)}
@override
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
@@ -72,6 +72,7 @@ def _build_runtime_middlewares(
lazy_init: bool = True,
) -> list[AgentMiddleware]:
"""Build shared base middlewares for agent execution."""
from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from deerflow.sandbox.middleware import SandboxMiddleware
@@ -90,6 +91,8 @@ def _build_runtime_middlewares(
middlewares.append(DanglingToolCallMiddleware())
middlewares.append(LLMErrorHandlingMiddleware())
# Guardrail middleware (if configured)
from deerflow.config.guardrails_config import get_guardrails_config
@@ -135,6 +138,6 @@ def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentM
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
return _build_runtime_middlewares(
include_uploads=False,
include_dangling_tool_call_patch=False,
include_dangling_tool_call_patch=True,
lazy_init=lazy_init,
)
@@ -10,10 +10,52 @@ from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.config.paths import Paths, get_paths
from deerflow.utils.file_conversion import extract_outline
logger = logging.getLogger(__name__)
_OUTLINE_PREVIEW_LINES = 5
def _extract_outline_for_file(file_path: Path) -> tuple[list[dict], list[str]]:
"""Return the document outline and fallback preview for *file_path*.
Looks for a sibling ``<stem>.md`` file produced by the upload conversion
pipeline.
Returns:
(outline, preview) where:
- outline: list of ``{title, line}`` dicts (plus optional sentinel).
Empty when no headings are found or no .md exists.
- preview: first few non-empty lines of the .md, used as a content
anchor when outline is empty so the agent has some context.
Empty when outline is non-empty (no fallback needed).
"""
md_path = file_path.with_suffix(".md")
if not md_path.is_file():
return [], []
outline = extract_outline(md_path)
if outline:
logger.debug("Extracted %d outline entries from %s", len(outline), file_path.name)
return outline, []
# outline is empty — read the first few non-empty lines as a content preview
preview: list[str] = []
try:
with md_path.open(encoding="utf-8") as f:
for line in f:
stripped = line.strip()
if stripped:
preview.append(stripped)
if len(preview) >= _OUTLINE_PREVIEW_LINES:
break
except Exception:
logger.debug("Failed to read preview lines from %s", md_path, exc_info=True)
return [], preview
class UploadsMiddlewareState(AgentState):
"""State schema for uploads middleware."""
@@ -39,12 +81,38 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
super().__init__()
self._paths = Paths(base_dir) if base_dir else get_paths()
def _format_file_entry(self, file: dict, lines: list[str]) -> None:
"""Append a single file entry (name, size, path, optional outline) to lines."""
size_kb = file["size"] / 1024
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
lines.append(f"- {file['filename']} ({size_str})")
lines.append(f" Path: {file['path']}")
outline = file.get("outline") or []
if outline:
truncated = outline[-1].get("truncated", False)
visible = [e for e in outline if not e.get("truncated")]
lines.append(" Document outline (use `read_file` with line ranges to read sections):")
for entry in visible:
lines.append(f" L{entry['line']}: {entry['title']}")
if truncated:
lines.append(f" ... (showing first {len(visible)} headings; use `read_file` to explore further)")
else:
preview = file.get("outline_preview") or []
if preview:
lines.append(" No structural headings detected. Document begins with:")
for text in preview:
lines.append(f" > {text}")
lines.append(" Use `grep` to search for keywords (e.g. `grep(pattern='keyword', path='/mnt/user-data/uploads/')`).")
lines.append("")
def _create_files_message(self, new_files: list[dict], historical_files: list[dict]) -> str:
"""Create a formatted message listing uploaded files.
Args:
new_files: Files uploaded in the current message.
historical_files: Files uploaded in previous messages.
Each file dict may contain an optional ``outline`` key — a list of
``{title, line}`` dicts extracted from the converted Markdown file.
Returns:
Formatted string inside <uploaded_files> tags.
@@ -55,25 +123,24 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
lines.append("")
if new_files:
for file in new_files:
size_kb = file["size"] / 1024
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
lines.append(f"- {file['filename']} ({size_str})")
lines.append(f" Path: {file['path']}")
lines.append("")
self._format_file_entry(file, lines)
else:
lines.append("(empty)")
lines.append("")
if historical_files:
lines.append("The following files were uploaded in previous messages and are still available:")
lines.append("")
for file in historical_files:
size_kb = file["size"] / 1024
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
lines.append(f"- {file['filename']} ({size_str})")
lines.append(f" Path: {file['path']}")
lines.append("")
self._format_file_entry(file, lines)
lines.append("You can read these files using the `read_file` tool with the paths shown above.")
lines.append("To work with these files:")
lines.append("- Read from the file first — use the outline line numbers and `read_file` to locate relevant sections.")
lines.append("- Use `grep` to search for keywords when you are not sure which section to look at")
lines.append(" (e.g. `grep(pattern='revenue', path='/mnt/user-data/uploads/')`).")
lines.append("- Use `glob` to find files by name pattern")
lines.append(" (e.g. `glob(pattern='**/*.md', path='/mnt/user-data/uploads/')`).")
lines.append("- Only fall back to web search if the file content is clearly insufficient to answer the question.")
lines.append("</uploaded_files>")
return "\n".join(lines)
@@ -147,6 +214,13 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
# Resolve uploads directory for existence checks
thread_id = (runtime.context or {}).get("thread_id")
if thread_id is None:
try:
from langgraph.config import get_config
thread_id = get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
pass # get_config() raises outside a runnable context (e.g. unit tests)
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files
@@ -159,15 +233,26 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
for file_path in sorted(uploads_dir.iterdir()):
if file_path.is_file() and file_path.name not in new_filenames:
stat = file_path.stat()
outline, preview = _extract_outline_for_file(file_path)
historical_files.append(
{
"filename": file_path.name,
"size": stat.st_size,
"path": f"/mnt/user-data/uploads/{file_path.name}",
"extension": file_path.suffix,
"outline": outline,
"outline_preview": preview,
}
)
# Attach outlines to new files as well
if uploads_dir:
for file in new_files:
phys_path = uploads_dir / file["filename"]
outline, preview = _extract_outline_for_file(phys_path)
file["outline"] = outline
file["outline_preview"] = preview
if not new_files and not historical_files:
return None