Merge branch 'main' into rayhpeng/persistence-scaffold

# Conflicts:
#	backend/Dockerfile
#	backend/uv.lock
This commit is contained in:
rayhpeng
2026-04-05 23:40:49 +08:00
37 changed files with 2277 additions and 458 deletions
@@ -21,6 +21,7 @@ class ConversationContext:
timestamp: datetime = field(default_factory=datetime.utcnow)
agent_name: str | None = None
correction_detected: bool = False
reinforcement_detected: bool = False
class MemoryUpdateQueue:
@@ -44,6 +45,7 @@ class MemoryUpdateQueue:
messages: list[Any],
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
"""Add a conversation to the update queue.
@@ -52,6 +54,7 @@ class MemoryUpdateQueue:
messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
"""
config = get_memory_config()
if not config.enabled:
@@ -63,11 +66,13 @@ class MemoryUpdateQueue:
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
)
# Check if this thread already has a pending update
@@ -130,6 +135,7 @@ class MemoryUpdateQueue:
thread_id=context.thread_id,
agent_name=context.agent_name,
correction_detected=context.correction_detected,
reinforcement_detected=context.reinforcement_detected,
)
if success:
logger.info("Memory updated successfully for thread %s", context.thread_id)
@@ -246,7 +246,7 @@ def _fact_content_key(content: Any) -> str | None:
stripped = content.strip()
if not stripped:
return None
return stripped
return stripped.casefold()
class MemoryUpdater:
@@ -272,6 +272,7 @@ class MemoryUpdater:
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Update memory based on conversation messages.
@@ -280,6 +281,7 @@ class MemoryUpdater:
thread_id: Optional thread ID for tracking source.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
Returns:
True if update was successful, False otherwise.
@@ -310,6 +312,14 @@ class MemoryUpdater:
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
@@ -441,6 +451,7 @@ def update_memory_from_conversation(
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Convenience function to update memory from a conversation.
@@ -449,9 +460,10 @@ def update_memory_from_conversation(
thread_id: Optional thread ID.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
Returns:
True if successful, False otherwise.
"""
updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name, correction_detected)
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
@@ -29,6 +29,22 @@ _CORRECTION_PATTERNS = (
re.compile(r"改用"),
)
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
@@ -132,6 +148,29 @@ def detect_correction(messages: list[Any]) -> bool:
return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns.
Complements detect_correction() by identifying when the user confirms the
agent's approach was correct. This allows the memory system to record what
worked well, not just what went wrong.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale signals from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution.
@@ -196,12 +235,14 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
# Queue the filtered conversation for memory update
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue()
queue.add(
thread_id=thread_id,
messages=filtered_messages,
agent_name=self._agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
return None
@@ -15,6 +15,11 @@ class SubagentOverrideConfig(BaseModel):
ge=1,
description="Timeout in seconds for this subagent (None = use global default)",
)
max_turns: int | None = Field(
default=None,
ge=1,
description="Maximum turns for this subagent (None = use global or builtin default)",
)
class SubagentsAppConfig(BaseModel):
@@ -25,6 +30,11 @@ class SubagentsAppConfig(BaseModel):
ge=1,
description="Default timeout in seconds for all subagents (default: 900 = 15 minutes)",
)
max_turns: int | None = Field(
default=None,
ge=1,
description="Optional default max-turn override for all subagents (None = keep builtin defaults)",
)
agents: dict[str, SubagentOverrideConfig] = Field(
default_factory=dict,
description="Per-agent configuration overrides keyed by agent name",
@@ -44,6 +54,15 @@ class SubagentsAppConfig(BaseModel):
return override.timeout_seconds
return self.timeout_seconds
def get_max_turns_for(self, agent_name: str, builtin_default: int) -> int:
"""Get the effective max_turns for a specific agent."""
override = self.agents.get(agent_name)
if override is not None and override.max_turns is not None:
return override.max_turns
if self.max_turns is not None:
return self.max_turns
return builtin_default
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
@@ -58,8 +77,26 @@ def load_subagents_config_from_dict(config_dict: dict) -> None:
global _subagents_config
_subagents_config = SubagentsAppConfig(**config_dict)
overrides_summary = {name: f"{override.timeout_seconds}s" for name, override in _subagents_config.agents.items() if override.timeout_seconds is not None}
overrides_summary = {}
for name, override in _subagents_config.agents.items():
parts = []
if override.timeout_seconds is not None:
parts.append(f"timeout={override.timeout_seconds}s")
if override.max_turns is not None:
parts.append(f"max_turns={override.max_turns}")
if parts:
overrides_summary[name] = ", ".join(parts)
if overrides_summary:
logger.info(f"Subagents config loaded: default timeout={_subagents_config.timeout_seconds}s, per-agent overrides={overrides_summary}")
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
overrides_summary,
)
else:
logger.info(f"Subagents config loaded: default timeout={_subagents_config.timeout_seconds}s, no per-agent overrides")
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
)
@@ -366,12 +366,17 @@ def _path_variants(path: str) -> set[str]:
return {path, path.replace("\\", "/"), path.replace("/", "\\")}
def _path_separator_for_style(path: str) -> str:
return "\\" if "\\" in path and "/" not in path else "/"
def _join_path_preserving_style(base: str, relative: str) -> str:
if not relative:
return base
if "/" in base and "\\" not in base:
return f"{base.rstrip('/')}/{relative}"
return str(Path(base) / relative)
separator = _path_separator_for_style(base)
normalized_relative = relative.replace("\\" if separator == "/" else "/", separator).lstrip("/\\")
stripped_base = base.rstrip("/\\")
return f"{stripped_base}{separator}{normalized_relative}"
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
@@ -416,7 +421,10 @@ def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
return actual_base
if path.startswith(f"{virtual_base}/"):
rest = path[len(virtual_base) :].lstrip("/")
return _join_path_preserving_style(actual_base, rest)
result = _join_path_preserving_style(actual_base, rest)
if path.endswith("/") and not result.endswith(("/", "\\")):
result += _path_separator_for_style(actual_base)
return result
return path
@@ -801,7 +809,8 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
if sandbox is None:
raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id)
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
return sandbox
@@ -836,7 +845,8 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if sandbox_id is not None:
sandbox = get_sandbox_provider().get(sandbox_id)
if sandbox is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
return sandbox
# Sandbox was released, fall through to acquire new one
@@ -858,7 +868,8 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if sandbox is None:
raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id)
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
return sandbox
@@ -43,5 +43,5 @@ You have access to the sandbox environment:
tools=["bash", "ls", "read_file", "write_file", "str_replace"], # Sandbox tools only
disallowed_tools=["task", "ask_clarification", "present_files"],
model="inherit",
max_turns=30,
max_turns=60,
)
@@ -44,5 +44,5 @@ You have access to the same sandbox environment as the parent agent:
tools=None, # Inherit all tools from parent
disallowed_tools=["task", "ask_clarification", "present_files"], # Prevent nesting and clarification
model="inherit",
max_turns=50,
max_turns=100,
)
@@ -28,9 +28,27 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
app_config = get_subagents_app_config()
effective_timeout = app_config.get_timeout_for(name)
effective_max_turns = app_config.get_max_turns_for(name, config.max_turns)
overrides = {}
if effective_timeout != config.timeout_seconds:
logger.debug(f"Subagent '{name}': timeout overridden by config.yaml ({config.timeout_seconds}s -> {effective_timeout}s)")
config = replace(config, timeout_seconds=effective_timeout)
logger.debug(
"Subagent '%s': timeout overridden by config.yaml (%ss -> %ss)",
name,
config.timeout_seconds,
effective_timeout,
)
overrides["timeout_seconds"] = effective_timeout
if effective_max_turns != config.max_turns:
logger.debug(
"Subagent '%s': max_turns overridden by config.yaml (%s -> %s)",
name,
config.max_turns,
effective_max_turns,
)
overrides["max_turns"] = effective_max_turns
if overrides:
config = replace(config, **overrides)
return config