Compare commits

..

5 Commits

Author SHA1 Message Date
Willem Jiang 11a362e5e5 Merge branch 'main' into rayhpeng/fix-run-manager-store-atomicity 2026-05-22 21:38:23 +08:00
rayhpeng 85402405ec clarify run creation rollback on cancellation 2026-05-22 18:10:53 +08:00
rayhpeng 43eb643910 fix run manager test cleanup await 2026-05-22 18:01:55 +08:00
rayhpeng f3e3a350ce fix run creation cancellation rollback 2026-05-22 17:58:28 +08:00
rayhpeng 0fae7c9cbb fix runtime run creation persistence atomicity 2026-05-22 11:10:06 +08:00
36 changed files with 285 additions and 1778 deletions
-35
View File
@@ -37,36 +37,11 @@ if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.thread_meta.base import ThreadMetaStore from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.runtime import RunRecord
T = TypeVar("T") T = TypeVar("T")
async def _mark_latest_recovered_threads_error(
run_manager: RunManager,
thread_store: ThreadMetaStore,
recovered_runs: list[RunRecord],
) -> None:
"""Mark thread status as error only when its newest run was recovered."""
recovered_by_thread: dict[str, set[str]] = {}
for record in recovered_runs:
recovered_by_thread.setdefault(record.thread_id, set()).add(record.run_id)
for thread_id, recovered_run_ids in recovered_by_thread.items():
try:
latest_runs = await run_manager.list_by_thread(thread_id, user_id=None, limit=1)
except Exception:
logger.warning("Failed to find latest run for thread %s during run reconciliation", thread_id, exc_info=True)
continue
if not latest_runs or latest_runs[0].run_id not in recovered_run_ids:
continue
try:
await thread_store.update_status(thread_id, "error", user_id=None)
except Exception:
logger.warning("Failed to mark thread %s as error during run reconciliation", thread_id, exc_info=True)
def get_config() -> AppConfig: def get_config() -> AppConfig:
"""Return the freshest ``AppConfig`` for the current request. """Return the freshest ``AppConfig`` for the current request.
@@ -163,16 +138,6 @@ async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGen
# RunManager with store backing for persistence # RunManager with store backing for persistence
app.state.run_manager = RunManager(store=app.state.run_store) app.state.run_manager = RunManager(store=app.state.run_store)
if getattr(config.database, "backend", None) == "sqlite":
from deerflow.utils.time import now_iso
# Startup-only recovery: clean shutdowns return no active rows and
# the thread-status update below becomes a no-op.
recovered_runs = await app.state.run_manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
before=now_iso(),
)
await _mark_latest_recovered_threads_error(app.state.run_manager, app.state.thread_store, recovered_runs)
try: try:
yield yield
+2 -25
View File
@@ -66,14 +66,6 @@ class RunResponse(BaseModel):
multitask_strategy: str = "reject" multitask_strategy: str = "reject"
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
llm_call_count: int = 0
lead_agent_tokens: int = 0
subagent_tokens: int = 0
middleware_tokens: int = 0
message_count: int = 0
class ThreadTokenUsageModelBreakdown(BaseModel): class ThreadTokenUsageModelBreakdown(BaseModel):
@@ -119,14 +111,6 @@ def _record_to_response(record: RunRecord) -> RunResponse:
multitask_strategy=record.multitask_strategy, multitask_strategy=record.multitask_strategy,
created_at=record.created_at, created_at=record.created_at,
updated_at=record.updated_at, updated_at=record.updated_at,
total_input_tokens=record.total_input_tokens,
total_output_tokens=record.total_output_tokens,
total_tokens=record.total_tokens,
llm_call_count=record.llm_call_count,
lead_agent_tokens=record.lead_agent_tokens,
subagent_tokens=record.subagent_tokens,
middleware_tokens=record.middleware_tokens,
message_count=record.message_count,
) )
@@ -418,15 +402,8 @@ async def list_run_events(
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse) @router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
@require_permission("threads", "read", owner_check=True) @require_permission("threads", "read", owner_check=True)
async def thread_token_usage( async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse:
thread_id: str,
request: Request,
include_active: bool = Query(default=False, description="Include running run progress snapshots"),
) -> ThreadTokenUsageResponse:
"""Thread-level token usage aggregation.""" """Thread-level token usage aggregation."""
run_store = get_run_store(request) run_store = get_run_store(request)
if include_active: agg = await run_store.aggregate_tokens_by_thread(thread_id)
agg = await run_store.aggregate_tokens_by_thread(thread_id, include_active=True)
else:
agg = await run_store.aggregate_tokens_by_thread(thread_id)
return ThreadTokenUsageResponse(thread_id=thread_id, **agg) return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
@@ -15,7 +15,6 @@ to the end of the message list as before_model + add_messages reducer would do.
import json import json
import logging import logging
from collections import defaultdict, deque
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import override from typing import override
@@ -110,10 +109,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
This normalizes model-bound causal order before provider serialization while This normalizes model-bound causal order before provider serialization while
preserving already-valid transcripts unchanged. preserving already-valid transcripts unchanged.
""" """
tool_messages_by_id: dict[str, deque[ToolMessage]] = defaultdict(deque) tool_messages_by_id: dict[str, ToolMessage] = {}
for msg in messages: for msg in messages:
if isinstance(msg, ToolMessage): if isinstance(msg, ToolMessage):
tool_messages_by_id[msg.tool_call_id].append(msg) tool_messages_by_id.setdefault(msg.tool_call_id, msg)
tool_call_ids: set[str] = set() tool_call_ids: set[str] = set()
for msg in messages: for msg in messages:
@@ -125,6 +124,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
tool_call_ids.add(tc_id) tool_call_ids.add(tc_id)
patched: list = [] patched: list = []
consumed_tool_msg_ids: set[str] = set()
patch_count = 0 patch_count = 0
for msg in messages: for msg in messages:
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids: if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
@@ -136,13 +136,13 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
for tc in self._message_tool_calls(msg): for tc in self._message_tool_calls(msg):
tc_id = tc.get("id") tc_id = tc.get("id")
if not tc_id: if not tc_id or tc_id in consumed_tool_msg_ids:
continue continue
tool_msg_queue = tool_messages_by_id.get(tc_id) existing_tool_msg = tool_messages_by_id.get(tc_id)
existing_tool_msg = tool_msg_queue.popleft() if tool_msg_queue else None
if existing_tool_msg is not None: if existing_tool_msg is not None:
patched.append(existing_tool_msg) patched.append(existing_tool_msg)
consumed_tool_msg_ids.add(tc_id)
else: else:
patched.append( patched.append(
ToolMessage( ToolMessage(
@@ -152,6 +152,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
status="error", status="error",
) )
) )
consumed_tool_msg_ids.add(tc_id)
patch_count += 1 patch_count += 1
if patched == messages: if patched == messages:
@@ -94,35 +94,25 @@ class RunRepository(RunStore):
created_at=None, created_at=None,
follow_up_to_run_id=None, follow_up_to_run_id=None,
): ):
"""Insert or update a run row.
``RunManager`` retries ``put`` after transient SQLite failures. Making
this operation idempotent prevents a successful-but-unacknowledged first
commit from turning the retry into a primary-key failure.
"""
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put") resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put")
now = datetime.now(UTC) now = datetime.now(UTC)
created = datetime.fromisoformat(created_at) if created_at else now row = RunRow(
values = { run_id=run_id,
"thread_id": thread_id, thread_id=thread_id,
"assistant_id": assistant_id, assistant_id=assistant_id,
"user_id": resolved_user_id, user_id=resolved_user_id,
"model_name": self._normalize_model_name(model_name), model_name=self._normalize_model_name(model_name),
"status": status, status=status,
"multitask_strategy": multitask_strategy, multitask_strategy=multitask_strategy,
"metadata_json": self._safe_json(metadata) or {}, metadata_json=self._safe_json(metadata) or {},
"kwargs_json": self._safe_json(kwargs) or {}, kwargs_json=self._safe_json(kwargs) or {},
"error": error, error=error,
"follow_up_to_run_id": follow_up_to_run_id, follow_up_to_run_id=follow_up_to_run_id,
"updated_at": now, created_at=datetime.fromisoformat(created_at) if created_at else now,
} updated_at=now,
)
async with self._sf() as session: async with self._sf() as session:
row = await session.get(RunRow, run_id) session.add(row)
if row is None:
session.add(RunRow(run_id=run_id, created_at=created, **values))
else:
for key, value in values.items():
setattr(row, key, value)
await session.commit() await session.commit()
async def get( async def get(
@@ -156,14 +146,13 @@ class RunRepository(RunStore):
result = await session.execute(stmt) result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()] return [self._row_to_dict(r) for r in result.scalars()]
async def update_status(self, run_id, status, *, error=None) -> bool: async def update_status(self, run_id, status, *, error=None):
values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)} values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)}
if error is not None: if error is not None:
values["error"] = error values["error"] = error
async with self._sf() as session: async with self._sf() as session:
result = await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit() await session.commit()
return result.rowcount != 0
async def update_model_name(self, run_id, model_name): async def update_model_name(self, run_id, model_name):
async with self._sf() as session: async with self._sf() as session:
@@ -198,26 +187,6 @@ class RunRepository(RunStore):
result = await session.execute(stmt) result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()] return [self._row_to_dict(r) for r in result.scalars()]
async def list_inflight(self, *, before=None):
"""Return persisted active runs for startup recovery."""
if before is None:
before_dt = datetime.now(UTC)
elif isinstance(before, datetime):
before_dt = before
else:
before_dt = datetime.fromisoformat(before)
stmt = (
select(RunRow)
.where(
RunRow.status.in_(("pending", "running")),
RunRow.created_at <= before_dt,
)
.order_by(RunRow.created_at.asc())
)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def update_run_completion( async def update_run_completion(
self, self,
run_id: str, run_id: str,
@@ -234,11 +203,8 @@ class RunRepository(RunStore):
last_ai_message: str | None = None, last_ai_message: str | None = None,
first_human_message: str | None = None, first_human_message: str | None = None,
error: str | None = None, error: str | None = None,
) -> bool: ) -> None:
"""Update status + token usage + convenience fields on run completion. """Update status + token usage + convenience fields on run completion."""
Returns ``False`` when no run row matched the requested ``run_id``.
"""
values: dict[str, Any] = { values: dict[str, Any] = {
"status": status, "status": status,
"total_input_tokens": total_input_tokens, "total_input_tokens": total_input_tokens,
@@ -258,52 +224,12 @@ class RunRepository(RunStore):
if error is not None: if error is not None:
values["error"] = error values["error"] = error
async with self._sf() as session: async with self._sf() as session:
result = await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit()
return result.rowcount != 0
async def update_run_progress(
self,
run_id: str,
*,
total_input_tokens: int | None = None,
total_output_tokens: int | None = None,
total_tokens: int | None = None,
llm_call_count: int | None = None,
lead_agent_tokens: int | None = None,
subagent_tokens: int | None = None,
middleware_tokens: int | None = None,
message_count: int | None = None,
last_ai_message: str | None = None,
first_human_message: str | None = None,
) -> None:
"""Update token usage + convenience fields while a run is still active."""
values: dict[str, Any] = {"updated_at": datetime.now(UTC)}
optional_counters = {
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_tokens": total_tokens,
"llm_call_count": llm_call_count,
"lead_agent_tokens": lead_agent_tokens,
"subagent_tokens": subagent_tokens,
"middleware_tokens": middleware_tokens,
"message_count": message_count,
}
for key, value in optional_counters.items():
if value is not None:
values[key] = value
if last_ai_message is not None:
values["last_ai_message"] = last_ai_message[:2000]
if first_human_message is not None:
values["first_human_message"] = first_human_message[:2000]
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id, RunRow.status == "running").values(**values))
await session.commit() await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]: async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
"""Aggregate token usage via a single SQL GROUP BY query.""" """Aggregate token usage via a single SQL GROUP BY query."""
statuses = ("success", "error", "running") if include_active else ("success", "error") _completed = RunRow.status.in_(("success", "error"))
_completed = RunRow.status.in_(statuses)
_thread = RunRow.thread_id == thread_id _thread = RunRow.thread_id == thread_id
model_name = func.coalesce(RunRow.model_name, "unknown") model_name = func.coalesce(RunRow.model_name, "unknown")
@@ -20,7 +20,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import time import time
from collections.abc import Awaitable, Callable, Mapping from collections.abc import Mapping
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
from uuid import UUID from uuid import UUID
@@ -46,8 +46,6 @@ class RunJournal(BaseCallbackHandler):
*, *,
track_token_usage: bool = True, track_token_usage: bool = True,
flush_threshold: int = 20, flush_threshold: int = 20,
progress_reporter: Callable[[dict], Awaitable[None]] | None = None,
progress_flush_interval: float = 5.0,
): ):
super().__init__() super().__init__()
self.run_id = run_id self.run_id = run_id
@@ -55,16 +53,10 @@ class RunJournal(BaseCallbackHandler):
self._store = event_store self._store = event_store
self._track_tokens = track_token_usage self._track_tokens = track_token_usage
self._flush_threshold = flush_threshold self._flush_threshold = flush_threshold
self._progress_reporter = progress_reporter
self._progress_flush_interval = progress_flush_interval
# Write buffer # Write buffer
self._buffer: list[dict] = [] self._buffer: list[dict] = []
self._pending_flush_tasks: set[asyncio.Task[None]] = set() self._pending_flush_tasks: set[asyncio.Task[None]] = set()
self._pending_progress_task: asyncio.Task[None] | None = None
self._pending_progress_delayed = False
self._progress_dirty = False
self._last_progress_flush = 0.0
# Token accumulators # Token accumulators
self._total_input_tokens = 0 self._total_input_tokens = 0
@@ -302,8 +294,6 @@ class RunJournal(BaseCallbackHandler):
else: else:
self._lead_agent_tokens += total_tk self._lead_agent_tokens += total_tk
self._schedule_progress_flush()
if messages: if messages:
self._counted_message_llm_run_ids.add(str(run_id)) self._counted_message_llm_run_ids.add(str(run_id))
@@ -455,8 +445,6 @@ class RunJournal(BaseCallbackHandler):
else: else:
self._lead_agent_tokens += total_tk self._lead_agent_tokens += total_tk
self._schedule_progress_flush()
def set_first_human_message(self, content: str) -> None: def set_first_human_message(self, content: str) -> None:
"""Record the first human message for convenience fields.""" """Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None self._first_human_msg = content[:2000] if content else None
@@ -486,14 +474,6 @@ class RunJournal(BaseCallbackHandler):
"""Force flush remaining buffer. Called in worker's finally block.""" """Force flush remaining buffer. Called in worker's finally block."""
if self._pending_flush_tasks: if self._pending_flush_tasks:
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True) await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
while self._pending_progress_task is not None and not self._pending_progress_task.done():
if self._pending_progress_delayed:
self._pending_progress_task.cancel()
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
self._progress_dirty = False
self._pending_progress_delayed = False
break
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
while self._buffer: while self._buffer:
batch = self._buffer[: self._flush_threshold] batch = self._buffer[: self._flush_threshold]
@@ -504,57 +484,6 @@ class RunJournal(BaseCallbackHandler):
self._buffer = batch + self._buffer self._buffer = batch + self._buffer
raise raise
def _schedule_progress_flush(self) -> None:
"""Best-effort throttled progress snapshot for active run visibility."""
if self._progress_reporter is None:
return
now = time.monotonic()
elapsed = now - self._last_progress_flush
if elapsed < self._progress_flush_interval:
self._progress_dirty = True
self._schedule_delayed_progress_flush(self._progress_flush_interval - elapsed)
return
if self._pending_progress_task is not None and not self._pending_progress_task.done():
self._progress_dirty = True
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
self._progress_dirty = False
self._pending_progress_task = loop.create_task(self._flush_progress_async(snapshot=self.get_completion_data()))
def _schedule_delayed_progress_flush(self, delay: float) -> None:
if self._pending_progress_task is not None and not self._pending_progress_task.done():
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
delay = max(0.0, delay)
self._pending_progress_delayed = delay > 0
self._pending_progress_task = loop.create_task(self._flush_progress_async(delay=delay))
async def _flush_progress_async(self, *, snapshot: dict | None = None, delay: float = 0.0) -> None:
if self._progress_reporter is None:
return
if delay > 0:
self._pending_progress_delayed = True
await asyncio.sleep(delay)
self._pending_progress_delayed = False
dirty_before_write = self._progress_dirty
self._progress_dirty = False
snapshot_to_write = snapshot or self.get_completion_data()
try:
await self._progress_reporter(snapshot_to_write)
self._last_progress_flush = time.monotonic()
except Exception:
logger.warning("Failed to persist progress snapshot for run %s", self.run_id, exc_info=True)
if dirty_before_write or self._progress_dirty:
self._progress_dirty = False
self._pending_progress_task = None
self._schedule_delayed_progress_flush(self._progress_flush_interval)
def get_completion_data(self) -> dict: def get_completion_data(self) -> dict:
"""Return accumulated token and message data for run completion.""" """Return accumulated token and message data for run completion."""
return { return {
@@ -4,9 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import sqlite3
import uuid import uuid
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@@ -19,57 +17,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_RETRYABLE_SQLITE_MESSAGES = (
"database is locked",
"database table is locked",
"database is busy",
)
_RETRYABLE_SQLITE_ERROR_CODES = {
sqlite3.SQLITE_BUSY,
sqlite3.SQLITE_LOCKED,
}
def _is_retryable_persistence_error(exc: BaseException) -> bool:
"""Return True for transient SQLite persistence failures.
SQLite lock contention normally surfaces through either sqlite3 exceptions
or SQLAlchemy wrappers. The short bounded retry here protects run status
finalization from transient writer pressure without hiding permanent
failures forever.
"""
pending: list[BaseException] = [exc]
seen: set[int] = set()
while pending:
current = pending.pop()
if id(current) in seen:
continue
seen.add(id(current))
message = str(current).lower()
if any(fragment in message for fragment in _RETRYABLE_SQLITE_MESSAGES):
return True
if isinstance(current, (sqlite3.OperationalError, sqlite3.DatabaseError)):
error_code = getattr(current, "sqlite_errorcode", None)
if error_code in _RETRYABLE_SQLITE_ERROR_CODES:
return True
for chained in (getattr(current, "orig", None), current.__cause__, current.__context__):
if isinstance(chained, BaseException):
pending.append(chained)
return False
@dataclass(frozen=True)
class PersistenceRetryPolicy:
"""Bounded retry policy for short run-store writes."""
max_attempts: int = 5
initial_delay: float = 0.05
max_delay: float = 1.0
backoff_factor: float = 2.0
@dataclass @dataclass
class RunRecord: class RunRecord:
@@ -91,16 +38,6 @@ class RunRecord:
error: str | None = None error: str | None = None
model_name: str | None = None model_name: str | None = None
store_only: bool = False store_only: bool = False
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
llm_call_count: int = 0
lead_agent_tokens: int = 0
subagent_tokens: int = 0
middleware_tokens: int = 0
message_count: int = 0
last_ai_message: str | None = None
first_human_message: str | None = None
class RunManager: class RunManager:
@@ -111,100 +48,41 @@ class RunManager:
that run history survives process restarts. that run history survives process restarts.
""" """
def __init__( def __init__(self, store: RunStore | None = None) -> None:
self,
store: RunStore | None = None,
*,
persistence_retry_policy: PersistenceRetryPolicy | None = None,
) -> None:
self._runs: dict[str, RunRecord] = {} self._runs: dict[str, RunRecord] = {}
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._store = store self._store = store
self._persistence_retry_policy = persistence_retry_policy or PersistenceRetryPolicy()
@staticmethod async def _persist_new_run_to_store(self, record: RunRecord) -> None:
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]: """Persist a newly created run record to the backing store.
return {
"thread_id": record.thread_id,
"assistant_id": record.assistant_id,
"status": record.status.value,
"multitask_strategy": record.multitask_strategy,
"metadata": record.metadata or {},
"kwargs": record.kwargs or {},
"error": error if error is not None else record.error,
"created_at": record.created_at,
"model_name": record.model_name,
}
async def _call_store_with_retry( Initial run creation is part of the run visibility boundary: callers
self, should not observe a run in memory unless its backing store row exists.
operation_name: str, Unlike follow-up status/model updates, failures are propagated so the
run_id: str, caller can treat creation as failed.
operation: Callable[[], Awaitable[Any]], """
) -> Any:
"""Run a short store operation with bounded retries for SQLite pressure."""
policy = self._persistence_retry_policy
attempt = 1
delay = policy.initial_delay
while True:
try:
return await operation()
except Exception as exc:
retryable = _is_retryable_persistence_error(exc)
if attempt >= policy.max_attempts or not retryable:
raise
logger.warning(
"Transient persistence failure during %s for run %s (attempt %d/%d); retrying",
operation_name,
run_id,
attempt,
policy.max_attempts,
exc_info=True,
)
if delay > 0:
await asyncio.sleep(delay)
delay = min(policy.max_delay, delay * policy.backoff_factor if delay else policy.initial_delay)
attempt += 1
async def _persist_snapshot_to_store(self, run_id: str, payload: dict[str, Any]) -> bool:
"""Best-effort persist a previously captured run snapshot."""
if self._store is None: if self._store is None:
return True return
try: await self._store.put(
await self._call_store_with_retry(
"put",
run_id,
lambda: self._store.put(run_id, **payload),
)
return True
except Exception:
logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
return False
async def _persist_to_store(self, record: RunRecord, *, error: str | None = None) -> bool:
"""Best-effort persist run record to backing store."""
return await self._persist_snapshot_to_store(
record.run_id, record.run_id,
self._store_put_payload(record, error=error), thread_id=record.thread_id,
assistant_id=record.assistant_id,
status=record.status.value,
multitask_strategy=record.multitask_strategy,
metadata=record.metadata or {},
kwargs=record.kwargs or {},
created_at=record.created_at,
model_name=record.model_name,
) )
async def _persist_status(self, record: RunRecord, status: RunStatus, *, error: str | None = None) -> bool: async def _persist_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
"""Best-effort persist a status transition to the backing store.""" """Best-effort persist a status transition to the backing store."""
if self._store is None: if self._store is None:
return True return
row_recovery_payload = self._store_put_payload(record, error=error)
try: try:
updated = await self._call_store_with_retry( await self._store.update_status(run_id, status.value, error=error)
"update_status",
record.run_id,
lambda: self._store.update_status(record.run_id, status.value, error=error),
)
if updated is False:
return await self._persist_snapshot_to_store(record.run_id, row_recovery_payload)
return True
except Exception: except Exception:
logger.warning("Failed to persist status update for run %s", record.run_id, exc_info=True) logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
return False
@staticmethod @staticmethod
def _record_from_store(row: dict[str, Any]) -> RunRecord: def _record_from_store(row: dict[str, Any]) -> RunRecord:
@@ -227,72 +105,15 @@ class RunManager:
error=row.get("error"), error=row.get("error"),
model_name=row.get("model_name"), model_name=row.get("model_name"),
store_only=True, store_only=True,
total_input_tokens=row.get("total_input_tokens") or 0,
total_output_tokens=row.get("total_output_tokens") or 0,
total_tokens=row.get("total_tokens") or 0,
llm_call_count=row.get("llm_call_count") or 0,
lead_agent_tokens=row.get("lead_agent_tokens") or 0,
subagent_tokens=row.get("subagent_tokens") or 0,
middleware_tokens=row.get("middleware_tokens") or 0,
message_count=row.get("message_count") or 0,
last_ai_message=row.get("last_ai_message"),
first_human_message=row.get("first_human_message"),
) )
async def update_run_completion(self, run_id: str, **kwargs) -> None: async def update_run_completion(self, run_id: str, **kwargs) -> None:
"""Persist token usage and completion data to the backing store.""" """Persist token usage and completion data to the backing store."""
row_recovery_payload: dict[str, Any] | None = None if self._store is not None:
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
for key, value in kwargs.items():
if key == "status":
continue
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
row_recovery_payload = self._store_put_payload(record, error=kwargs.get("error"))
if self._store is None:
return
try:
updated = await self._call_store_with_retry(
"update_run_completion",
run_id,
lambda: self._store.update_run_completion(run_id, **kwargs),
)
if updated is False:
if row_recovery_payload is None:
logger.warning("Failed to recreate missing run %s for completion persistence", run_id)
return
if not await self._persist_snapshot_to_store(run_id, row_recovery_payload):
return
recovered = await self._call_store_with_retry(
"update_run_completion",
run_id,
lambda: self._store.update_run_completion(run_id, **kwargs),
)
if recovered is False:
logger.warning("Run completion update for %s affected no rows after row recreation", run_id)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
async def update_run_progress(self, run_id: str, **kwargs) -> None:
"""Persist a running token/message snapshot without changing status."""
should_persist = True
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
should_persist = record.status == RunStatus.running
if record is not None and should_persist:
for key, value in kwargs.items():
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if should_persist and self._store is not None:
try: try:
await self._store.update_run_progress(run_id, **kwargs) await self._store.update_run_completion(run_id, **kwargs)
except Exception: except Exception:
logger.warning("Failed to persist run progress for %s", run_id, exc_info=True) logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
async def create( async def create(
self, self,
@@ -321,7 +142,16 @@ class RunManager:
) )
async with self._lock: async with self._lock:
self._runs[run_id] = record self._runs[run_id] = record
await self._persist_to_store(record) persisted = False
try:
await self._persist_new_run_to_store(record)
persisted = True
except Exception:
logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True)
raise
finally:
if not persisted:
self._runs.pop(run_id, None)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record return record
@@ -408,7 +238,7 @@ class RunManager:
record.updated_at = _now_iso() record.updated_at = _now_iso()
if error is not None: if error is not None:
record.error = error record.error = error
await self._persist_status(record, status, error=error) await self._persist_status(run_id, status, error=error)
logger.info("Run %s -> %s", run_id, status.value) logger.info("Run %s -> %s", run_id, status.value)
async def _persist_model_name(self, run_id: str, model_name: str | None) -> None: async def _persist_model_name(self, run_id: str, model_name: str | None) -> None:
@@ -416,11 +246,7 @@ class RunManager:
if self._store is None: if self._store is None:
return return
try: try:
await self._call_store_with_retry( await self._store.update_model_name(run_id, model_name)
"update_model_name",
run_id,
lambda: self._store.update_model_name(run_id, model_name),
)
except Exception: except Exception:
logger.warning("Failed to persist model_name update for run %s", run_id, exc_info=True) logger.warning("Failed to persist model_name update for run %s", run_id, exc_info=True)
@@ -463,7 +289,7 @@ class RunManager:
record.task.cancel() record.task.cancel()
record.status = RunStatus.interrupted record.status = RunStatus.interrupted
record.updated_at = _now_iso() record.updated_at = _now_iso()
await self._persist_status(record, RunStatus.interrupted) await self._persist_status(run_id, RunStatus.interrupted)
logger.info("Run %s cancelled (action=%s)", run_id, action) logger.info("Run %s cancelled (action=%s)", run_id, action)
return True return True
@@ -491,7 +317,7 @@ class RunManager:
now = _now_iso() now = _now_iso()
_supported_strategies = ("reject", "interrupt", "rollback") _supported_strategies = ("reject", "interrupt", "rollback")
interrupted_records: list[RunRecord] = [] interrupted_run_ids: list[str] = []
async with self._lock: async with self._lock:
if multitask_strategy not in _supported_strategies: if multitask_strategy not in _supported_strategies:
@@ -503,16 +329,8 @@ class RunManager:
raise ConflictError(f"Thread {thread_id} already has an active run") raise ConflictError(f"Thread {thread_id} already has an active run")
if multitask_strategy in ("interrupt", "rollback") and inflight: if multitask_strategy in ("interrupt", "rollback") and inflight:
for r in inflight:
r.abort_action = multitask_strategy
r.abort_event.set()
if r.task is not None and not r.task.done():
r.task.cancel()
r.status = RunStatus.interrupted
r.updated_at = now
interrupted_records.append(r)
logger.info( logger.info(
"Cancelled %d inflight run(s) on thread %s (strategy=%s)", "Preparing to cancel %d inflight run(s) on thread %s (strategy=%s)",
len(inflight), len(inflight),
thread_id, thread_id,
multitask_strategy, multitask_strategy,
@@ -532,67 +350,32 @@ class RunManager:
model_name=model_name, model_name=model_name,
) )
self._runs[run_id] = record self._runs[run_id] = record
persisted = False
try:
await self._persist_new_run_to_store(record)
persisted = True
except Exception:
logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True)
raise
finally:
if not persisted:
self._runs.pop(run_id, None)
for interrupted_record in interrupted_records: if multitask_strategy in ("interrupt", "rollback") and inflight:
await self._persist_status(interrupted_record, RunStatus.interrupted) for r in inflight:
await self._persist_to_store(record) r.abort_action = multitask_strategy
r.abort_event.set()
if r.task is not None and not r.task.done():
r.task.cancel()
r.status = RunStatus.interrupted
r.updated_at = now
interrupted_run_ids.append(r.run_id)
for interrupted_run_id in interrupted_run_ids:
await self._persist_status(interrupted_run_id, RunStatus.interrupted)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record return record
async def reconcile_orphaned_inflight_runs(
self,
*,
error: str,
before: str | None = None,
) -> list[RunRecord]:
"""Mark persisted active runs as failed when no local task owns them.
Gateway runs are process-local: the asyncio task and abort event live in
memory, while the run row is durable. After a SQLite-backed gateway
restart, any persisted ``pending`` or ``running`` row created before
startup cannot still have a local worker. This recovery step turns that
ambiguous state into an explicit error instead of letting the UI show an
indefinite active run.
"""
if self._store is None:
return []
try:
rows = await self._call_store_with_retry(
"list_inflight",
"*",
lambda: self._store.list_inflight(before=before),
)
except Exception:
logger.warning("Failed to list orphaned inflight runs for reconciliation", exc_info=True)
return []
recovered: list[RunRecord] = []
now = _now_iso()
for row in rows:
try:
record = self._record_from_store(row)
except Exception:
logger.warning("Failed to map orphaned run row during reconciliation", exc_info=True)
continue
async with self._lock:
live_record = self._runs.get(record.run_id)
if live_record is not None and live_record.status in (RunStatus.pending, RunStatus.running):
continue
record.status = RunStatus.error
record.error = error
record.updated_at = now
persisted = await self._persist_status(record, RunStatus.error, error=error)
if not persisted:
logger.warning("Skipped orphaned run %s recovery because error status was not persisted", record.run_id)
continue
recovered.append(record)
if recovered:
logger.warning("Recovered %d orphaned inflight run(s) as error", len(recovered))
return recovered
async def has_inflight(self, thread_id: str) -> bool: async def has_inflight(self, thread_id: str) -> bool:
"""Return ``True`` if *thread_id* has a pending or running run.""" """Return ``True`` if *thread_id* has a pending or running run."""
async with self._lock: async with self._lock:
@@ -59,12 +59,7 @@ class RunStore(abc.ABC):
status: str, status: str,
*, *,
error: str | None = None, error: str | None = None,
) -> bool | None: ) -> None:
"""Update a run status.
Returns ``False`` when the store can prove no row was updated. Older or
lightweight stores may return ``None`` when they cannot report rowcount.
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
@@ -97,42 +92,15 @@ class RunStore(abc.ABC):
last_ai_message: str | None = None, last_ai_message: str | None = None,
first_human_message: str | None = None, first_human_message: str | None = None,
error: str | None = None, error: str | None = None,
) -> bool | None:
"""Persist final completion fields.
Returns ``False`` when the store can prove no row was updated.
"""
pass
async def update_run_progress(
self,
run_id: str,
*,
total_input_tokens: int | None = None,
total_output_tokens: int | None = None,
total_tokens: int | None = None,
llm_call_count: int | None = None,
lead_agent_tokens: int | None = None,
subagent_tokens: int | None = None,
middleware_tokens: int | None = None,
message_count: int | None = None,
last_ai_message: str | None = None,
first_human_message: str | None = None,
) -> None: ) -> None:
"""Persist a best-effort running snapshot without changing run status.""" pass
return None
@abc.abstractmethod @abc.abstractmethod
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]: async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
pass pass
@abc.abstractmethod @abc.abstractmethod
async def list_inflight(self, *, before: str | None = None) -> list[dict[str, Any]]: async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
"""Return persisted runs that are still ``pending`` or ``running``."""
pass
@abc.abstractmethod
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage for completed runs in a thread. """Aggregate token usage for completed runs in a thread.
Returns a dict with keys: total_tokens, total_input_tokens, Returns a dict with keys: total_tokens, total_input_tokens,
@@ -65,8 +65,6 @@ class MemoryRunStore(RunStore):
if error is not None: if error is not None:
self._runs[run_id]["error"] = error self._runs[run_id]["error"] = error
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
return True
return False
async def update_model_name(self, run_id, model_name): async def update_model_name(self, run_id, model_name):
if run_id in self._runs: if run_id in self._runs:
@@ -83,15 +81,6 @@ class MemoryRunStore(RunStore):
if value is not None: if value is not None:
self._runs[run_id][key] = value self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
return True
return False
async def update_run_progress(self, run_id, **kwargs):
if run_id in self._runs and self._runs[run_id].get("status") == "running":
for key, value in kwargs.items():
if value is not None:
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def list_pending(self, *, before=None): async def list_pending(self, *, before=None):
now = before or datetime.now(UTC).isoformat() now = before or datetime.now(UTC).isoformat()
@@ -99,15 +88,8 @@ class MemoryRunStore(RunStore):
results.sort(key=lambda r: r["created_at"]) results.sort(key=lambda r: r["created_at"])
return results return results
async def list_inflight(self, *, before=None): async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
now = before or datetime.now(UTC).isoformat() completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
results = [r for r in self._runs.values() if r["status"] in ("pending", "running") and r["created_at"] <= now]
results.sort(key=lambda r: r["created_at"])
return results
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
statuses = ("success", "error", "running") if include_active else ("success", "error")
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]
by_model: dict[str, dict] = {} by_model: dict[str, dict] = {}
for r in completed: for r in completed:
model = r.get("model_name") or "unknown" model = r.get("model_name") or "unknown"
@@ -153,6 +153,8 @@ async def run_agent(
journal = None journal = None
journal = None
# Track whether "events" was requested but skipped # Track whether "events" was requested but skipped
if "events" in requested_modes: if "events" in requested_modes:
logger.info( logger.info(
@@ -175,7 +177,6 @@ async def run_agent(
thread_id=thread_id, thread_id=thread_id,
event_store=event_store, event_store=event_store,
track_token_usage=getattr(run_events_config, "track_token_usage", True), track_token_usage=getattr(run_events_config, "track_token_usage", True),
progress_reporter=lambda snapshot: run_manager.update_run_progress(run_id, **snapshot),
) )
# 1. Mark running # 1. Mark running
@@ -218,70 +218,6 @@ class TestBuildPatchedMessagesPatching:
assert mw._build_patched_messages(msgs) is None assert mw._build_patched_messages(msgs) is None
def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self):
mw = DanglingToolCallMiddleware()
msgs = [
HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:11"),
_tc("web_search", "web_search:12"),
_tc("web_search", "web_search:13"),
]
),
_tool_msg("web_search:11", "web_search"),
_tool_msg("web_search:12", "web_search"),
_tool_msg("web_search:13", "web_search"),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:9"),
_tc("web_search", "web_search:10"),
_tc("web_search", "web_search:11"),
]
),
_tool_msg("web_search:9", "web_search"),
_tool_msg("web_search:10", "web_search"),
_tool_msg("web_search:11", "web_search"),
]
assert mw._build_patched_messages(msgs) is None
def test_reused_tool_call_id_patches_second_dangling_occurrence(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
_tool_msg("web_search:11", "web_search"),
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "web_search:11"
assert patched[1].status == "success"
assert isinstance(patched[3], ToolMessage)
assert patched[3].tool_call_id == "web_search:11"
assert patched[3].status == "error"
def test_reused_tool_call_id_consumes_later_result_for_first_dangling_occurrence(self):
mw = DanglingToolCallMiddleware()
result = _tool_msg("web_search:11", "web_search")
msgs = [
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
result,
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert patched[1] is result
assert patched[1].status == "success"
assert isinstance(patched[3], ToolMessage)
assert patched[3].tool_call_id == "web_search:11"
assert patched[3].status == "error"
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self): def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
mw = DanglingToolCallMiddleware() mw = DanglingToolCallMiddleware()
msgs = [ msgs = [
-127
View File
@@ -1,127 +0,0 @@
"""Gateway startup recovery for stale persisted runs."""
from __future__ import annotations
from contextlib import asynccontextmanager
from types import SimpleNamespace
import pytest
from fastapi import FastAPI
import deerflow.runtime as runtime_module
from app.gateway import deps as gateway_deps
from deerflow.persistence import engine as engine_module
from deerflow.persistence import thread_meta as thread_meta_module
from deerflow.runtime.checkpointer import async_provider as checkpointer_module
from deerflow.runtime.events import store as event_store_module
@asynccontextmanager
async def _fake_context(value):
yield value
class _FakeRunManager:
"""RunManager double that records startup reconciliation calls."""
instances: list[_FakeRunManager] = []
recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")]
latest_by_thread: dict[str, list[SimpleNamespace]] = {}
def __init__(self, *, store):
self.store = store
self.reconcile_calls: list[dict] = []
self.list_by_thread_calls: list[dict] = []
_FakeRunManager.instances.append(self)
async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None):
self.reconcile_calls.append({"error": error, "before": before})
return self.recovered_runs
async def list_by_thread(self, thread_id: str, *, user_id=None, limit: int = 100):
self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit})
return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit])
class _FakeThreadStore:
def __init__(self) -> None:
self.status_updates: list[tuple[str, str, str | None]] = []
async def update_status(self, thread_id: str, status: str, *, user_id=None) -> None:
self.status_updates.append((thread_id, status, user_id))
@pytest.mark.anyio
async def test_sqlite_runtime_reconciles_orphaned_runs_on_startup(monkeypatch):
"""SQLite startup should recover stale active runs before serving requests."""
app = FastAPI()
config = SimpleNamespace(
database=SimpleNamespace(backend="sqlite"),
run_events=SimpleNamespace(backend="memory"),
)
thread_store = _FakeThreadStore()
_FakeRunManager.instances.clear()
_FakeRunManager.recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")]
_FakeRunManager.latest_by_thread = {}
async def fake_init_engine_from_config(_database):
return None
async def fake_close_engine():
return None
monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config)
monkeypatch.setattr(engine_module, "get_session_factory", lambda: None)
monkeypatch.setattr(engine_module, "close_engine", fake_close_engine)
monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object()))
monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object()))
monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object()))
monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store)
monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object())
monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager)
async with gateway_deps.langgraph_runtime(app, config):
pass
assert len(_FakeRunManager.instances) == 1
assert _FakeRunManager.instances[0].reconcile_calls
assert _FakeRunManager.instances[0].reconcile_calls[0]["error"]
assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}]
assert thread_store.status_updates == [("thread-1", "error", None)]
@pytest.mark.anyio
async def test_sqlite_runtime_does_not_mark_thread_error_when_newer_run_is_success(monkeypatch):
"""Startup recovery should not let an old orphaned run overwrite a newer terminal thread state."""
app = FastAPI()
config = SimpleNamespace(
database=SimpleNamespace(backend="sqlite"),
run_events=SimpleNamespace(backend="memory"),
)
thread_store = _FakeThreadStore()
_FakeRunManager.instances.clear()
_FakeRunManager.recovered_runs = [SimpleNamespace(run_id="old-running", thread_id="thread-1")]
_FakeRunManager.latest_by_thread = {"thread-1": [SimpleNamespace(run_id="newer-success", thread_id="thread-1", status="success")]}
async def fake_init_engine_from_config(_database):
return None
async def fake_close_engine():
return None
monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config)
monkeypatch.setattr(engine_module, "get_session_factory", lambda: None)
monkeypatch.setattr(engine_module, "close_engine", fake_close_engine)
monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object()))
monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object()))
monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object()))
monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store)
monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object())
monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager)
async with gateway_deps.langgraph_runtime(app, config):
pass
assert len(_FakeRunManager.instances) == 1
assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}]
assert thread_store.status_updates == []
-104
View File
@@ -714,110 +714,6 @@ class TestExternalUsageRecords:
assert j._subagent_tokens == 0 assert j._subagent_tokens == 0
class TestProgressSnapshots:
@pytest.mark.anyio
async def test_on_llm_end_reports_progress_snapshot(self):
snapshots: list[dict] = []
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=0,
)
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("Answer", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
await j.flush()
assert snapshots
assert snapshots[-1]["total_tokens"] == 15
assert snapshots[-1]["llm_call_count"] == 1
assert snapshots[-1]["message_count"] == 1
assert snapshots[-1]["last_ai_message"] == "Answer"
@pytest.mark.anyio
async def test_throttled_progress_flush_emits_trailing_snapshot(self):
snapshots: list[dict] = []
trailing_seen = asyncio.Event()
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
if snapshot["total_tokens"] == 45:
trailing_seen.set()
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=0.01,
)
j.on_llm_end(
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
j.on_llm_end(
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.wait_for(trailing_seen.wait(), timeout=1.0)
await j.flush()
assert len(snapshots) >= 2
assert snapshots[-1]["total_tokens"] == 45
assert snapshots[-1]["llm_call_count"] == 2
assert snapshots[-1]["last_ai_message"] == "Second"
@pytest.mark.anyio
async def test_flush_cancels_delayed_progress_without_final_progress_write(self):
snapshots: list[dict] = []
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=10.0,
)
j.on_llm_end(
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.sleep(0)
assert snapshots[-1]["total_tokens"] == 15
j.on_llm_end(
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.wait_for(j.flush(), timeout=0.2)
assert snapshots[-1]["total_tokens"] == 15
assert snapshots[-1]["llm_call_count"] == 1
assert snapshots[-1]["last_ai_message"] == "First"
class TestChatModelStartHumanMessage: class TestChatModelStartHumanMessage:
"""Tests for on_chat_model_start extracting the first human message.""" """Tests for on_chat_model_start extracting the first human message."""
+122 -240
View File
@@ -1,15 +1,11 @@
"""Tests for RunManager.""" """Tests for RunManager."""
import logging import asyncio
import re import re
import sqlite3
from typing import Any
import pytest import pytest
from sqlalchemy.exc import DatabaseError as SQLAlchemyDatabaseError
from deerflow.runtime import DisconnectMode, RunManager, RunStatus from deerflow.runtime import DisconnectMode, RunManager, RunStatus
from deerflow.runtime.runs.manager import PersistenceRetryPolicy
from deerflow.runtime.runs.store.memory import MemoryRunStore from deerflow.runtime.runs.store.memory import MemoryRunStore
ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
@@ -20,92 +16,6 @@ def manager() -> RunManager:
return RunManager() return RunManager()
class FlakyStatusRunStore(MemoryRunStore):
"""Memory run store that simulates transient SQLite status-write failures."""
def __init__(self, *, status_failures: int) -> None:
super().__init__()
self.status_failures = status_failures
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
if self.status_failures > 0:
self.status_failures -= 1
raise sqlite3.OperationalError("database is locked")
return await super().update_status(run_id, status, error=error)
class MissingRowStatusRunStore(MemoryRunStore):
"""Memory run store that reports a missing row for status updates."""
async def update_status(self, run_id, status, *, error=None):
await super().update_status(run_id, status, error=error)
return False
class PermanentStatusRunStore(MemoryRunStore):
"""Memory run store that simulates a permanent SQLAlchemy write failure."""
def __init__(self) -> None:
super().__init__()
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
raise SQLAlchemyDatabaseError(
"UPDATE runs SET status = :status WHERE run_id = :run_id",
{"status": status, "run_id": run_id},
sqlite3.DatabaseError("no such table: runs"),
)
class FailingStatusRunStore(MemoryRunStore):
"""Memory run store that always fails status updates."""
def __init__(self) -> None:
super().__init__()
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
raise sqlite3.OperationalError("database is locked")
class MissingCompletionRunStore(MemoryRunStore):
"""Memory run store that reports one missing row for completion updates."""
def __init__(self) -> None:
super().__init__()
self.completion_update_attempts = 0
async def update_run_completion(self, run_id, *, status, **kwargs):
self.completion_update_attempts += 1
if self.completion_update_attempts == 1:
return False
return await super().update_run_completion(run_id, status=status, **kwargs)
class AlwaysMissingCompletionRunStore(MemoryRunStore):
"""Memory run store that keeps reporting missing rows for completion updates."""
def __init__(self) -> None:
super().__init__()
self.completion_update_attempts = 0
async def update_run_completion(self, run_id, *, status, **kwargs):
self.completion_update_attempts += 1
return False
async def _stored_statuses(store: MemoryRunStore, *run_ids: str) -> dict[str, Any]:
rows = {}
for run_id in run_ids:
row = await store.get(run_id)
rows[run_id] = row["status"] if row else None
return rows
@pytest.mark.anyio @pytest.mark.anyio
async def test_create_and_get(manager: RunManager): async def test_create_and_get(manager: RunManager):
"""Created run should be retrievable with new fields.""" """Created run should be retrievable with new fields."""
@@ -171,155 +81,6 @@ async def test_cancel_persists_interrupted_status_to_store():
assert stored["status"] == "interrupted" assert stored["status"] == "interrupted"
@pytest.mark.anyio
async def test_status_persistence_retries_transient_sqlite_lock():
"""Transient SQLite lock errors should not leave a final status stale."""
store = FlakyStatusRunStore(status_failures=2)
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
await manager.set_status(record.run_id, RunStatus.success)
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "success"
assert store.status_update_attempts >= 4
@pytest.mark.anyio
async def test_status_persistence_recreates_missing_store_row():
"""A final status update should recreate a run row if initial persistence was lost."""
store = MissingRowStatusRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await store.delete(record.run_id)
await manager.set_status(record.run_id, RunStatus.error, error="boom")
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "error"
assert stored["error"] == "boom"
@pytest.mark.anyio
async def test_status_persistence_does_not_retry_permanent_sqlalchemy_errors():
"""Permanent SQLAlchemy failures should not be retried as SQLite pressure."""
store = PermanentStatusRunStore()
manager = RunManager(
store=store,
persistence_retry_policy=PersistenceRetryPolicy(max_attempts=5, initial_delay=0),
)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.error, error="boom")
assert store.status_update_attempts == 1
@pytest.mark.anyio
async def test_completion_persistence_recreates_missing_store_row():
"""Completion updates should recreate a missing row and persist final counters."""
store = MissingCompletionRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
await manager.set_status(record.run_id, RunStatus.success)
await store.delete(record.run_id)
await manager.update_run_completion(
record.run_id,
status="success",
total_tokens=42,
llm_call_count=2,
last_ai_message="done",
)
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "success"
assert stored["total_tokens"] == 42
assert stored["llm_call_count"] == 2
assert stored["last_ai_message"] == "done"
assert store.completion_update_attempts == 2
@pytest.mark.anyio
async def test_completion_persistence_warns_when_recreated_row_still_missing(caplog):
"""A second zero-row completion update after recreation should not be silent."""
store = AlwaysMissingCompletionRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.success)
caplog.set_level(logging.WARNING, logger="deerflow.runtime.runs.manager")
await manager.update_run_completion(record.run_id, status="success", total_tokens=42)
assert store.completion_update_attempts == 2
assert "affected no rows after row recreation" in caplog.text
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_marks_stale_rows_error():
"""Startup recovery should turn persisted active rows into explicit errors."""
store = MemoryRunStore()
await store.put("pending-run", thread_id="thread-1", status="pending", created_at="2026-01-01T00:00:00+00:00")
await store.put("running-run", thread_id="thread-1", status="running", created_at="2026-01-01T00:00:01+00:00")
await store.put("success-run", thread_id="thread-1", status="success", created_at="2026-01-01T00:00:02+00:00")
manager = RunManager(store=store)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
before="2026-01-01T00:00:02+00:00",
)
assert {record.run_id for record in recovered} == {"pending-run", "running-run"}
assert await _stored_statuses(store, "pending-run", "running-run", "success-run") == {
"pending-run": "error",
"running-run": "error",
"success-run": "success",
}
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_skips_live_local_run():
"""Startup recovery should not mark an active row orphaned when this worker owns it."""
store = MemoryRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
)
stored = await store.get(record.run_id)
assert recovered == []
assert stored["status"] == "running"
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_skips_rows_when_error_status_is_not_persisted():
"""Startup recovery must not report a row as recovered if the error update failed."""
store = FailingStatusRunStore()
await store.put("running-run", thread_id="thread-1", status="running", created_at="2026-01-01T00:00:00+00:00")
manager = RunManager(
store=store,
persistence_retry_policy=PersistenceRetryPolicy(max_attempts=2, initial_delay=0),
)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
before="2026-01-01T00:00:01+00:00",
)
stored = await store.get("running-run")
assert recovered == []
assert stored["status"] == "running"
assert store.status_update_attempts == 2
@pytest.mark.anyio @pytest.mark.anyio
async def test_cancel_not_inflight(manager: RunManager): async def test_cancel_not_inflight(manager: RunManager):
"""Cancelling a completed run should return False.""" """Cancelling a completed run should return False."""
@@ -471,6 +232,81 @@ async def test_create_record_is_not_store_only(manager: RunManager):
assert record.store_only is False assert record.store_only is False
@pytest.mark.anyio
async def test_create_rolls_back_in_memory_record_on_store_failure():
"""create() must fail and hide the run when the initial store write fails."""
from unittest.mock import AsyncMock
store = MemoryRunStore()
store.put = AsyncMock(side_effect=RuntimeError("db down"))
manager = RunManager(store=store)
with pytest.raises(RuntimeError, match="db down"):
await manager.create("thread-1")
assert manager._runs == {}
assert await manager.list_by_thread("thread-1") == []
@pytest.mark.anyio
async def test_create_rolls_back_in_memory_record_on_store_cancellation():
"""create() must also roll back when cancelled during the initial store write."""
store = MemoryRunStore()
async def cancelled_put(run_id, **kwargs):
raise asyncio.CancelledError
store.put = cancelled_put
manager = RunManager(store=store)
with pytest.raises(asyncio.CancelledError):
await manager.create("thread-1")
assert manager._runs == {}
assert await manager.list_by_thread("thread-1") == []
@pytest.mark.anyio
async def test_create_does_not_expose_run_until_store_persist_completes():
"""Concurrent readers must wait until the new run has been persisted."""
store = MemoryRunStore()
manager = RunManager(store=store)
original_put = store.put
put_started = asyncio.Event()
allow_put = asyncio.Event()
async def blocking_put(run_id, **kwargs):
put_started.set()
await allow_put.wait()
return await original_put(run_id, **kwargs)
store.put = blocking_put
create_task = asyncio.create_task(manager.create("thread-1"))
list_task = None
try:
await put_started.wait()
list_task = asyncio.create_task(manager.list_by_thread("thread-1"))
await asyncio.sleep(0)
assert not list_task.done()
allow_put.set()
record = await create_task
runs = await list_task
assert [run.run_id for run in runs] == [record.run_id]
finally:
allow_put.set()
cleanup_tasks = []
for task in (list_task, create_task):
if task is None:
continue
if not task.done():
task.cancel()
cleanup_tasks.append(task)
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_prefers_in_memory_record_over_store(): async def test_get_prefers_in_memory_record_over_store():
"""In-memory records retain task/control state when store has same run.""" """In-memory records retain task/control state when store has same run."""
@@ -558,6 +394,52 @@ async def test_create_or_reject_interrupt_persists_interrupted_status_to_store()
assert stored_old["status"] == "interrupted" assert stored_old["status"] == "interrupted"
@pytest.mark.anyio
async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_fails():
"""A failed new-run persist must not cancel the existing inflight run."""
from unittest.mock import AsyncMock
store = MemoryRunStore()
manager = RunManager(store=store)
old = await manager.create("thread-1")
await manager.set_status(old.run_id, RunStatus.running)
store.put = AsyncMock(side_effect=RuntimeError("db down"))
with pytest.raises(RuntimeError, match="db down"):
await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
stored_old = await store.get(old.run_id)
assert list(manager._runs) == [old.run_id]
assert old.status == RunStatus.running
assert old.abort_event.is_set() is False
assert stored_old is not None
assert stored_old["status"] == "running"
@pytest.mark.anyio
async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_is_cancelled():
"""Cancellation during new-run persist must not cancel the existing run."""
store = MemoryRunStore()
manager = RunManager(store=store)
old = await manager.create("thread-1")
await manager.set_status(old.run_id, RunStatus.running)
async def cancelled_put(run_id, **kwargs):
raise asyncio.CancelledError
store.put = cancelled_put
with pytest.raises(asyncio.CancelledError):
await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
stored_old = await store.get(old.run_id)
assert list(manager._runs) == [old.run_id]
assert old.status == RunStatus.running
assert old.abort_event.is_set() is False
assert stored_old is not None
assert stored_old["status"] == "running"
@pytest.mark.anyio @pytest.mark.anyio
async def test_create_or_reject_rollback_persists_interrupted_status_to_store(): async def test_create_or_reject_rollback_persists_interrupted_status_to_store():
"""rollback strategy should persist interrupted status for old runs.""" """rollback strategy should persist interrupted status for old runs."""
+2 -169
View File
@@ -10,7 +10,6 @@ from sqlalchemy.dialects import postgresql
from deerflow.persistence.run import RunRepository from deerflow.persistence.run import RunRepository
from deerflow.runtime import RunManager, RunStatus from deerflow.runtime import RunManager, RunStatus
from deerflow.runtime.runs.store.base import RunStore
async def _make_repo(tmp_path): async def _make_repo(tmp_path):
@@ -27,45 +26,6 @@ async def _cleanup():
await close_engine() await close_engine()
class _CustomRunStoreWithoutProgress(RunStore):
async def put(self, *args, **kwargs):
return None
async def get(self, *args, **kwargs):
return None
async def list_by_thread(self, *args, **kwargs):
return []
async def update_status(self, *args, **kwargs):
return None
async def delete(self, *args, **kwargs):
return None
async def update_model_name(self, *args, **kwargs):
return None
async def update_run_completion(self, *args, **kwargs):
return None
async def list_pending(self, *args, **kwargs):
return []
async def list_inflight(self, *args, **kwargs):
return []
async def aggregate_tokens_by_thread(self, *args, **kwargs):
return {}
@pytest.mark.anyio
async def test_update_run_progress_defaults_to_noop_for_custom_store():
store = _CustomRunStoreWithoutProgress()
await store.update_run_progress("r1", total_tokens=1)
class TestRunRepository: class TestRunRepository:
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_and_get(self, tmp_path): async def test_put_and_get(self, tmp_path):
@@ -78,19 +38,6 @@ class TestRunRepository:
assert row["status"] == "pending" assert row["status"] == "pending"
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_put_is_idempotent_for_retried_writes(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", assistant_id="old-agent", status="pending")
await repo.put("r1", thread_id="t1", assistant_id="new-agent", status="running", error="retry")
row = await repo.get("r1")
assert row["assistant_id"] == "new-agent"
assert row["status"] == "running"
assert row["error"] == "retry"
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_missing_returns_none(self, tmp_path): async def test_get_missing_returns_none(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
@@ -101,19 +48,11 @@ class TestRunRepository:
async def test_update_status(self, tmp_path): async def test_update_status(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1") await repo.put("r1", thread_id="t1")
updated = await repo.update_status("r1", "running") await repo.update_status("r1", "running")
row = await repo.get("r1") row = await repo.get("r1")
assert updated is True
assert row["status"] == "running" assert row["status"] == "running"
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_update_status_returns_false_for_missing_row(self, tmp_path):
repo = await _make_repo(tmp_path)
updated = await repo.update_status("missing", "error", error="lost")
assert updated is False
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_update_status_with_error(self, tmp_path): async def test_update_status_with_error(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
@@ -170,24 +109,11 @@ class TestRunRepository:
assert all(r["status"] == "pending" for r in pending) assert all(r["status"] == "pending" for r in pending)
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_list_inflight_returns_pending_and_running_before_cutoff(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("pending-old", thread_id="t1", status="pending", created_at="2026-01-01T00:00:00+00:00")
await repo.put("running-old", thread_id="t1", status="running", created_at="2026-01-01T00:00:01+00:00")
await repo.put("success-old", thread_id="t1", status="success", created_at="2026-01-01T00:00:02+00:00")
await repo.put("pending-new", thread_id="t1", status="pending", created_at="2026-01-01T00:00:03+00:00")
inflight = await repo.list_inflight(before="2026-01-01T00:00:02+00:00")
assert [row["run_id"] for row in inflight] == ["pending-old", "running-old"]
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_update_run_completion(self, tmp_path): async def test_update_run_completion(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running") await repo.put("r1", thread_id="t1", status="running")
updated = await repo.update_run_completion( await repo.update_run_completion(
"r1", "r1",
status="success", status="success",
total_input_tokens=100, total_input_tokens=100,
@@ -202,7 +128,6 @@ class TestRunRepository:
first_human_message="What is the meaning?", first_human_message="What is the meaning?",
) )
row = await repo.get("r1") row = await repo.get("r1")
assert updated is True
assert row["status"] == "success" assert row["status"] == "success"
assert row["total_tokens"] == 150 assert row["total_tokens"] == 150
assert row["llm_call_count"] == 2 assert row["llm_call_count"] == 2
@@ -212,13 +137,6 @@ class TestRunRepository:
assert row["first_human_message"] == "What is the meaning?" assert row["first_human_message"] == "What is the meaning?"
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_update_run_completion_returns_false_for_missing_row(self, tmp_path):
repo = await _make_repo(tmp_path)
updated = await repo.update_run_completion("missing", status="error", total_tokens=1)
assert updated is False
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_metadata_preserved(self, tmp_path): async def test_metadata_preserved(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
@@ -252,69 +170,6 @@ class TestRunRepository:
assert row["total_tokens"] == 100 assert row["total_tokens"] == 100
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_keeps_status_running(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_progress(
"r1",
total_input_tokens=40,
total_output_tokens=10,
total_tokens=50,
llm_call_count=1,
message_count=2,
last_ai_message="partial answer",
)
row = await repo.get("r1")
assert row["status"] == "running"
assert row["total_tokens"] == 50
assert row["llm_call_count"] == 1
assert row["message_count"] == 2
assert row["last_ai_message"] == "partial answer"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_preserves_omitted_fields(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_progress(
"r1",
total_input_tokens=40,
total_output_tokens=10,
total_tokens=50,
llm_call_count=1,
lead_agent_tokens=30,
subagent_tokens=20,
message_count=2,
)
await repo.update_run_progress("r1", total_tokens=60, last_ai_message="updated")
row = await repo.get("r1")
assert row["total_input_tokens"] == 40
assert row["total_output_tokens"] == 10
assert row["total_tokens"] == 60
assert row["llm_call_count"] == 1
assert row["lead_agent_tokens"] == 30
assert row["subagent_tokens"] == 20
assert row["message_count"] == 2
assert row["last_ai_message"] == "updated"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_skips_terminal_runs(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_completion("r1", status="success", total_tokens=100, llm_call_count=1)
await repo.update_run_progress("r1", total_tokens=200, llm_call_count=2)
row = await repo.get("r1")
assert row["status"] == "success"
assert row["total_tokens"] == 100
assert row["llm_call_count"] == 1
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path): async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
@@ -370,28 +225,6 @@ class TestRunRepository:
} }
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_can_include_active_runs(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("success-run", thread_id="t1", status="running")
await repo.update_run_completion("success-run", status="success", total_tokens=100, lead_agent_tokens=100)
await repo.put("running-run", thread_id="t1", status="running")
await repo.update_run_progress("running-run", total_tokens=25, lead_agent_tokens=20, subagent_tokens=5)
without_active = await repo.aggregate_tokens_by_thread("t1")
with_active = await repo.aggregate_tokens_by_thread("t1", include_active=True)
assert without_active["total_tokens"] == 100
assert without_active["total_runs"] == 1
assert with_active["total_tokens"] == 125
assert with_active["total_runs"] == 2
assert with_active["by_caller"] == {
"lead_agent": 120,
"subagent": 5,
"middleware": 0,
}
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_list_by_thread_ordered_desc(self, tmp_path): async def test_list_by_thread_ordered_desc(self, tmp_path):
"""list_by_thread returns newest first.""" """list_by_thread returns newest first."""
-27
View File
@@ -53,30 +53,3 @@ def test_thread_token_usage_returns_stable_shape():
}, },
} }
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1") run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
def test_thread_token_usage_can_include_active_runs():
run_store = MagicMock()
run_store.aggregate_tokens_by_thread = AsyncMock(
return_value={
"total_tokens": 175,
"total_input_tokens": 120,
"total_output_tokens": 55,
"total_runs": 3,
"by_model": {"unknown": {"tokens": 175, "runs": 3}},
"by_caller": {
"lead_agent": 145,
"subagent": 25,
"middleware": 5,
},
},
)
app = _make_app(run_store)
with TestClient(app) as client:
response = client.get("/api/threads/thread-1/token-usage?include_active=true")
assert response.status_code == 200
assert response.json()["total_tokens"] == 175
assert response.json()["total_runs"] == 3
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1", include_active=True)
-4
View File
@@ -18,7 +18,3 @@ lint:
format: format:
pnpm format:write pnpm format:write
build-static:
NEXT_CONFIG_BUILD_OUTPUT=standalone SKIP_ENV_VALIDATION=1 NEXT_PUBLIC_STATIC_WEBSITE_ONLY=true pnpm build
@if [ -d .next/static ]; then mkdir -p .next/standalone/.next && cp -R .next/static .next/standalone/.next/static; fi
-4
View File
@@ -16,10 +16,6 @@ const withNextra = nextra({});
/** @type {import("next").NextConfig} */ /** @type {import("next").NextConfig} */
const config = { const config = {
output:
process.env.NEXT_CONFIG_BUILD_OUTPUT === "standalone"
? "standalone"
: undefined,
i18n: { i18n: {
locales: ["en", "zh"], locales: ["en", "zh"],
defaultLocale: "en", defaultLocale: "en",
@@ -32,7 +32,7 @@ Even with digital Leicas, photographers often emulate film characteristics: natu
### Image 1: Parisian Decisive Moment ### Image 1: Parisian Decisive Moment
![Paris Decisive Moment](/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-paris-decisive-moment.jpg) ![Paris Decisive Moment](/frontend/public/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-paris-decisive-moment.jpg)
This image captures the essence of Cartier-Bresson's philosophy. A woman in a red coat leaps over a puddle while a cyclist passes in perfect synchrony. The composition follows the rule of thirds, with the subject positioned at the intersection of grid lines. Shot with a simulated Leica M11 and 35mm Summicron lens at f/2.8, the image features shallow depth of field, natural film grain, and the warm, muted color palette characteristic of Leica photography. This image captures the essence of Cartier-Bresson's philosophy. A woman in a red coat leaps over a puddle while a cyclist passes in perfect synchrony. The composition follows the rule of thirds, with the subject positioned at the intersection of grid lines. Shot with a simulated Leica M11 and 35mm Summicron lens at f/2.8, the image features shallow depth of field, natural film grain, and the warm, muted color palette characteristic of Leica photography.
@@ -40,7 +40,7 @@ The "decisive moment" here isn't just about timing—it's about the alignment of
### Image 2: Tokyo Night Reflections ### Image 2: Tokyo Night Reflections
![Tokyo Night Scene](/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-tokyo-night.jpg) ![Tokyo Night Scene](/frontend/public/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-tokyo-night.jpg)
Moving to Shinjuku, Tokyo, this image explores the atmospheric possibilities of Leica's legendary Noctilux lens. Simulating a Leica M10-P with a 50mm f/0.95 Noctilux wide open, the image creates extremely shallow depth of field with beautiful bokeh balls from neon signs reflected in wet pavement. Moving to Shinjuku, Tokyo, this image explores the atmospheric possibilities of Leica's legendary Noctilux lens. Simulating a Leica M10-P with a 50mm f/0.95 Noctilux wide open, the image creates extremely shallow depth of field with beautiful bokeh balls from neon signs reflected in wet pavement.
@@ -48,7 +48,7 @@ A salaryman waits under glowing kanji signs, steam rising from a nearby ramen sh
### Image 3: New York City Candid ### Image 3: New York City Candid
![NYC Candid Scene](/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-nyc-candid.jpg) ![NYC Candid Scene](/frontend/public/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-nyc-candid.jpg)
This Chinatown scene demonstrates the documentary power of Leica's Q2 camera with its fixed 28mm Summilux lens. The wide angle captures environmental context while maintaining intimate proximity to the subjects. A fishmonger hands a live fish to a customer while tourists photograph the scene—a moment of cultural contrast and authentic urban life. This Chinatown scene demonstrates the documentary power of Leica's Q2 camera with its fixed 28mm Summilux lens. The wide angle captures environmental context while maintaining intimate proximity to the subjects. A fishmonger hands a live fish to a customer while tourists photograph the scene—a moment of cultural contrast and authentic urban life.
@@ -1,19 +1,19 @@
import { isStaticWebsiteOnly } from "@/core/static-mode"; "use client";
import { DEMO_THREAD_IDS } from "@/core/threads/static-demo";
import { ChatProviders } from "./providers"; import { PromptInputProvider } from "@/components/ai-elements/prompt-input";
import { ArtifactsProvider } from "@/components/workspace/artifacts";
export function generateStaticParams() { import { SubtasksProvider } from "@/core/tasks/context";
if (!isStaticWebsiteOnly()) {
return [];
}
return DEMO_THREAD_IDS.map((thread_id) => ({ thread_id }));
}
export default function ChatLayout({ export default function ChatLayout({
children, children,
}: { }: {
children: React.ReactNode; children: React.ReactNode;
}) { }) {
return <ChatProviders>{children}</ChatProviders>; return (
<SubtasksProvider>
<ArtifactsProvider>
<PromptInputProvider>{children}</PromptInputProvider>
</ArtifactsProvider>
</SubtasksProvider>
);
} }
@@ -227,7 +227,6 @@ export default function ChatPage() {
isWelcomeMode && <Welcome mode={settings.context.mode} /> isWelcomeMode && <Welcome mode={settings.context.mode} />
} }
disabled={ disabled={
isMock ||
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" || env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" ||
isUploading isUploading
} }
@@ -1,15 +0,0 @@
"use client";
import { PromptInputProvider } from "@/components/ai-elements/prompt-input";
import { ArtifactsProvider } from "@/components/workspace/artifacts";
import { SubtasksProvider } from "@/core/tasks/context";
export function ChatProviders({ children }: { children: React.ReactNode }) {
return (
<SubtasksProvider>
<ArtifactsProvider>
<PromptInputProvider>{children}</PromptInputProvider>
</ArtifactsProvider>
</SubtasksProvider>
);
}
+6 -8
View File
@@ -43,14 +43,12 @@ export default async function WorkspaceLayout({
> >
Retry Retry
</Link> </Link>
<form action="/api/v1/auth/logout" method="post"> <Link
<button href="/api/v1/auth/logout"
type="submit" className="text-muted-foreground hover:bg-muted rounded-md border px-4 py-2 text-sm"
className="text-muted-foreground hover:bg-muted rounded-md border px-4 py-2 text-sm" >
> Logout &amp; Reset
Logout &amp; Reset </Link>
</button>
</form>
</div> </div>
</div> </div>
); );
@@ -83,7 +83,7 @@ export function ArtifactFileDetail({
const isSupportPreview = useMemo(() => { const isSupportPreview = useMemo(() => {
return language === "html" || language === "markdown"; return language === "html" || language === "markdown";
}, [language]); }, [language]);
const { content, url } = useArtifactContent({ const { content } = useArtifactContent({
threadId, threadId,
filepath: filepathFromProps, filepath: filepathFromProps,
enabled: isCodeFile && !isWriteFile, enabled: isCodeFile && !isWriteFile,
@@ -254,9 +254,7 @@ export function ArtifactFileDetail({
(language === "markdown" || language === "html") && ( (language === "markdown" || language === "html") && (
<ArtifactFilePreview <ArtifactFilePreview
content={displayContent} content={displayContent}
isWriteFile={isWriteFile}
language={language ?? "text"} language={language ?? "text"}
url={url}
/> />
)} )}
{isCodeFile && viewMode === "code" && ( {isCodeFile && viewMode === "code" && (
@@ -279,33 +277,27 @@ export function ArtifactFileDetail({
export function ArtifactFilePreview({ export function ArtifactFilePreview({
content, content,
isWriteFile,
language, language,
url,
}: { }: {
content: string; content: string;
isWriteFile: boolean;
language: string; language: string;
url?: string;
}) { }) {
const [htmlPreviewUrl, setHtmlPreviewUrl] = useState<string>(); const [htmlPreviewUrl, setHtmlPreviewUrl] = useState<string>();
useEffect(() => { useEffect(() => {
if (language !== "html" || isWriteFile) { if (language !== "html") {
setHtmlPreviewUrl(undefined); setHtmlPreviewUrl(undefined);
return; return;
} }
const blob = new Blob([htmlWithBaseHref(content ?? "", url)], { const blob = new Blob([content ?? ""], { type: "text/html" });
type: "text/html", const url = URL.createObjectURL(blob);
}); setHtmlPreviewUrl(url);
const objectUrl = URL.createObjectURL(blob);
setHtmlPreviewUrl(objectUrl);
return () => { return () => {
URL.revokeObjectURL(objectUrl); URL.revokeObjectURL(url);
}; };
}, [content, isWriteFile, language, url]); }, [content, language]);
if (language === "markdown") { if (language === "markdown") {
return ( return (
@@ -326,35 +318,9 @@ export function ArtifactFilePreview({
className="size-full" className="size-full"
title="Artifact preview" title="Artifact preview"
sandbox="allow-scripts allow-forms" sandbox="allow-scripts allow-forms"
src={isWriteFile ? undefined : htmlPreviewUrl} src={htmlPreviewUrl}
srcDoc={isWriteFile ? content : undefined}
/> />
); );
} }
return null; return null;
} }
function htmlWithBaseHref(content: string, url?: string) {
if (!url || /<base\s/i.exec(content)) {
return content;
}
const baseHref = htmlBaseHref(url);
const baseElement = `<base href="${escapeHtmlAttribute(baseHref)}">`;
if (/<head[^>]*>/i.exec(content)) {
return content.replace(/<head([^>]*)>/i, `<head$1>${baseElement}`);
}
return `${baseElement}${content}`;
}
function htmlBaseHref(url: string) {
const baseUrl = new URL(url, window.location.href);
baseUrl.pathname = baseUrl.pathname.replace(/\/[^/]*$/, "/");
baseUrl.search = "";
baseUrl.hash = "";
return baseUrl.toString();
}
function escapeHtmlAttribute(value: string) {
return value.replaceAll("&", "&amp;").replaceAll('"', "&quot;");
}
+17 -17
View File
@@ -20,27 +20,27 @@ If you want to understand how DeerFlow works, start with the Introduction. If yo
Start with the conceptual overview first. Start with the conceptual overview first.
- [Introduction](./docs/introduction) - [Introduction](/docs/introduction)
- [Why DeerFlow](./docs/introduction/why-deerflow) - [Why DeerFlow](/docs/introduction/why-deerflow)
- [Harness vs App](./docs/introduction/harness-vs-app) - [Harness vs App](/docs/introduction/harness-vs-app)
### If you want to build with DeerFlow ### If you want to build with DeerFlow
Start with the Harness section. This path is for teams who want to integrate DeerFlow capabilities into their own system or build a custom agent product on top of the DeerFlow runtime. Start with the Harness section. This path is for teams who want to integrate DeerFlow capabilities into their own system or build a custom agent product on top of the DeerFlow runtime.
- [DeerFlow Harness](./docs/harness) - [DeerFlow Harness](/docs/harness)
- [Quick Start](./docs/harness/quick-start) - [Quick Start](/docs/harness/quick-start)
- [Configuration](./docs/harness/configuration) - [Configuration](/docs/harness/configuration)
- [Customization](./docs/harness/customization) - [Customization](/docs/harness/customization)
### If you want to deploy and use DeerFlow ### If you want to deploy and use DeerFlow
Start with the App section. This path is for teams who want to run DeerFlow as a complete application and understand how to configure, operate, and use it in practice. Start with the App section. This path is for teams who want to run DeerFlow as a complete application and understand how to configure, operate, and use it in practice.
- [DeerFlow App](./docs/app) - [DeerFlow App](/docs/app)
- [Quick Start](./docs/app/quick-start) - [Quick Start](/docs/app/quick-start)
- [Deployment Guide](./docs/app/deployment-guide) - [Deployment Guide](/docs/app/deployment-guide)
- [Workspace Usage](./docs/app/workspace-usage) - [Workspace Usage](/docs/app/workspace-usage)
## Documentation structure ## Documentation structure
@@ -79,17 +79,17 @@ The App section is written for teams who want to deploy DeerFlow as a usable pro
The Tutorials section is for hands-on, task-oriented learning. The Tutorials section is for hands-on, task-oriented learning.
- [Tutorials](./docs/tutorials) - [Tutorials](/docs/tutorials)
### Reference ### Reference
The Reference section is for detailed lookup material, including configuration, runtime modes, APIs, and source-oriented mapping. The Reference section is for detailed lookup material, including configuration, runtime modes, APIs, and source-oriented mapping.
- [Reference](./docs/reference) - [Reference](/docs/reference)
## Choose the right path ## Choose the right path
- If you are **evaluating the project**, start with [Introduction](./docs/introduction). - If you are **evaluating the project**, start with [Introduction](/docs/introduction).
- If you are **building your own agent system**, start with [DeerFlow Harness](./docs/harness). - If you are **building your own agent system**, start with [DeerFlow Harness](/docs/harness).
- If you are **deploying DeerFlow for users**, start with [DeerFlow App](./docs/app). - If you are **deploying DeerFlow for users**, start with [DeerFlow App](/docs/app).
- If you want to **learn by doing**, go to [Tutorials](./docs/tutorials). - If you want to **learn by doing**, go to [Tutorials](/docs/tutorials).
@@ -1,9 +0,0 @@
---
title: DeerFlow 2.0 M1
description: DeerFlow 2.0 M1 is officially in RC. Here's what you need to know.
date: 2026-05-30
tags:
- Release
---
## DeerFlow 2.0 M1 Release
+11 -11
View File
@@ -20,27 +20,27 @@ DeerFlow 是一个用于构建和运行 Agent 系统的框架。它提供了一
先从概念概述开始。 先从概念概述开始。
- [简介](./docs/introduction) - [简介](/docs/introduction)
- [为什么选择 DeerFlow](./docs/introduction/why-deerflow) - [为什么选择 DeerFlow](/docs/introduction/why-deerflow)
- [Harness 与应用的区别](./docs/introduction/harness-vs-app) - [Harness 与应用的区别](/docs/introduction/harness-vs-app)
### 如果你想基于 DeerFlow 进行开发 ### 如果你想基于 DeerFlow 进行开发
从 Harness 章节开始。这条路径适合想将 DeerFlow 功能集成到自己系统中,或基于 DeerFlow 运行时构建自定义 Agent 产品的团队。 从 Harness 章节开始。这条路径适合想将 DeerFlow 功能集成到自己系统中,或基于 DeerFlow 运行时构建自定义 Agent 产品的团队。
- [DeerFlow Harness](./docs/harness) - [DeerFlow Harness](/docs/harness)
- [快速上手](./docs/harness/quick-start) - [快速上手](/docs/harness/quick-start)
- [配置](./docs/harness/configuration) - [配置](/docs/harness/configuration)
- [自定义与扩展](./docs/harness/customization) - [自定义与扩展](/docs/harness/customization)
### 如果你想部署和使用 DeerFlow ### 如果你想部署和使用 DeerFlow
从应用章节开始。这条路径适合想将 DeerFlow 作为完整应用运行,并了解如何配置、运维和实际使用的团队。 从应用章节开始。这条路径适合想将 DeerFlow 作为完整应用运行,并了解如何配置、运维和实际使用的团队。
- [DeerFlow 应用](./docs/application) - [DeerFlow 应用](/docs/application)
- [快速上手](./docs/application/quick-start) - [快速上手](/docs/application/quick-start)
- [部署指南](./docs/application/deployment-guide) - [部署指南](/docs/application/deployment-guide)
- [工作区使用](./docs/application/workspace-usage) - [工作区使用](/docs/application/workspace-usage)
## 文档结构 ## 文档结构
-49
View File
@@ -3,13 +3,6 @@
import { Client as LangGraphClient } from "@langchain/langgraph-sdk/client"; import { Client as LangGraphClient } from "@langchain/langgraph-sdk/client";
import { getLangGraphBaseURL } from "../config"; import { getLangGraphBaseURL } from "../config";
import { isStaticWebsiteOnly } from "../static-mode";
import {
loadStaticDemoThread,
loadStaticDemoThreads,
staticDemoThreadState,
} from "../threads/static-demo";
import type { AgentThreadState } from "../threads/types";
import { isStateChangingMethod, readCsrfCookie } from "./fetcher"; import { isStateChangingMethod, readCsrfCookie } from "./fetcher";
import { sanitizeRunStreamOptions } from "./stream-mode"; import { sanitizeRunStreamOptions } from "./stream-mode";
@@ -39,10 +32,6 @@ function injectCsrfHeader(_url: URL, init: RequestInit): RequestInit {
} }
function createCompatibleClient(isMock?: boolean): LangGraphClient { function createCompatibleClient(isMock?: boolean): LangGraphClient {
if (isStaticWebsiteOnly() && !isMock) {
return createStaticClient();
}
const apiUrl = getLangGraphBaseURL(isMock); const apiUrl = getLangGraphBaseURL(isMock);
console.log(`Creating API client with base URL: ${apiUrl}`); console.log(`Creating API client with base URL: ${apiUrl}`);
const client = new LangGraphClient({ const client = new LangGraphClient({
@@ -69,44 +58,6 @@ function createCompatibleClient(isMock?: boolean): LangGraphClient {
return client; return client;
} }
function createStaticClient(): LangGraphClient {
const apiUrl =
typeof window === "undefined"
? "http://localhost:3000"
: window.location.origin;
const client = new LangGraphClient({ apiUrl });
client.threads.search = (async (query) => {
return loadStaticDemoThreads(query);
}) as typeof client.threads.search;
client.threads.get = (async (threadId) => {
return loadStaticDemoThread(threadId);
}) as typeof client.threads.get;
client.threads.getState = (async (threadId) => {
return staticDemoThreadState(await loadStaticDemoThread(threadId));
}) as typeof client.threads.getState;
client.threads.getHistory = (async (threadId) => {
return [staticDemoThreadState(await loadStaticDemoThread(threadId))];
}) as typeof client.threads.getHistory;
client.threads.update = (async (threadId) => {
return loadStaticDemoThread(threadId);
}) as typeof client.threads.update;
client.runs.list = (async () => []) as typeof client.runs.list;
client.runs.stream = async function* () {
/* empty */
} as typeof client.runs.stream;
client.runs.joinStream = async function* () {
/* empty */
} as typeof client.runs.joinStream;
return client as LangGraphClient<AgentThreadState>;
}
const _clients = new Map<string, LangGraphClient>(); const _clients = new Map<string, LangGraphClient>();
export function getAPIClient(isMock?: boolean): LangGraphClient { export function getAPIClient(isMock?: boolean): LangGraphClient {
const cacheKey = isMock ? "mock" : "default"; const cacheKey = isMock ? "mock" : "default";
-20
View File
@@ -1,5 +1,4 @@
import { getBackendBaseURL } from "../config"; import { getBackendBaseURL } from "../config";
import { isStaticWebsiteOnly } from "../static-mode";
import type { AgentThread } from "../threads"; import type { AgentThread } from "../threads";
export function urlOfArtifact({ export function urlOfArtifact({
@@ -13,9 +12,6 @@ export function urlOfArtifact({
download?: boolean; download?: boolean;
isMock?: boolean; isMock?: boolean;
}) { }) {
if (isStaticWebsiteOnly()) {
return staticDemoArtifactURL({ filepath, threadId, download });
}
if (isMock) { if (isMock) {
return `${getBackendBaseURL()}/mock/api/threads/${threadId}/artifacts${filepath}${download ? "?download=true" : ""}`; return `${getBackendBaseURL()}/mock/api/threads/${threadId}/artifacts${filepath}${download ? "?download=true" : ""}`;
} }
@@ -27,21 +23,5 @@ export function extractArtifactsFromThread(thread: AgentThread) {
} }
export function resolveArtifactURL(absolutePath: string, threadId: string) { export function resolveArtifactURL(absolutePath: string, threadId: string) {
if (isStaticWebsiteOnly()) {
return staticDemoArtifactURL({ filepath: absolutePath, threadId });
}
return `${getBackendBaseURL()}/api/threads/${threadId}/artifacts${absolutePath}`; return `${getBackendBaseURL()}/api/threads/${threadId}/artifacts${absolutePath}`;
} }
function staticDemoArtifactURL({
filepath,
threadId,
download = false,
}: {
filepath: string;
threadId: string;
download?: boolean;
}) {
const demoPath = filepath.replace(/^\/mnt\//, "/");
return `${getBackendBaseURL()}/demo/threads/${threadId}${demoPath}${download ? "?download=true" : ""}`;
}
+3 -17
View File
@@ -10,8 +10,6 @@ import React, {
type ReactNode, type ReactNode,
} from "react"; } from "react";
import { isStaticWebsiteOnly } from "../static-mode";
import { type User, buildLoginUrl } from "./types"; import { type User, buildLoginUrl } from "./types";
// Re-export for consumers // Re-export for consumers
@@ -48,7 +46,6 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
const [isLoading, setIsLoading] = useState(false); const [isLoading, setIsLoading] = useState(false);
const router = useRouter(); const router = useRouter();
const pathname = usePathname(); const pathname = usePathname();
const staticMode = isStaticWebsiteOnly();
const isAuthenticated = user !== null; const isAuthenticated = user !== null;
@@ -57,8 +54,6 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
* Used when initialUser might be stale (e.g., after tab was inactive) * Used when initialUser might be stale (e.g., after tab was inactive)
*/ */
const refreshUser = useCallback(async () => { const refreshUser = useCallback(async () => {
if (staticMode) return;
try { try {
setIsLoading(true); setIsLoading(true);
const res = await fetch("/api/v1/auth/me", { const res = await fetch("/api/v1/auth/me", {
@@ -82,7 +77,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
} finally { } finally {
setIsLoading(false); setIsLoading(false);
} }
}, [staticMode, pathname, router]); }, [pathname, router]);
/** /**
* Logout - call FastAPI logout endpoint and clear local state * Logout - call FastAPI logout endpoint and clear local state
@@ -92,11 +87,6 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
// Immediately clear local state to prevent UI flicker // Immediately clear local state to prevent UI flicker
setUser(null); setUser(null);
if (staticMode) {
router.push("/");
return;
}
try { try {
await fetch("/api/v1/auth/logout", { await fetch("/api/v1/auth/logout", {
method: "POST", method: "POST",
@@ -109,7 +99,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
// Redirect to home page // Redirect to home page
router.push("/"); router.push("/");
}, [staticMode, router]); }, [router]);
/** /**
* Handle visibility change - refresh user when tab becomes visible again. * Handle visibility change - refresh user when tab becomes visible again.
@@ -118,8 +108,6 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
const lastCheckRef = React.useRef(0); const lastCheckRef = React.useRef(0);
useEffect(() => { useEffect(() => {
if (staticMode) return;
const handleVisibilityChange = () => { const handleVisibilityChange = () => {
if (document.visibilityState !== "visible" || user === null) return; if (document.visibilityState !== "visible" || user === null) return;
const now = Date.now(); const now = Date.now();
@@ -132,7 +120,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
return () => { return () => {
document.removeEventListener("visibilitychange", handleVisibilityChange); document.removeEventListener("visibilitychange", handleVisibilityChange);
}; };
}, [staticMode, user, refreshUser]); }, [user, refreshUser]);
const value: AuthContextType = { const value: AuthContextType = {
user, user,
@@ -167,8 +155,6 @@ export function useRequireAuth(): AuthContextType {
const pathname = usePathname(); const pathname = usePathname();
useEffect(() => { useEffect(() => {
if (isStaticWebsiteOnly()) return;
// Only redirect if we're sure user is not authenticated (not just loading) // Only redirect if we're sure user is not authenticated (not just loading)
if (!auth.isLoading && !auth.isAuthenticated) { if (!auth.isLoading && !auth.isAuthenticated) {
router.push(buildLoginUrl(pathname || "/workspace")); router.push(buildLoginUrl(pathname || "/workspace"));
-10
View File
@@ -1,9 +1,6 @@
import { cookies } from "next/headers"; import { cookies } from "next/headers";
import { isStaticWebsiteOnly } from "../static-mode";
import { getGatewayConfig } from "./gateway-config"; import { getGatewayConfig } from "./gateway-config";
import { STATIC_WEBSITE_USER } from "./static-user";
import { type AuthResult, userSchema } from "./types"; import { type AuthResult, userSchema } from "./types";
const SSR_AUTH_TIMEOUT_MS = 5_000; const SSR_AUTH_TIMEOUT_MS = 5_000;
@@ -13,13 +10,6 @@ const SSR_AUTH_TIMEOUT_MS = 5_000;
* Returns a tagged AuthResult — callers use exhaustive switch, no try/catch. * Returns a tagged AuthResult — callers use exhaustive switch, no try/catch.
*/ */
export async function getServerSideUser(): Promise<AuthResult> { export async function getServerSideUser(): Promise<AuthResult> {
if (isStaticWebsiteOnly()) {
return {
tag: "authenticated",
user: STATIC_WEBSITE_USER,
};
}
if (process.env.DEER_FLOW_AUTH_DISABLED === "1") { if (process.env.DEER_FLOW_AUTH_DISABLED === "1") {
return { return {
tag: "authenticated", tag: "authenticated",
-8
View File
@@ -1,8 +0,0 @@
import type { User } from "./types";
export const STATIC_WEBSITE_USER: User = {
id: "static-website-user",
email: "static@example.local",
system_role: "admin",
needs_setup: false,
};
-10
View File
@@ -1,18 +1,8 @@
import { getBackendBaseURL } from "../config"; import { getBackendBaseURL } from "../config";
import { isStaticWebsiteOnly } from "../static-mode";
import type { ModelsResponse } from "./types"; import type { ModelsResponse } from "./types";
const STATIC_MODELS_RESPONSE: ModelsResponse = {
models: [],
token_usage: { enabled: false },
};
export async function loadModels(): Promise<ModelsResponse> { export async function loadModels(): Promise<ModelsResponse> {
if (isStaticWebsiteOnly()) {
return STATIC_MODELS_RESPONSE;
}
const res = await fetch(`${getBackendBaseURL()}/api/models`); const res = await fetch(`${getBackendBaseURL()}/api/models`);
const data = (await res.json()) as Partial<ModelsResponse>; const data = (await res.json()) as Partial<ModelsResponse>;
return { return {
-5
View File
@@ -1,5 +0,0 @@
import { env } from "@/env";
export function isStaticWebsiteOnly() {
return env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true";
}
-87
View File
@@ -1,87 +0,0 @@
import type { ThreadState } from "@langchain/langgraph-sdk";
import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
import type { AgentThread, AgentThreadState } from "./types";
export const DEMO_THREAD_IDS = [
"21cfea46-34bd-4aa6-9e1f-3009452fbeb9",
"3823e443-4e2b-4679-b496-a9506eae462b",
"4f3e55ee-f853-43db-bfb3-7d1a411f03cb",
"5aa47db1-d0cb-4eb9-aea5-3dac1b371c5a",
"7cfa5f8f-a2f8-47ad-acbd-da7137baf990",
"7f9dc56c-e49c-4671-a3d2-c492ff4dce0c",
"90040b36-7eba-4b97-ba89-02c3ad47a8b9",
"ad76c455-5bf9-4335-8517-fc03834ab828",
"b83fbb2a-4e36-4d82-9de0-7b2a02c2092a",
"c02bb4d5-4202-490e-ae8f-ff4864fc0d2e",
"d3e5adaf-084c-4dd5-9d29-94f1d6bccd98",
"f4125791-0128-402a-8ca9-50e0947557e4",
"fe3f7974-1bcb-4a01-a950-79673baafefd",
] as const;
export type ThreadSearchParams = NonNullable<
Parameters<ThreadsClient["search"]>[0]
>;
export async function loadStaticDemoThreads(
params: ThreadSearchParams = {},
): Promise<AgentThread[]> {
const threads = await Promise.all(
DEMO_THREAD_IDS.map((threadId) => loadStaticDemoThread(threadId)),
);
const sortBy = params.sortBy ?? "updated_at";
const sortOrder = params.sortOrder ?? "desc";
const sortedThreads = [...threads].sort((a, b) => {
const aTimestamp = (a as unknown as Record<string, unknown>)[sortBy];
const bTimestamp = (b as unknown as Record<string, unknown>)[sortBy];
const aParsed = typeof aTimestamp === "string" ? Date.parse(aTimestamp) : 0;
const bParsed = typeof bTimestamp === "string" ? Date.parse(bTimestamp) : 0;
const aValue = Number.isNaN(aParsed) ? 0 : aParsed;
const bValue = Number.isNaN(bParsed) ? 0 : bParsed;
return sortOrder === "asc" ? aValue - bValue : bValue - aValue;
});
const offset = Math.max(0, Math.floor(params.offset ?? 0));
const limit =
typeof params.limit === "number"
? Math.max(0, Math.floor(params.limit))
: sortedThreads.length;
return sortedThreads.slice(offset, offset + limit);
}
export async function loadStaticDemoThread(
threadId: string,
): Promise<AgentThread> {
const response = await globalThis.fetch(
`/demo/threads/${encodeURIComponent(threadId)}/thread.json`,
);
if (!response.ok) {
throw new Error(`Failed to load demo thread ${threadId}`);
}
const thread = (await response.json()) as AgentThread;
return {
...thread,
thread_id: threadId,
updated_at: thread.updated_at ?? thread.created_at,
};
}
export function staticDemoThreadState(
thread: AgentThread,
): ThreadState<AgentThreadState> {
return {
values: thread.values,
next: [],
checkpoint: {
thread_id: thread.thread_id,
checkpoint_ns: "",
checkpoint_id: null,
checkpoint_map: null,
},
metadata: thread.metadata ?? null,
created_at: thread.updated_at ?? thread.created_at ?? null,
parent_checkpoint: null,
tasks: [],
};
}
@@ -1,69 +0,0 @@
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
const ENV_KEYS = [
"NEXT_PUBLIC_BACKEND_BASE_URL",
"NEXT_PUBLIC_STATIC_WEBSITE_ONLY",
] as const;
type EnvSnapshot = Partial<
Record<(typeof ENV_KEYS)[number], string | undefined>
>;
function snapshotEnv(): EnvSnapshot {
const snapshot: EnvSnapshot = {};
for (const key of ENV_KEYS) {
snapshot[key] = process.env[key];
}
return snapshot;
}
function setEnv(key: (typeof ENV_KEYS)[number], value: string | undefined) {
const env = process.env as Record<string, string | undefined>;
if (value === undefined) {
delete env[key];
} else {
env[key] = value;
}
}
function restoreEnv(snapshot: EnvSnapshot) {
for (const key of ENV_KEYS) {
setEnv(key, snapshot[key]);
}
}
async function loadFreshArtifactUtils() {
vi.resetModules();
return await import("@/core/artifacts/utils");
}
describe("artifact URL helpers", () => {
let saved: EnvSnapshot;
beforeEach(() => {
saved = snapshotEnv();
setEnv("NEXT_PUBLIC_BACKEND_BASE_URL", undefined);
setEnv("NEXT_PUBLIC_STATIC_WEBSITE_ONLY", undefined);
});
afterEach(() => {
restoreEnv(saved);
});
test("maps static demo artifact paths to bundled public files", async () => {
setEnv("NEXT_PUBLIC_STATIC_WEBSITE_ONLY", "true");
const { resolveArtifactURL, urlOfArtifact } =
await loadFreshArtifactUtils();
expect(
urlOfArtifact({
filepath: "/mnt/user-data/outputs/index.html",
threadId: "thread-1",
}),
).toBe("/demo/threads/thread-1/user-data/outputs/index.html");
expect(
resolveArtifactURL("/mnt/user-data/outputs/style.css", "thread-1"),
).toBe("/demo/threads/thread-1/user-data/outputs/style.css");
});
});
@@ -1,77 +0,0 @@
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
import { STATIC_WEBSITE_USER } from "@/core/auth/static-user";
vi.mock("next/headers", () => ({
cookies: vi.fn(() => {
throw new Error("cookies should not be read in static website mode");
}),
}));
const ENV_KEYS = [
"DEER_FLOW_AUTH_DISABLED",
"NEXT_PUBLIC_STATIC_WEBSITE_ONLY",
] as const;
type EnvSnapshot = Partial<
Record<(typeof ENV_KEYS)[number], string | undefined>
>;
function snapshotEnv(): EnvSnapshot {
const snapshot: EnvSnapshot = {};
for (const key of ENV_KEYS) {
snapshot[key] = process.env[key];
}
return snapshot;
}
function setEnv(key: (typeof ENV_KEYS)[number], value: string | undefined) {
const env = process.env as Record<string, string | undefined>;
if (value === undefined) {
delete env[key];
} else {
env[key] = value;
}
}
function restoreEnv(snapshot: EnvSnapshot) {
for (const key of ENV_KEYS) {
setEnv(key, snapshot[key]);
}
}
async function loadFreshServerAuth() {
vi.resetModules();
return await import("@/core/auth/server");
}
describe("getServerSideUser", () => {
let saved: EnvSnapshot;
beforeEach(() => {
saved = snapshotEnv();
setEnv("DEER_FLOW_AUTH_DISABLED", undefined);
setEnv("NEXT_PUBLIC_STATIC_WEBSITE_ONLY", undefined);
});
afterEach(() => {
restoreEnv(saved);
vi.unstubAllGlobals();
});
test("bypasses gateway auth in static website mode", async () => {
setEnv("NEXT_PUBLIC_STATIC_WEBSITE_ONLY", "true");
const fetchSpy = vi.fn(() => {
throw new Error("fetch should not be called in static website mode");
});
vi.stubGlobal("fetch", fetchSpy);
const { getServerSideUser } = await loadFreshServerAuth();
await expect(getServerSideUser()).resolves.toEqual({
tag: "authenticated",
user: STATIC_WEBSITE_USER,
});
expect(fetchSpy).not.toHaveBeenCalled();
});
});