mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
Merge branch 'main' into rayhpeng/persistence-scaffold
# Conflicts: # .env.example # backend/packages/harness/deerflow/agents/middlewares/title_middleware.py
This commit is contained in:
@@ -345,6 +345,8 @@ def make_lead_agent(config: RunnableConfig):
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
|
||||
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
|
||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
|
||||
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
|
||||
),
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,14 @@ from deerflow.subagents import get_available_subagent_names
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_enabled_skills():
|
||||
try:
|
||||
return list(load_skills(enabled_only=True))
|
||||
except Exception:
|
||||
logger.exception("Failed to load enabled skills for prompt injection")
|
||||
return []
|
||||
|
||||
|
||||
def _build_subagent_section(max_concurrent: int) -> str:
|
||||
"""Build the subagent system prompt section with dynamic concurrency limit.
|
||||
|
||||
@@ -386,7 +394,7 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||
Returns the <skill_system>...</skill_system> block listing all enabled skills,
|
||||
suitable for injection into any agent's system prompt.
|
||||
"""
|
||||
skills = load_skills(enabled_only=True)
|
||||
skills = _get_enabled_skills()
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
@@ -402,6 +410,10 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||
if available_skills is not None:
|
||||
skills = [skill for skill in skills if skill.name in available_skills]
|
||||
|
||||
# Check again after filtering
|
||||
if not skills:
|
||||
return ""
|
||||
|
||||
skill_items = "\n".join(
|
||||
f" <skill>\n <name>{skill.name}</name>\n <description>{skill.description}</description>\n <location>{skill.get_container_file_path(container_base_path)}</location>\n </skill>" for skill in skills
|
||||
)
|
||||
@@ -446,7 +458,7 @@ def get_deferred_tools_prompt_section() -> str:
|
||||
|
||||
if not get_app_config().tool_search.enabled:
|
||||
return ""
|
||||
except FileNotFoundError:
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
registry = get_deferred_registry()
|
||||
|
||||
@@ -29,6 +29,17 @@ Instructions:
|
||||
2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies)
|
||||
3. Update the memory sections as needed following the detailed length guidelines below
|
||||
|
||||
Before extracting facts, perform a structured reflection on the conversation:
|
||||
1. Error/Retry Detection: Did the agent encounter errors, require retries, or produce incorrect results?
|
||||
If yes, record the root cause and correct approach as a high-confidence fact with category "correction".
|
||||
2. User Correction Detection: Did the user correct the agent's direction, understanding, or output?
|
||||
If yes, record the correct interpretation or approach as a high-confidence fact with category "correction".
|
||||
Include what went wrong in "sourceError" only when category is "correction" and the mistake is explicit in the conversation.
|
||||
3. Project Constraint Discovery: Were any project-specific constraints discovered during the conversation?
|
||||
If yes, record them as facts with the most appropriate category and confidence.
|
||||
|
||||
{correction_hint}
|
||||
|
||||
Memory Section Guidelines:
|
||||
|
||||
**User Context** (Current state - concise summaries):
|
||||
@@ -62,6 +73,7 @@ Memory Section Guidelines:
|
||||
* context: Background facts (job title, projects, locations, languages)
|
||||
* behavior: Working patterns, communication habits, problem-solving approaches
|
||||
* goal: Stated objectives, learning targets, project ambitions
|
||||
* correction: Explicit agent mistakes or user corrections, including the correct approach
|
||||
- Confidence levels:
|
||||
* 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y")
|
||||
* 0.7-0.8: Strongly implied from actions/discussions
|
||||
@@ -94,7 +106,7 @@ Output Format (JSON):
|
||||
"longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }}
|
||||
}},
|
||||
"newFacts": [
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
|
||||
],
|
||||
"factsToRemove": ["fact_id_1", "fact_id_2"]
|
||||
}}
|
||||
@@ -104,6 +116,8 @@ Important Rules:
|
||||
- Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs)
|
||||
- Include specific metrics, version numbers, and proper nouns in facts
|
||||
- Only add facts that are clearly stated (0.9+) or strongly implied (0.7+)
|
||||
- Use category "correction" for explicit agent mistakes or user corrections; assign confidence >= 0.95 when the correction is explicit
|
||||
- Include "sourceError" only for explicit correction facts when the prior mistake or wrong approach is clearly stated; omit it otherwise
|
||||
- Remove facts that are contradicted by new information
|
||||
- When updating topOfMind, integrate new focus areas while removing completed/abandoned ones
|
||||
Keep 3-5 concurrent focus themes that are still active and relevant
|
||||
@@ -126,7 +140,7 @@ Message:
|
||||
Extract facts in this JSON format:
|
||||
{{
|
||||
"facts": [
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal", "confidence": 0.0-1.0 }}
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
|
||||
]
|
||||
}}
|
||||
|
||||
@@ -136,6 +150,7 @@ Categories:
|
||||
- context: Background context (location, job, projects)
|
||||
- behavior: Behavioral patterns
|
||||
- goal: User's goals or objectives
|
||||
- correction: Explicit corrections or mistakes to avoid repeating
|
||||
|
||||
Rules:
|
||||
- Only extract clear, specific facts
|
||||
@@ -231,6 +246,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
if earlier.get("summary"):
|
||||
history_sections.append(f"Earlier: {earlier['summary']}")
|
||||
|
||||
background = history_data.get("longTermBackground", {})
|
||||
if background.get("summary"):
|
||||
history_sections.append(f"Background: {background['summary']}")
|
||||
|
||||
if history_sections:
|
||||
sections.append("History:\n" + "\n".join(f"- {s}" for s in history_sections))
|
||||
|
||||
@@ -262,7 +281,11 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
continue
|
||||
category = str(fact.get("category", "context")).strip() or "context"
|
||||
confidence = _coerce_confidence(fact.get("confidence"), default=0.0)
|
||||
line = f"- [{category} | {confidence:.2f}] {content}"
|
||||
source_error = fact.get("sourceError")
|
||||
if category == "correction" and isinstance(source_error, str) and source_error.strip():
|
||||
line = f"- [{category} | {confidence:.2f}] {content} (avoid: {source_error.strip()})"
|
||||
else:
|
||||
line = f"- [{category} | {confidence:.2f}] {content}"
|
||||
|
||||
# Each additional line is preceded by a newline (except the first).
|
||||
line_text = ("\n" + line) if fact_lines else line
|
||||
|
||||
@@ -20,6 +20,7 @@ class ConversationContext:
|
||||
messages: list[Any]
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
agent_name: str | None = None
|
||||
correction_detected: bool = False
|
||||
|
||||
|
||||
class MemoryUpdateQueue:
|
||||
@@ -37,25 +38,38 @@ class MemoryUpdateQueue:
|
||||
self._timer: threading.Timer | None = None
|
||||
self._processing = False
|
||||
|
||||
def add(self, thread_id: str, messages: list[Any], agent_name: str | None = None) -> None:
|
||||
def add(
|
||||
self,
|
||||
thread_id: str,
|
||||
messages: list[Any],
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
) -> None:
|
||||
"""Add a conversation to the update queue.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
messages: The conversation messages.
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
return
|
||||
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
existing_context = next(
|
||||
(context for context in self._queue if context.thread_id == thread_id),
|
||||
None,
|
||||
)
|
||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
correction_detected=merged_correction_detected,
|
||||
)
|
||||
|
||||
# Check if this thread already has a pending update
|
||||
# If so, replace it with the newer one
|
||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||
@@ -115,6 +129,7 @@ class MemoryUpdateQueue:
|
||||
messages=context.messages,
|
||||
thread_id=context.thread_id,
|
||||
agent_name=context.agent_name,
|
||||
correction_detected=context.correction_detected,
|
||||
)
|
||||
if success:
|
||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||
|
||||
@@ -266,13 +266,20 @@ class MemoryUpdater:
|
||||
model_name = self._model_name or config.model_name
|
||||
return create_chat_model(name=model_name, thinking_enabled=False)
|
||||
|
||||
def update_memory(self, messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
|
||||
def update_memory(
|
||||
self,
|
||||
messages: list[Any],
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Update memory based on conversation messages.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
thread_id: Optional thread ID for tracking source.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
@@ -295,9 +302,19 @@ class MemoryUpdater:
|
||||
return False
|
||||
|
||||
# Build prompt
|
||||
correction_hint = ""
|
||||
if correction_detected:
|
||||
correction_hint = (
|
||||
"IMPORTANT: Explicit correction signals were detected in this conversation. "
|
||||
"Pay special attention to what the agent got wrong, what the user corrected, "
|
||||
"and record the correct approach as a fact with category "
|
||||
'"correction" and confidence >= 0.95 when appropriate.'
|
||||
)
|
||||
|
||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||
current_memory=json.dumps(current_memory, indent=2),
|
||||
conversation=conversation_text,
|
||||
correction_hint=correction_hint,
|
||||
)
|
||||
|
||||
# Call LLM
|
||||
@@ -383,6 +400,8 @@ class MemoryUpdater:
|
||||
confidence = fact.get("confidence", 0.5)
|
||||
if confidence >= config.fact_confidence_threshold:
|
||||
raw_content = fact.get("content", "")
|
||||
if not isinstance(raw_content, str):
|
||||
continue
|
||||
normalized_content = raw_content.strip()
|
||||
fact_key = _fact_content_key(normalized_content)
|
||||
if fact_key is not None and fact_key in existing_fact_keys:
|
||||
@@ -396,6 +415,11 @@ class MemoryUpdater:
|
||||
"createdAt": now,
|
||||
"source": thread_id or "unknown",
|
||||
}
|
||||
source_error = fact.get("sourceError")
|
||||
if isinstance(source_error, str):
|
||||
normalized_source_error = source_error.strip()
|
||||
if normalized_source_error:
|
||||
fact_entry["sourceError"] = normalized_source_error
|
||||
current_memory["facts"].append(fact_entry)
|
||||
if fact_key is not None:
|
||||
existing_fact_keys.add(fact_key)
|
||||
@@ -412,16 +436,22 @@ class MemoryUpdater:
|
||||
return current_memory
|
||||
|
||||
|
||||
def update_memory_from_conversation(messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
|
||||
def update_memory_from_conversation(
|
||||
messages: list[Any],
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Convenience function to update memory from a conversation.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
thread_id: Optional thread ID.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
return updater.update_memory(messages, thread_id, agent_name)
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected)
|
||||
|
||||
+275
@@ -0,0 +1,275 @@
|
||||
"""LLM error handling middleware with retry/backoff and user-facing fallbacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
|
||||
_BUSY_PATTERNS = (
|
||||
"server busy",
|
||||
"temporarily unavailable",
|
||||
"try again later",
|
||||
"please retry",
|
||||
"please try again",
|
||||
"overloaded",
|
||||
"high demand",
|
||||
"rate limit",
|
||||
"负载较高",
|
||||
"服务繁忙",
|
||||
"稍后重试",
|
||||
"请稍后重试",
|
||||
)
|
||||
_QUOTA_PATTERNS = (
|
||||
"insufficient_quota",
|
||||
"quota",
|
||||
"billing",
|
||||
"credit",
|
||||
"payment",
|
||||
"余额不足",
|
||||
"超出限额",
|
||||
"额度不足",
|
||||
"欠费",
|
||||
)
|
||||
_AUTH_PATTERNS = (
|
||||
"authentication",
|
||||
"unauthorized",
|
||||
"invalid api key",
|
||||
"invalid_api_key",
|
||||
"permission",
|
||||
"forbidden",
|
||||
"access denied",
|
||||
"无权",
|
||||
"未授权",
|
||||
)
|
||||
|
||||
|
||||
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Retry transient LLM errors and surface graceful assistant messages."""
|
||||
|
||||
retry_max_attempts: int = 3
|
||||
retry_base_delay_ms: int = 1000
|
||||
retry_cap_delay_ms: int = 8000
|
||||
|
||||
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
|
||||
detail = _extract_error_detail(exc)
|
||||
lowered = detail.lower()
|
||||
error_code = _extract_error_code(exc)
|
||||
status_code = _extract_status_code(exc)
|
||||
|
||||
if _matches_any(lowered, _QUOTA_PATTERNS) or _matches_any(str(error_code).lower(), _QUOTA_PATTERNS):
|
||||
return False, "quota"
|
||||
if _matches_any(lowered, _AUTH_PATTERNS):
|
||||
return False, "auth"
|
||||
|
||||
exc_name = exc.__class__.__name__
|
||||
if exc_name in {
|
||||
"APITimeoutError",
|
||||
"APIConnectionError",
|
||||
"InternalServerError",
|
||||
}:
|
||||
return True, "transient"
|
||||
if status_code in _RETRIABLE_STATUS_CODES:
|
||||
return True, "transient"
|
||||
if _matches_any(lowered, _BUSY_PATTERNS):
|
||||
return True, "busy"
|
||||
|
||||
return False, "generic"
|
||||
|
||||
def _build_retry_delay_ms(self, attempt: int, exc: BaseException) -> int:
|
||||
retry_after = _extract_retry_after_ms(exc)
|
||||
if retry_after is not None:
|
||||
return retry_after
|
||||
backoff = self.retry_base_delay_ms * (2 ** max(0, attempt - 1))
|
||||
return min(backoff, self.retry_cap_delay_ms)
|
||||
|
||||
def _build_retry_message(self, attempt: int, wait_ms: int, reason: str) -> str:
|
||||
seconds = max(1, round(wait_ms / 1000))
|
||||
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
|
||||
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
|
||||
|
||||
def _build_user_message(self, exc: BaseException, reason: str) -> str:
|
||||
detail = _extract_error_detail(exc)
|
||||
if reason == "quota":
|
||||
return "The configured LLM provider rejected the request because the account is out of quota, billing is unavailable, or usage is restricted. Please fix the provider account and try again."
|
||||
if reason == "auth":
|
||||
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
|
||||
if reason in {"busy", "transient"}:
|
||||
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
|
||||
return f"LLM request failed: {detail}"
|
||||
|
||||
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
|
||||
try:
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
writer(
|
||||
{
|
||||
"type": "llm_retry",
|
||||
"attempt": attempt,
|
||||
"max_attempts": self.retry_max_attempts,
|
||||
"wait_ms": wait_ms,
|
||||
"reason": reason,
|
||||
"message": self._build_retry_message(attempt, wait_ms, reason),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to emit llm_retry event", exc_info=True)
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
if retriable and attempt < self.retry_max_attempts:
|
||||
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
||||
logger.warning(
|
||||
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
||||
attempt,
|
||||
self.retry_max_attempts,
|
||||
wait_ms,
|
||||
_extract_error_detail(exc),
|
||||
)
|
||||
self._emit_retry_event(attempt, wait_ms, reason)
|
||||
time.sleep(wait_ms / 1000)
|
||||
attempt += 1
|
||||
continue
|
||||
logger.warning(
|
||||
"LLM call failed after %d attempt(s): %s",
|
||||
attempt,
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return await handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
if retriable and attempt < self.retry_max_attempts:
|
||||
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
||||
logger.warning(
|
||||
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
||||
attempt,
|
||||
self.retry_max_attempts,
|
||||
wait_ms,
|
||||
_extract_error_detail(exc),
|
||||
)
|
||||
self._emit_retry_event(attempt, wait_ms, reason)
|
||||
await asyncio.sleep(wait_ms / 1000)
|
||||
attempt += 1
|
||||
continue
|
||||
logger.warning(
|
||||
"LLM call failed after %d attempt(s): %s",
|
||||
attempt,
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
|
||||
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
|
||||
return any(pattern in detail for pattern in patterns)
|
||||
|
||||
|
||||
def _extract_error_code(exc: BaseException) -> Any:
|
||||
for attr in ("code", "error_code"):
|
||||
value = getattr(exc, attr, None)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
|
||||
body = getattr(exc, "body", None)
|
||||
if isinstance(body, dict):
|
||||
error = body.get("error")
|
||||
if isinstance(error, dict):
|
||||
for key in ("code", "type"):
|
||||
value = error.get(key)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _extract_status_code(exc: BaseException) -> int | None:
|
||||
for attr in ("status_code", "status"):
|
||||
value = getattr(exc, attr, None)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
response = getattr(exc, "response", None)
|
||||
status = getattr(response, "status_code", None)
|
||||
return status if isinstance(status, int) else None
|
||||
|
||||
|
||||
def _extract_retry_after_ms(exc: BaseException) -> int | None:
|
||||
response = getattr(exc, "response", None)
|
||||
headers = getattr(response, "headers", None)
|
||||
if headers is None:
|
||||
return None
|
||||
|
||||
raw = None
|
||||
header_name = ""
|
||||
for key in ("retry-after-ms", "Retry-After-Ms", "retry-after", "Retry-After"):
|
||||
header_name = key
|
||||
if hasattr(headers, "get"):
|
||||
raw = headers.get(key)
|
||||
if raw:
|
||||
break
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
try:
|
||||
multiplier = 1 if "ms" in header_name.lower() else 1000
|
||||
return max(0, int(float(raw) * multiplier))
|
||||
except (TypeError, ValueError):
|
||||
try:
|
||||
target = parsedate_to_datetime(str(raw))
|
||||
delta = target.timestamp() - time.time()
|
||||
return max(0, int(delta * 1000))
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
return None
|
||||
|
||||
|
||||
def _extract_error_detail(exc: BaseException) -> str:
|
||||
detail = str(exc).strip()
|
||||
if detail:
|
||||
return detail
|
||||
message = getattr(exc, "message", None)
|
||||
if isinstance(message, str) and message.strip():
|
||||
return message.strip()
|
||||
return exc.__class__.__name__
|
||||
@@ -182,6 +182,23 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
|
||||
return None, False
|
||||
|
||||
@staticmethod
|
||||
def _append_text(content: str | list | None, text: str) -> str | list:
|
||||
"""Append *text* to AIMessage content, handling str, list, and None.
|
||||
|
||||
When content is a list of content blocks (e.g. Anthropic thinking mode),
|
||||
we append a new ``{"type": "text", ...}`` block instead of concatenating
|
||||
a string to a list, which would raise ``TypeError``.
|
||||
"""
|
||||
if content is None:
|
||||
return text
|
||||
if isinstance(content, list):
|
||||
return [*content, {"type": "text", "text": f"\n\n{text}"}]
|
||||
if isinstance(content, str):
|
||||
return content + f"\n\n{text}"
|
||||
# Fallback: coerce unexpected types to str to avoid TypeError
|
||||
return str(content) + f"\n\n{text}"
|
||||
|
||||
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
warning, hard_stop = self._track_and_check(state, runtime)
|
||||
|
||||
@@ -192,7 +209,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
stripped_msg = last_msg.model_copy(
|
||||
update={
|
||||
"tool_calls": [],
|
||||
"content": (last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}",
|
||||
"content": self._append_text(last_msg.content, _HARD_STOP_MSG),
|
||||
}
|
||||
)
|
||||
return {"messages": [stripped_msg]}
|
||||
|
||||
@@ -14,6 +14,21 @@ from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||
_CORRECTION_PATTERNS = (
|
||||
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
|
||||
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
|
||||
re.compile(r"\btry again\b", re.IGNORECASE),
|
||||
re.compile(r"\bredo\b", re.IGNORECASE),
|
||||
re.compile(r"不对"),
|
||||
re.compile(r"你理解错了"),
|
||||
re.compile(r"你理解有误"),
|
||||
re.compile(r"重试"),
|
||||
re.compile(r"重新来"),
|
||||
re.compile(r"换一种"),
|
||||
re.compile(r"改用"),
|
||||
)
|
||||
|
||||
|
||||
class MemoryMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
@@ -21,6 +36,22 @@ class MemoryMiddlewareState(AgentState):
|
||||
pass
|
||||
|
||||
|
||||
def _extract_message_text(message: Any) -> str:
|
||||
"""Extract plain text from message content for filtering and signal detection."""
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
text_parts.append(part)
|
||||
elif isinstance(part, dict):
|
||||
text_val = part.get("text")
|
||||
if isinstance(text_val, str):
|
||||
text_parts.append(text_val)
|
||||
return " ".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
"""Filter messages to keep only user inputs and final assistant responses.
|
||||
|
||||
@@ -44,18 +75,13 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
Returns:
|
||||
Filtered list containing only user inputs and final assistant responses.
|
||||
"""
|
||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||
|
||||
filtered = []
|
||||
skip_next_ai = False
|
||||
for msg in messages:
|
||||
msg_type = getattr(msg, "type", None)
|
||||
|
||||
if msg_type == "human":
|
||||
content = getattr(msg, "content", "")
|
||||
if isinstance(content, list):
|
||||
content = " ".join(p.get("text", "") for p in content if isinstance(p, dict))
|
||||
content_str = str(content)
|
||||
content_str = _extract_message_text(msg)
|
||||
if "<uploaded_files>" in content_str:
|
||||
# Strip the ephemeral upload block; keep the user's real question.
|
||||
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
||||
@@ -87,6 +113,25 @@ def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
return filtered
|
||||
|
||||
|
||||
def detect_correction(messages: list[Any]) -> bool:
|
||||
"""Detect explicit user corrections in recent conversation turns.
|
||||
|
||||
The queue keeps only one pending context per thread, so callers pass the
|
||||
latest filtered message list. Checking only recent user turns keeps signal
|
||||
detection conservative while avoiding stale corrections from long histories.
|
||||
"""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = _extract_message_text(msg).strip()
|
||||
if not content:
|
||||
continue
|
||||
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
"""Middleware that queues conversation for memory update after agent execution.
|
||||
|
||||
@@ -150,7 +195,13 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
return None
|
||||
|
||||
# Queue the filtered conversation for memory update
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
queue = get_memory_queue()
|
||||
queue.add(thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name)
|
||||
queue.add(
|
||||
thread_id=thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=self._agent_name,
|
||||
correction_detected=correction_detected,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@@ -116,44 +116,33 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
return config
|
||||
|
||||
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Synchronously generate a title. Returns state update or None."""
|
||||
"""Generate a local fallback title without blocking on an LLM call."""
|
||||
if not self._should_generate_title(state):
|
||||
return None
|
||||
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
config = get_title_config()
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
|
||||
try:
|
||||
response = model.invoke(prompt, config=self._get_runnable_config())
|
||||
title = self._parse_title(response.content)
|
||||
if not title:
|
||||
title = self._fallback_title(user_msg)
|
||||
except Exception:
|
||||
logger.exception("Failed to generate title (sync)")
|
||||
title = self._fallback_title(user_msg)
|
||||
|
||||
return {"title": title}
|
||||
_, user_msg = self._build_title_prompt(state)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
|
||||
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Asynchronously generate a title. Returns state update or None."""
|
||||
"""Generate a title asynchronously and fall back locally on failure."""
|
||||
if not self._should_generate_title(state):
|
||||
return None
|
||||
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
config = get_title_config()
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
|
||||
try:
|
||||
response = await model.ainvoke(prompt, config=self._get_runnable_config())
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
response = await model.ainvoke(prompt)
|
||||
title = self._parse_title(response.content)
|
||||
if not title:
|
||||
title = self._fallback_title(user_msg)
|
||||
if title:
|
||||
return {"title": title}
|
||||
except Exception:
|
||||
logger.exception("Failed to generate title (async)")
|
||||
title = self._fallback_title(user_msg)
|
||||
|
||||
return {"title": title}
|
||||
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
|
||||
@override
|
||||
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
|
||||
+4
-1
@@ -72,6 +72,7 @@ def _build_runtime_middlewares(
|
||||
lazy_init: bool = True,
|
||||
) -> list[AgentMiddleware]:
|
||||
"""Build shared base middlewares for agent execution."""
|
||||
from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware
|
||||
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from deerflow.sandbox.middleware import SandboxMiddleware
|
||||
|
||||
@@ -90,6 +91,8 @@ def _build_runtime_middlewares(
|
||||
|
||||
middlewares.append(DanglingToolCallMiddleware())
|
||||
|
||||
middlewares.append(LLMErrorHandlingMiddleware())
|
||||
|
||||
# Guardrail middleware (if configured)
|
||||
from deerflow.config.guardrails_config import get_guardrails_config
|
||||
|
||||
@@ -135,6 +138,6 @@ def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentM
|
||||
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
include_uploads=False,
|
||||
include_dangling_tool_call_patch=False,
|
||||
include_dangling_tool_call_patch=True,
|
||||
lazy_init=lazy_init,
|
||||
)
|
||||
|
||||
@@ -10,10 +10,52 @@ from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.utils.file_conversion import extract_outline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_OUTLINE_PREVIEW_LINES = 5
|
||||
|
||||
|
||||
def _extract_outline_for_file(file_path: Path) -> tuple[list[dict], list[str]]:
|
||||
"""Return the document outline and fallback preview for *file_path*.
|
||||
|
||||
Looks for a sibling ``<stem>.md`` file produced by the upload conversion
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
(outline, preview) where:
|
||||
- outline: list of ``{title, line}`` dicts (plus optional sentinel).
|
||||
Empty when no headings are found or no .md exists.
|
||||
- preview: first few non-empty lines of the .md, used as a content
|
||||
anchor when outline is empty so the agent has some context.
|
||||
Empty when outline is non-empty (no fallback needed).
|
||||
"""
|
||||
md_path = file_path.with_suffix(".md")
|
||||
if not md_path.is_file():
|
||||
return [], []
|
||||
|
||||
outline = extract_outline(md_path)
|
||||
if outline:
|
||||
logger.debug("Extracted %d outline entries from %s", len(outline), file_path.name)
|
||||
return outline, []
|
||||
|
||||
# outline is empty — read the first few non-empty lines as a content preview
|
||||
preview: list[str] = []
|
||||
try:
|
||||
with md_path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
stripped = line.strip()
|
||||
if stripped:
|
||||
preview.append(stripped)
|
||||
if len(preview) >= _OUTLINE_PREVIEW_LINES:
|
||||
break
|
||||
except Exception:
|
||||
logger.debug("Failed to read preview lines from %s", md_path, exc_info=True)
|
||||
return [], preview
|
||||
|
||||
|
||||
class UploadsMiddlewareState(AgentState):
|
||||
"""State schema for uploads middleware."""
|
||||
|
||||
@@ -39,12 +81,38 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
super().__init__()
|
||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||
|
||||
def _format_file_entry(self, file: dict, lines: list[str]) -> None:
|
||||
"""Append a single file entry (name, size, path, optional outline) to lines."""
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
outline = file.get("outline") or []
|
||||
if outline:
|
||||
truncated = outline[-1].get("truncated", False)
|
||||
visible = [e for e in outline if not e.get("truncated")]
|
||||
lines.append(" Document outline (use `read_file` with line ranges to read sections):")
|
||||
for entry in visible:
|
||||
lines.append(f" L{entry['line']}: {entry['title']}")
|
||||
if truncated:
|
||||
lines.append(f" ... (showing first {len(visible)} headings; use `read_file` to explore further)")
|
||||
else:
|
||||
preview = file.get("outline_preview") or []
|
||||
if preview:
|
||||
lines.append(" No structural headings detected. Document begins with:")
|
||||
for text in preview:
|
||||
lines.append(f" > {text}")
|
||||
lines.append(" Use `grep` to search for keywords (e.g. `grep(pattern='keyword', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("")
|
||||
|
||||
def _create_files_message(self, new_files: list[dict], historical_files: list[dict]) -> str:
|
||||
"""Create a formatted message listing uploaded files.
|
||||
|
||||
Args:
|
||||
new_files: Files uploaded in the current message.
|
||||
historical_files: Files uploaded in previous messages.
|
||||
Each file dict may contain an optional ``outline`` key — a list of
|
||||
``{title, line}`` dicts extracted from the converted Markdown file.
|
||||
|
||||
Returns:
|
||||
Formatted string inside <uploaded_files> tags.
|
||||
@@ -55,25 +123,24 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
lines.append("")
|
||||
if new_files:
|
||||
for file in new_files:
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
lines.append("")
|
||||
self._format_file_entry(file, lines)
|
||||
else:
|
||||
lines.append("(empty)")
|
||||
lines.append("")
|
||||
|
||||
if historical_files:
|
||||
lines.append("The following files were uploaded in previous messages and are still available:")
|
||||
lines.append("")
|
||||
for file in historical_files:
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
lines.append("")
|
||||
self._format_file_entry(file, lines)
|
||||
|
||||
lines.append("You can read these files using the `read_file` tool with the paths shown above.")
|
||||
lines.append("To work with these files:")
|
||||
lines.append("- Read from the file first — use the outline line numbers and `read_file` to locate relevant sections.")
|
||||
lines.append("- Use `grep` to search for keywords when you are not sure which section to look at")
|
||||
lines.append(" (e.g. `grep(pattern='revenue', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("- Use `glob` to find files by name pattern")
|
||||
lines.append(" (e.g. `glob(pattern='**/*.md', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("- Only fall back to web search if the file content is clearly insufficient to answer the question.")
|
||||
lines.append("</uploaded_files>")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -147,6 +214,13 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
|
||||
# Resolve uploads directory for existence checks
|
||||
thread_id = (runtime.context or {}).get("thread_id")
|
||||
if thread_id is None:
|
||||
try:
|
||||
from langgraph.config import get_config
|
||||
|
||||
thread_id = get_config().get("configurable", {}).get("thread_id")
|
||||
except RuntimeError:
|
||||
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
||||
|
||||
# Get newly uploaded files from the current message's additional_kwargs.files
|
||||
@@ -159,15 +233,26 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
for file_path in sorted(uploads_dir.iterdir()):
|
||||
if file_path.is_file() and file_path.name not in new_filenames:
|
||||
stat = file_path.stat()
|
||||
outline, preview = _extract_outline_for_file(file_path)
|
||||
historical_files.append(
|
||||
{
|
||||
"filename": file_path.name,
|
||||
"size": stat.st_size,
|
||||
"path": f"/mnt/user-data/uploads/{file_path.name}",
|
||||
"extension": file_path.suffix,
|
||||
"outline": outline,
|
||||
"outline_preview": preview,
|
||||
}
|
||||
)
|
||||
|
||||
# Attach outlines to new files as well
|
||||
if uploads_dir:
|
||||
for file in new_files:
|
||||
phys_path = uploads_dir / file["filename"]
|
||||
outline, preview = _extract_outline_for_file(phys_path)
|
||||
file["outline"] = outline
|
||||
file["outline_preview"] = preview
|
||||
|
||||
if not new_files and not historical_files:
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user