mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
Merge branch 'main' into rayhpeng/persistence-scaffold
# Conflicts: # backend/Dockerfile # backend/uv.lock
This commit is contained in:
@@ -21,6 +21,7 @@ class ConversationContext:
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
agent_name: str | None = None
|
||||
correction_detected: bool = False
|
||||
reinforcement_detected: bool = False
|
||||
|
||||
|
||||
class MemoryUpdateQueue:
|
||||
@@ -44,6 +45,7 @@ class MemoryUpdateQueue:
|
||||
messages: list[Any],
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> None:
|
||||
"""Add a conversation to the update queue.
|
||||
|
||||
@@ -52,6 +54,7 @@ class MemoryUpdateQueue:
|
||||
messages: The conversation messages.
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
@@ -63,11 +66,13 @@ class MemoryUpdateQueue:
|
||||
None,
|
||||
)
|
||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
correction_detected=merged_correction_detected,
|
||||
reinforcement_detected=merged_reinforcement_detected,
|
||||
)
|
||||
|
||||
# Check if this thread already has a pending update
|
||||
@@ -130,6 +135,7 @@ class MemoryUpdateQueue:
|
||||
thread_id=context.thread_id,
|
||||
agent_name=context.agent_name,
|
||||
correction_detected=context.correction_detected,
|
||||
reinforcement_detected=context.reinforcement_detected,
|
||||
)
|
||||
if success:
|
||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||
|
||||
@@ -246,7 +246,7 @@ def _fact_content_key(content: Any) -> str | None:
|
||||
stripped = content.strip()
|
||||
if not stripped:
|
||||
return None
|
||||
return stripped
|
||||
return stripped.casefold()
|
||||
|
||||
|
||||
class MemoryUpdater:
|
||||
@@ -272,6 +272,7 @@ class MemoryUpdater:
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Update memory based on conversation messages.
|
||||
|
||||
@@ -280,6 +281,7 @@ class MemoryUpdater:
|
||||
thread_id: Optional thread ID for tracking source.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
@@ -310,6 +312,14 @@ class MemoryUpdater:
|
||||
"and record the correct approach as a fact with category "
|
||||
'"correction" and confidence >= 0.95 when appropriate.'
|
||||
)
|
||||
if reinforcement_detected:
|
||||
reinforcement_hint = (
|
||||
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
|
||||
"The user explicitly confirmed the agent's approach was correct or helpful. "
|
||||
"Record the confirmed approach, style, or preference as a fact with category "
|
||||
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
|
||||
)
|
||||
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
|
||||
|
||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||
current_memory=json.dumps(current_memory, indent=2),
|
||||
@@ -441,6 +451,7 @@ def update_memory_from_conversation(
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Convenience function to update memory from a conversation.
|
||||
|
||||
@@ -449,9 +460,10 @@ def update_memory_from_conversation(
|
||||
thread_id: Optional thread ID.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected)
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
|
||||
|
||||
@@ -29,6 +29,22 @@ _CORRECTION_PATTERNS = (
|
||||
re.compile(r"改用"),
|
||||
)
|
||||
|
||||
_REINFORCEMENT_PATTERNS = (
|
||||
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
|
||||
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
|
||||
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
|
||||
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
|
||||
re.compile(r"完全正确(?:[。!?!?.]|$)"),
|
||||
re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
|
||||
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
|
||||
re.compile(r"继续保持(?:[。!?!?.]|$)"),
|
||||
)
|
||||
|
||||
|
||||
class MemoryMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
@@ -132,6 +148,29 @@ def detect_correction(messages: list[Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def detect_reinforcement(messages: list[Any]) -> bool:
|
||||
"""Detect explicit positive reinforcement signals in recent conversation turns.
|
||||
|
||||
Complements detect_correction() by identifying when the user confirms the
|
||||
agent's approach was correct. This allows the memory system to record what
|
||||
worked well, not just what went wrong.
|
||||
|
||||
The queue keeps only one pending context per thread, so callers pass the
|
||||
latest filtered message list. Checking only recent user turns keeps signal
|
||||
detection conservative while avoiding stale signals from long histories.
|
||||
"""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = _extract_message_text(msg).strip()
|
||||
if not content:
|
||||
continue
|
||||
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
"""Middleware that queues conversation for memory update after agent execution.
|
||||
|
||||
@@ -196,12 +235,14 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
|
||||
# Queue the filtered conversation for memory update
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||
queue = get_memory_queue()
|
||||
queue.add(
|
||||
thread_id=thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=self._agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@@ -15,6 +15,11 @@ class SubagentOverrideConfig(BaseModel):
|
||||
ge=1,
|
||||
description="Timeout in seconds for this subagent (None = use global default)",
|
||||
)
|
||||
max_turns: int | None = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
description="Maximum turns for this subagent (None = use global or builtin default)",
|
||||
)
|
||||
|
||||
|
||||
class SubagentsAppConfig(BaseModel):
|
||||
@@ -25,6 +30,11 @@ class SubagentsAppConfig(BaseModel):
|
||||
ge=1,
|
||||
description="Default timeout in seconds for all subagents (default: 900 = 15 minutes)",
|
||||
)
|
||||
max_turns: int | None = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
description="Optional default max-turn override for all subagents (None = keep builtin defaults)",
|
||||
)
|
||||
agents: dict[str, SubagentOverrideConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="Per-agent configuration overrides keyed by agent name",
|
||||
@@ -44,6 +54,15 @@ class SubagentsAppConfig(BaseModel):
|
||||
return override.timeout_seconds
|
||||
return self.timeout_seconds
|
||||
|
||||
def get_max_turns_for(self, agent_name: str, builtin_default: int) -> int:
|
||||
"""Get the effective max_turns for a specific agent."""
|
||||
override = self.agents.get(agent_name)
|
||||
if override is not None and override.max_turns is not None:
|
||||
return override.max_turns
|
||||
if self.max_turns is not None:
|
||||
return self.max_turns
|
||||
return builtin_default
|
||||
|
||||
|
||||
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
|
||||
|
||||
@@ -58,8 +77,26 @@ def load_subagents_config_from_dict(config_dict: dict) -> None:
|
||||
global _subagents_config
|
||||
_subagents_config = SubagentsAppConfig(**config_dict)
|
||||
|
||||
overrides_summary = {name: f"{override.timeout_seconds}s" for name, override in _subagents_config.agents.items() if override.timeout_seconds is not None}
|
||||
overrides_summary = {}
|
||||
for name, override in _subagents_config.agents.items():
|
||||
parts = []
|
||||
if override.timeout_seconds is not None:
|
||||
parts.append(f"timeout={override.timeout_seconds}s")
|
||||
if override.max_turns is not None:
|
||||
parts.append(f"max_turns={override.max_turns}")
|
||||
if parts:
|
||||
overrides_summary[name] = ", ".join(parts)
|
||||
|
||||
if overrides_summary:
|
||||
logger.info(f"Subagents config loaded: default timeout={_subagents_config.timeout_seconds}s, per-agent overrides={overrides_summary}")
|
||||
logger.info(
|
||||
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s",
|
||||
_subagents_config.timeout_seconds,
|
||||
_subagents_config.max_turns,
|
||||
overrides_summary,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Subagents config loaded: default timeout={_subagents_config.timeout_seconds}s, no per-agent overrides")
|
||||
logger.info(
|
||||
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
|
||||
_subagents_config.timeout_seconds,
|
||||
_subagents_config.max_turns,
|
||||
)
|
||||
|
||||
@@ -366,12 +366,17 @@ def _path_variants(path: str) -> set[str]:
|
||||
return {path, path.replace("\\", "/"), path.replace("/", "\\")}
|
||||
|
||||
|
||||
def _path_separator_for_style(path: str) -> str:
|
||||
return "\\" if "\\" in path and "/" not in path else "/"
|
||||
|
||||
|
||||
def _join_path_preserving_style(base: str, relative: str) -> str:
|
||||
if not relative:
|
||||
return base
|
||||
if "/" in base and "\\" not in base:
|
||||
return f"{base.rstrip('/')}/{relative}"
|
||||
return str(Path(base) / relative)
|
||||
separator = _path_separator_for_style(base)
|
||||
normalized_relative = relative.replace("\\" if separator == "/" else "/", separator).lstrip("/\\")
|
||||
stripped_base = base.rstrip("/\\")
|
||||
return f"{stripped_base}{separator}{normalized_relative}"
|
||||
|
||||
|
||||
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
|
||||
@@ -416,7 +421,10 @@ def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
|
||||
return actual_base
|
||||
if path.startswith(f"{virtual_base}/"):
|
||||
rest = path[len(virtual_base) :].lstrip("/")
|
||||
return _join_path_preserving_style(actual_base, rest)
|
||||
result = _join_path_preserving_style(actual_base, rest)
|
||||
if path.endswith("/") and not result.endswith(("/", "\\")):
|
||||
result += _path_separator_for_style(actual_base)
|
||||
return result
|
||||
|
||||
return path
|
||||
|
||||
@@ -801,7 +809,8 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
|
||||
if sandbox is None:
|
||||
raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id)
|
||||
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
|
||||
if runtime.context is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
|
||||
return sandbox
|
||||
|
||||
|
||||
@@ -836,7 +845,8 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
|
||||
if sandbox_id is not None:
|
||||
sandbox = get_sandbox_provider().get(sandbox_id)
|
||||
if sandbox is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
|
||||
if runtime.context is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
|
||||
return sandbox
|
||||
# Sandbox was released, fall through to acquire new one
|
||||
|
||||
@@ -858,7 +868,8 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
|
||||
if sandbox is None:
|
||||
raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id)
|
||||
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
|
||||
if runtime.context is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
|
||||
return sandbox
|
||||
|
||||
|
||||
|
||||
@@ -43,5 +43,5 @@ You have access to the sandbox environment:
|
||||
tools=["bash", "ls", "read_file", "write_file", "str_replace"], # Sandbox tools only
|
||||
disallowed_tools=["task", "ask_clarification", "present_files"],
|
||||
model="inherit",
|
||||
max_turns=30,
|
||||
max_turns=60,
|
||||
)
|
||||
|
||||
@@ -44,5 +44,5 @@ You have access to the same sandbox environment as the parent agent:
|
||||
tools=None, # Inherit all tools from parent
|
||||
disallowed_tools=["task", "ask_clarification", "present_files"], # Prevent nesting and clarification
|
||||
model="inherit",
|
||||
max_turns=50,
|
||||
max_turns=100,
|
||||
)
|
||||
|
||||
@@ -28,9 +28,27 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
|
||||
|
||||
app_config = get_subagents_app_config()
|
||||
effective_timeout = app_config.get_timeout_for(name)
|
||||
effective_max_turns = app_config.get_max_turns_for(name, config.max_turns)
|
||||
|
||||
overrides = {}
|
||||
if effective_timeout != config.timeout_seconds:
|
||||
logger.debug(f"Subagent '{name}': timeout overridden by config.yaml ({config.timeout_seconds}s -> {effective_timeout}s)")
|
||||
config = replace(config, timeout_seconds=effective_timeout)
|
||||
logger.debug(
|
||||
"Subagent '%s': timeout overridden by config.yaml (%ss -> %ss)",
|
||||
name,
|
||||
config.timeout_seconds,
|
||||
effective_timeout,
|
||||
)
|
||||
overrides["timeout_seconds"] = effective_timeout
|
||||
if effective_max_turns != config.max_turns:
|
||||
logger.debug(
|
||||
"Subagent '%s': max_turns overridden by config.yaml (%s -> %s)",
|
||||
name,
|
||||
config.max_turns,
|
||||
effective_max_turns,
|
||||
)
|
||||
overrides["max_turns"] = effective_max_turns
|
||||
if overrides:
|
||||
config = replace(config, **overrides)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
Reference in New Issue
Block a user