fix(harness)!: hydrate runs from RunStore and persist interrupted status (#2932)
* fix(harness): hydrate run history from RunStore and persist cancellation status fix: - Make RunManager.get() async and hydrate from RunStore when in-memory record is missing - Merge store rows into list_by_thread() with in-memory precedence for active runs - Persist interrupted status to RunStore in cancel() and create_or_reject(interrupt|rollback) - Extract _persist_status() to reuse the best-effort store update pattern - Await run_mgr.get() in all gateway endpoints - Return 409 with distinct message for store-only runs not active on current worker Closes #2812, Closes #2813 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix(harness): consistent sort and guarded hydration in RunManager fix: - list_by_thread() now sorts by created_at desc (newest first) even when no RunStore is configured, matching the store-backed code path - guard _record_from_store() call sites in get() and list_by_thread() with best-effort error handling so a single malformed store row cannot turn read paths into 500s test: - update test_list_by_thread assertion to expect newest-first order - seed MemoryRunStore via public put() API instead of writing to _runs * fix(harness): guard store-only runs from streaming and fix get() TOCTOU Add RunRecord.store_only flag set by _record_from_store so callers can distinguish hydrated history from live in-memory runs. join_run and stream_existing_run (action=None) now return 409 instead of hanging forever on an empty MemoryStreamBridge channel. Re-check _runs under lock after the store await in RunManager.get() so a concurrent create() that lands between the two checks returns the authoritative in-memory record rather than a stale store-hydrated copy. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com> * fix(harness): reorder bridge fetch in join_run and make list_by_thread limit explicit Move get_stream_bridge() after the store_only guard in join_run so a missing bridge cannot produce 503 for historical runs before the 409 guard fires. Add limit parameter to RunManager.list_by_thread (default 100, matching the store's page size) and pass it explicitly to the store call. Update docstring to document the limit instead of claiming all runs are returned. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com> * fix(harness): cap list_by_thread result to limit after merge Apply [:limit] to all return paths in list_by_thread so the method consistently returns at most limit records regardless of how many in-memory runs exist, making the limit parameter a true upper bound on the response size rather than just a store-query hint. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com> * fix `list_by_thread` docstring Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix(runtime): add update_model_name to RunStore to prevent SQL integrity errors RunManager.update_model_name() was calling _persist_to_store() which uses RunStore.put(), but RunRepository.put() is insert-only. This caused integrity errors when updating model_name for existing runs in SQL-backed stores. fix: - Add abstract update_model_name method to RunStore base class - Implement update_model_name in MemoryRunStore - Implement update_model_name in RunRepository with proper normalization - Add _persist_model_name helper in RunManager - Update RunManager.update_model_name to use the new method test: - Add tests for update_model_name functionality - Add integration tests for RunManager with SQL-backed store Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix(runtime): handle NULL status/on_disconnect in _record_from_store `dict.get(key, default)` only uses the default when the key is absent, so a SQL row with an explicit NULL status would pass `None` to `RunStatus(None)` and raise, breaking hydration for otherwise valid rows. Switch to `row.get(...) or fallback` so both missing and NULL values get a safe default. Add tests for get() and list_by_thread() with a NULL status row to prevent regression. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com> * fix(runs): address PR review feedback on store consistency changes - Fix list_by_thread limit semantics: pass store_limit = max(0, limit - len(memory_records)) to store so newer store records are not crowded out by in-memory records - Remove dead code: cancelled guard after raise is always True, simplify to if wait and record.task - Document _record_from_store NULL fallback policy (status→pending, on_disconnect→cancel) in docstring Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -225,6 +225,12 @@ CORS is same-origin by default when requests enter through nginx on port 2026. S
|
|||||||
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
||||||
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
||||||
|
|
||||||
|
**RunManager / RunStore contract**:
|
||||||
|
- `RunManager.get()` is async; direct callers must `await` it.
|
||||||
|
- When a persistent `RunStore` is configured, `get()` and `list_by_thread()` hydrate historical runs from the store. In-memory records win for the same `run_id` so task, abort, and stream-control state stays attached to active local runs.
|
||||||
|
- `cancel()` and `create_or_reject(..., multitask_strategy="interrupt"|"rollback")` persist interrupted status through `RunStore.update_status()`, matching normal `set_status()` transitions.
|
||||||
|
- Store-only hydrated runs are readable history. If the current worker has no in-memory task/control state for that run, cancellation APIs can return 409 because this worker cannot stop the task.
|
||||||
|
|
||||||
Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runtime, all other `/api/*` → Gateway REST APIs.
|
Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runtime, all other `/api/*` → Gateway REST APIs.
|
||||||
|
|
||||||
### Sandbox System (`packages/harness/deerflow/sandbox/`)
|
### Sandbox System (`packages/harness/deerflow/sandbox/`)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from pydantic import BaseModel, Field
|
|||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||||
from app.gateway.services import sse_consumer, start_run
|
from app.gateway.services import sse_consumer, start_run
|
||||||
from deerflow.runtime import RunRecord, serialize_channel_values
|
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
||||||
@@ -94,6 +94,12 @@ class ThreadTokenUsageResponse(BaseModel):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _cancel_conflict_detail(run_id: str, record: RunRecord) -> str:
|
||||||
|
if record.status in (RunStatus.pending, RunStatus.running):
|
||||||
|
return f"Run {run_id} is not active on this worker and cannot be cancelled"
|
||||||
|
return f"Run {run_id} is not cancellable (status: {record.status.value})"
|
||||||
|
|
||||||
|
|
||||||
def _record_to_response(record: RunRecord) -> RunResponse:
|
def _record_to_response(record: RunRecord) -> RunResponse:
|
||||||
return RunResponse(
|
return RunResponse(
|
||||||
run_id=record.run_id,
|
run_id=record.run_id,
|
||||||
@@ -191,7 +197,7 @@ async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
|||||||
"""Get details of a specific run."""
|
"""Get details of a specific run."""
|
||||||
run_mgr = get_run_manager(request)
|
run_mgr = get_run_manager(request)
|
||||||
user_id = await get_current_user(request)
|
user_id = await get_current_user(request)
|
||||||
record = await run_mgr.aget(run_id, user_id=user_id)
|
record = await run_mgr.get(run_id, user_id=user_id)
|
||||||
if record is None or record.thread_id != thread_id:
|
if record is None or record.thread_id != thread_id:
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
return _record_to_response(record)
|
return _record_to_response(record)
|
||||||
@@ -214,16 +220,13 @@ async def cancel_run(
|
|||||||
- wait=false: Return immediately with 202
|
- wait=false: Return immediately with 202
|
||||||
"""
|
"""
|
||||||
run_mgr = get_run_manager(request)
|
run_mgr = get_run_manager(request)
|
||||||
record = run_mgr.get(run_id)
|
record = await run_mgr.get(run_id)
|
||||||
if record is None or record.thread_id != thread_id:
|
if record is None or record.thread_id != thread_id:
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
|
||||||
cancelled = await run_mgr.cancel(run_id, action=action)
|
cancelled = await run_mgr.cancel(run_id, action=action)
|
||||||
if not cancelled:
|
if not cancelled:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=409, detail=_cancel_conflict_detail(run_id, record))
|
||||||
status_code=409,
|
|
||||||
detail=f"Run {run_id} is not cancellable (status: {record.status.value})",
|
|
||||||
)
|
|
||||||
|
|
||||||
if wait and record.task is not None:
|
if wait and record.task is not None:
|
||||||
try:
|
try:
|
||||||
@@ -239,12 +242,14 @@ async def cancel_run(
|
|||||||
@require_permission("runs", "read", owner_check=True)
|
@require_permission("runs", "read", owner_check=True)
|
||||||
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
|
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
|
||||||
"""Join an existing run's SSE stream."""
|
"""Join an existing run's SSE stream."""
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
run_mgr = get_run_manager(request)
|
run_mgr = get_run_manager(request)
|
||||||
record = run_mgr.get(run_id)
|
record = await run_mgr.get(run_id)
|
||||||
if record is None or record.thread_id != thread_id:
|
if record is None or record.thread_id != thread_id:
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
if record.store_only:
|
||||||
|
raise HTTPException(status_code=409, detail=f"Run {run_id} is not active on this worker and cannot be streamed")
|
||||||
|
|
||||||
|
bridge = get_stream_bridge(request)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
sse_consumer(bridge, record, request, run_mgr),
|
sse_consumer(bridge, record, request, run_mgr),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
@@ -273,14 +278,18 @@ async def stream_existing_run(
|
|||||||
remaining buffered events so the client observes a clean shutdown.
|
remaining buffered events so the client observes a clean shutdown.
|
||||||
"""
|
"""
|
||||||
run_mgr = get_run_manager(request)
|
run_mgr = get_run_manager(request)
|
||||||
record = run_mgr.get(run_id)
|
record = await run_mgr.get(run_id)
|
||||||
if record is None or record.thread_id != thread_id:
|
if record is None or record.thread_id != thread_id:
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
if record.store_only and action is None:
|
||||||
|
raise HTTPException(status_code=409, detail=f"Run {run_id} is not active on this worker and cannot be streamed")
|
||||||
|
|
||||||
# Cancel if an action was requested (stop-button / interrupt flow)
|
# Cancel if an action was requested (stop-button / interrupt flow)
|
||||||
if action is not None:
|
if action is not None:
|
||||||
cancelled = await run_mgr.cancel(run_id, action=action)
|
cancelled = await run_mgr.cancel(run_id, action=action)
|
||||||
if cancelled and wait and record.task is not None:
|
if not cancelled:
|
||||||
|
raise HTTPException(status_code=409, detail=_cancel_conflict_detail(run_id, record))
|
||||||
|
if wait and record.task is not None:
|
||||||
try:
|
try:
|
||||||
await record.task
|
await record.task
|
||||||
except (asyncio.CancelledError, Exception):
|
except (asyncio.CancelledError, Exception):
|
||||||
|
|||||||
@@ -151,6 +151,11 @@ class RunRepository(RunStore):
|
|||||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
async def update_model_name(self, run_id, model_name):
|
||||||
|
async with self._sf() as session:
|
||||||
|
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(model_name=self._normalize_model_name(model_name), updated_at=datetime.now(UTC)))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
async def delete(
|
async def delete(
|
||||||
self,
|
self,
|
||||||
run_id,
|
run_id,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from deerflow.utils.time import now_iso as _now_iso
|
from deerflow.utils.time import now_iso as _now_iso
|
||||||
|
|
||||||
@@ -37,6 +37,7 @@ class RunRecord:
|
|||||||
abort_action: str = "interrupt"
|
abort_action: str = "interrupt"
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
model_name: str | None = None
|
model_name: str | None = None
|
||||||
|
store_only: bool = False
|
||||||
|
|
||||||
|
|
||||||
class RunManager:
|
class RunManager:
|
||||||
@@ -71,6 +72,38 @@ class RunManager:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
||||||
|
|
||||||
|
async def _persist_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
||||||
|
"""Best-effort persist a status transition to the backing store."""
|
||||||
|
if self._store is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await self._store.update_status(run_id, status.value, error=error)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _record_from_store(row: dict[str, Any]) -> RunRecord:
|
||||||
|
"""Build a read-only runtime record from a serialized store row.
|
||||||
|
|
||||||
|
NULL status/on_disconnect columns (e.g. from rows written before those
|
||||||
|
columns were added) default to ``pending`` and ``cancel`` respectively.
|
||||||
|
"""
|
||||||
|
return RunRecord(
|
||||||
|
run_id=row["run_id"],
|
||||||
|
thread_id=row["thread_id"],
|
||||||
|
assistant_id=row.get("assistant_id"),
|
||||||
|
status=RunStatus(row.get("status") or RunStatus.pending.value),
|
||||||
|
on_disconnect=DisconnectMode(row.get("on_disconnect") or DisconnectMode.cancel.value),
|
||||||
|
multitask_strategy=row.get("multitask_strategy") or "reject",
|
||||||
|
metadata=row.get("metadata") or {},
|
||||||
|
kwargs=row.get("kwargs") or {},
|
||||||
|
created_at=row.get("created_at") or "",
|
||||||
|
updated_at=row.get("updated_at") or "",
|
||||||
|
error=row.get("error"),
|
||||||
|
model_name=row.get("model_name"),
|
||||||
|
store_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def update_run_completion(self, run_id: str, **kwargs) -> None:
|
async def update_run_completion(self, run_id: str, **kwargs) -> None:
|
||||||
"""Persist token usage and completion data to the backing store."""
|
"""Persist token usage and completion data to the backing store."""
|
||||||
if self._store is not None:
|
if self._store is not None:
|
||||||
@@ -110,61 +143,77 @@ class RunManager:
|
|||||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
def get(self, run_id: str) -> RunRecord | None:
|
async def get(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None:
|
||||||
"""Return an in-memory run record by ID, or ``None``."""
|
"""Return a run record by ID, or ``None``.
|
||||||
return self._runs.get(run_id)
|
|
||||||
|
|
||||||
async def aget(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None:
|
Args:
|
||||||
"""Return a run record by ID, checking the persistent store as fallback."""
|
run_id: The run ID to look up.
|
||||||
record = self._runs.get(run_id)
|
user_id: Optional user ID for permission filtering when hydrating from store.
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
record = self._runs.get(run_id)
|
||||||
if record is not None:
|
if record is not None:
|
||||||
return record
|
return record
|
||||||
if self._store is not None:
|
if self._store is None:
|
||||||
try:
|
return None
|
||||||
d = await self._store.get(run_id, user_id=user_id)
|
try:
|
||||||
if d is not None:
|
row = await self._store.get(run_id, user_id=user_id)
|
||||||
return self._store_dict_to_record(d)
|
except Exception:
|
||||||
except Exception:
|
logger.warning("Failed to hydrate run %s from store", run_id, exc_info=True)
|
||||||
logger.warning("Failed to query store for run %s", run_id, exc_info=True)
|
return None
|
||||||
return None
|
# Re-check after store await: a concurrent create() may have inserted the
|
||||||
|
# in-memory record while the store call was in flight.
|
||||||
def _store_dict_to_record(self, d: dict) -> RunRecord:
|
|
||||||
"""Convert a store dict back to a RunRecord for read-only use."""
|
|
||||||
return RunRecord(
|
|
||||||
run_id=d["run_id"],
|
|
||||||
thread_id=d["thread_id"],
|
|
||||||
assistant_id=d.get("assistant_id"),
|
|
||||||
status=RunStatus(d.get("status", RunStatus.error.value)),
|
|
||||||
on_disconnect=DisconnectMode.cancel,
|
|
||||||
multitask_strategy=d.get("multitask_strategy", "reject"),
|
|
||||||
metadata=d.get("metadata", {}),
|
|
||||||
kwargs=d.get("kwargs", {}),
|
|
||||||
created_at=d.get("created_at", ""),
|
|
||||||
updated_at=d.get("updated_at", ""),
|
|
||||||
model_name=d.get("model_name"),
|
|
||||||
error=d.get("error"),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def list_by_thread(self, thread_id: str, *, user_id: str | None = None) -> list[RunRecord]:
|
|
||||||
"""Return all runs for a given thread, oldest first."""
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
in_memory = [r for r in self._runs.values() if r.thread_id == thread_id]
|
record = self._runs.get(run_id)
|
||||||
in_memory_ids = {r.run_id for r in in_memory}
|
if record is not None:
|
||||||
|
return record
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return self._record_from_store(row)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to map store row for run %s", run_id, exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
store_records: list[RunRecord] = []
|
async def aget(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None:
|
||||||
if self._store is not None:
|
"""Return a run record by ID, checking the persistent store as fallback.
|
||||||
try:
|
|
||||||
store_dicts = await self._store.list_by_thread(thread_id, user_id=user_id)
|
|
||||||
for d in store_dicts:
|
|
||||||
if d["run_id"] not in in_memory_ids:
|
|
||||||
store_records.append(self._store_dict_to_record(d))
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to query store for thread %s runs", thread_id, exc_info=True)
|
|
||||||
|
|
||||||
return sorted(
|
Alias for :meth:`get` for backward compatibility.
|
||||||
in_memory + store_records,
|
"""
|
||||||
key=lambda record: record.created_at or "",
|
return await self.get(run_id, user_id=user_id)
|
||||||
)
|
|
||||||
|
async def list_by_thread(self, thread_id: str, *, user_id: str | None = None, limit: int = 100) -> list[RunRecord]:
|
||||||
|
"""Return runs for a given thread, newest first, at most ``limit`` records.
|
||||||
|
|
||||||
|
In-memory runs take precedence only when the same ``run_id`` exists in both
|
||||||
|
memory and the backing store. The merged result is then sorted newest-first
|
||||||
|
by ``created_at`` and trimmed to ``limit`` (default 100).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The thread ID to filter by.
|
||||||
|
user_id: Optional user ID for permission filtering when hydrating from store.
|
||||||
|
limit: Maximum number of runs to return.
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
# Dict insertion order gives deterministic results when timestamps tie.
|
||||||
|
memory_records = [r for r in self._runs.values() if r.thread_id == thread_id]
|
||||||
|
if self._store is None:
|
||||||
|
return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit]
|
||||||
|
records_by_id = {record.run_id: record for record in memory_records}
|
||||||
|
store_limit = max(0, limit - len(memory_records))
|
||||||
|
try:
|
||||||
|
rows = await self._store.list_by_thread(thread_id, user_id=user_id, limit=store_limit)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to hydrate runs for thread %s from store", thread_id, exc_info=True)
|
||||||
|
return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit]
|
||||||
|
for row in rows:
|
||||||
|
run_id = row.get("run_id")
|
||||||
|
if run_id and run_id not in records_by_id:
|
||||||
|
try:
|
||||||
|
records_by_id[run_id] = self._record_from_store(row)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to map store row for run %s", run_id, exc_info=True)
|
||||||
|
return sorted(records_by_id.values(), key=lambda record: record.created_at, reverse=True)[:limit]
|
||||||
|
|
||||||
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
||||||
"""Transition a run to a new status."""
|
"""Transition a run to a new status."""
|
||||||
@@ -177,13 +226,18 @@ class RunManager:
|
|||||||
record.updated_at = _now_iso()
|
record.updated_at = _now_iso()
|
||||||
if error is not None:
|
if error is not None:
|
||||||
record.error = error
|
record.error = error
|
||||||
if self._store is not None:
|
await self._persist_status(run_id, status, error=error)
|
||||||
try:
|
|
||||||
await self._store.update_status(run_id, status.value, error=error)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
|
|
||||||
logger.info("Run %s -> %s", run_id, status.value)
|
logger.info("Run %s -> %s", run_id, status.value)
|
||||||
|
|
||||||
|
async def _persist_model_name(self, run_id: str, model_name: str | None) -> None:
|
||||||
|
"""Best-effort persist model_name update to the backing store."""
|
||||||
|
if self._store is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await self._store.update_model_name(run_id, model_name)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to persist model_name update for run %s", run_id, exc_info=True)
|
||||||
|
|
||||||
async def update_model_name(self, run_id: str, model_name: str | None) -> None:
|
async def update_model_name(self, run_id: str, model_name: str | None) -> None:
|
||||||
"""Update the model name for a run."""
|
"""Update the model name for a run."""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
@@ -193,7 +247,7 @@ class RunManager:
|
|||||||
return
|
return
|
||||||
record.model_name = model_name
|
record.model_name = model_name
|
||||||
record.updated_at = _now_iso()
|
record.updated_at = _now_iso()
|
||||||
await self._persist_to_store(record)
|
await self._persist_model_name(run_id, model_name)
|
||||||
logger.info("Run %s model_name=%s", run_id, model_name)
|
logger.info("Run %s model_name=%s", run_id, model_name)
|
||||||
|
|
||||||
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
|
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
|
||||||
@@ -218,6 +272,7 @@ class RunManager:
|
|||||||
record.task.cancel()
|
record.task.cancel()
|
||||||
record.status = RunStatus.interrupted
|
record.status = RunStatus.interrupted
|
||||||
record.updated_at = _now_iso()
|
record.updated_at = _now_iso()
|
||||||
|
await self._persist_status(run_id, RunStatus.interrupted)
|
||||||
logger.info("Run %s cancelled (action=%s)", run_id, action)
|
logger.info("Run %s cancelled (action=%s)", run_id, action)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -245,6 +300,7 @@ class RunManager:
|
|||||||
now = _now_iso()
|
now = _now_iso()
|
||||||
|
|
||||||
_supported_strategies = ("reject", "interrupt", "rollback")
|
_supported_strategies = ("reject", "interrupt", "rollback")
|
||||||
|
interrupted_run_ids: list[str] = []
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
if multitask_strategy not in _supported_strategies:
|
if multitask_strategy not in _supported_strategies:
|
||||||
@@ -263,6 +319,7 @@ class RunManager:
|
|||||||
r.task.cancel()
|
r.task.cancel()
|
||||||
r.status = RunStatus.interrupted
|
r.status = RunStatus.interrupted
|
||||||
r.updated_at = now
|
r.updated_at = now
|
||||||
|
interrupted_run_ids.append(r.run_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Cancelled %d inflight run(s) on thread %s (strategy=%s)",
|
"Cancelled %d inflight run(s) on thread %s (strategy=%s)",
|
||||||
len(inflight),
|
len(inflight),
|
||||||
@@ -285,6 +342,8 @@ class RunManager:
|
|||||||
)
|
)
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
|
|
||||||
|
for interrupted_run_id in interrupted_run_ids:
|
||||||
|
await self._persist_status(interrupted_run_id, RunStatus.interrupted)
|
||||||
await self._persist_to_store(record)
|
await self._persist_to_store(record)
|
||||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||||
return record
|
return record
|
||||||
|
|||||||
@@ -66,6 +66,15 @@ class RunStore(abc.ABC):
|
|||||||
async def delete(self, run_id: str) -> None:
|
async def delete(self, run_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def update_model_name(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
model_name: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Update the model_name field for an existing run."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def update_run_completion(
|
async def update_run_completion(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -66,6 +66,11 @@ class MemoryRunStore(RunStore):
|
|||||||
self._runs[run_id]["error"] = error
|
self._runs[run_id]["error"] = error
|
||||||
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
|
async def update_model_name(self, run_id, model_name):
|
||||||
|
if run_id in self._runs:
|
||||||
|
self._runs[run_id]["model_name"] = model_name
|
||||||
|
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
async def delete(self, run_id):
|
async def delete(self, run_id):
|
||||||
self._runs.pop(run_id, None)
|
self._runs.pop(run_id, None)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import re
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from deerflow.runtime import RunManager, RunStatus
|
from deerflow.runtime import DisconnectMode, RunManager, RunStatus
|
||||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||||
|
|
||||||
ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
|
ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
|
||||||
@@ -34,7 +34,7 @@ async def test_create_and_get(manager: RunManager):
|
|||||||
assert ISO_RE.match(record.created_at)
|
assert ISO_RE.match(record.created_at)
|
||||||
assert ISO_RE.match(record.updated_at)
|
assert ISO_RE.match(record.updated_at)
|
||||||
|
|
||||||
fetched = manager.get(record.run_id)
|
fetched = await manager.get(record.run_id)
|
||||||
assert fetched is record
|
assert fetched is record
|
||||||
|
|
||||||
|
|
||||||
@@ -64,6 +64,22 @@ async def test_cancel(manager: RunManager):
|
|||||||
assert record.status == RunStatus.interrupted
|
assert record.status == RunStatus.interrupted
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cancel_persists_interrupted_status_to_store():
|
||||||
|
"""Cancel should persist interrupted status to the backing store."""
|
||||||
|
store = MemoryRunStore()
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
record = await manager.create("thread-1")
|
||||||
|
await manager.set_status(record.run_id, RunStatus.running)
|
||||||
|
|
||||||
|
cancelled = await manager.cancel(record.run_id)
|
||||||
|
|
||||||
|
stored = await store.get(record.run_id)
|
||||||
|
assert cancelled is True
|
||||||
|
assert stored is not None
|
||||||
|
assert stored["status"] == "interrupted"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_cancel_not_inflight(manager: RunManager):
|
async def test_cancel_not_inflight(manager: RunManager):
|
||||||
"""Cancelling a completed run should return False."""
|
"""Cancelling a completed run should return False."""
|
||||||
@@ -83,9 +99,9 @@ async def test_list_by_thread(manager: RunManager):
|
|||||||
|
|
||||||
runs = await manager.list_by_thread("thread-1")
|
runs = await manager.list_by_thread("thread-1")
|
||||||
assert len(runs) == 2
|
assert len(runs) == 2
|
||||||
# list_by_thread returns oldest-first (ascending created_at).
|
# Newest first: r2 was created after r1.
|
||||||
assert runs[0].run_id == r1.run_id
|
assert runs[0].run_id == r2.run_id
|
||||||
assert runs[1].run_id == r2.run_id
|
assert runs[1].run_id == r1.run_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -117,7 +133,7 @@ async def test_cleanup(manager: RunManager):
|
|||||||
run_id = record.run_id
|
run_id = record.run_id
|
||||||
|
|
||||||
await manager.cleanup(run_id, delay=0)
|
await manager.cleanup(run_id, delay=0)
|
||||||
assert manager.get(run_id) is None
|
assert await manager.get(run_id) is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -132,7 +148,116 @@ async def test_set_status_with_error(manager: RunManager):
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_nonexistent(manager: RunManager):
|
async def test_get_nonexistent(manager: RunManager):
|
||||||
"""Getting a nonexistent run should return None."""
|
"""Getting a nonexistent run should return None."""
|
||||||
assert manager.get("does-not-exist") is None
|
assert await manager.get("does-not-exist") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_hydrates_store_only_run():
|
||||||
|
"""Store-only runs should be readable after process restart."""
|
||||||
|
store = MemoryRunStore()
|
||||||
|
await store.put(
|
||||||
|
"run-store-only",
|
||||||
|
thread_id="thread-1",
|
||||||
|
assistant_id="lead_agent",
|
||||||
|
status="success",
|
||||||
|
multitask_strategy="reject",
|
||||||
|
metadata={"source": "store"},
|
||||||
|
kwargs={"input": "value"},
|
||||||
|
created_at="2026-01-01T00:00:00+00:00",
|
||||||
|
model_name="model-a",
|
||||||
|
)
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
|
||||||
|
record = await manager.get("run-store-only")
|
||||||
|
|
||||||
|
assert record is not None
|
||||||
|
assert record.run_id == "run-store-only"
|
||||||
|
assert record.thread_id == "thread-1"
|
||||||
|
assert record.assistant_id == "lead_agent"
|
||||||
|
assert record.status == RunStatus.success
|
||||||
|
assert record.on_disconnect == DisconnectMode.cancel
|
||||||
|
assert record.metadata == {"source": "store"}
|
||||||
|
assert record.kwargs == {"input": "value"}
|
||||||
|
assert record.model_name == "model-a"
|
||||||
|
assert record.task is None
|
||||||
|
assert record.store_only is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_hydrates_run_with_null_enum_fields():
|
||||||
|
"""Rows with NULL status/on_disconnect must hydrate with safe defaults, not raise."""
|
||||||
|
store = MemoryRunStore()
|
||||||
|
# Simulate a SQL row where the nullable status column is NULL
|
||||||
|
await store.put(
|
||||||
|
"run-null-status",
|
||||||
|
thread_id="thread-1",
|
||||||
|
status=None,
|
||||||
|
created_at="2026-01-01T00:00:00+00:00",
|
||||||
|
)
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
|
||||||
|
record = await manager.get("run-null-status")
|
||||||
|
|
||||||
|
assert record is not None
|
||||||
|
assert record.status == RunStatus.pending
|
||||||
|
assert record.on_disconnect == DisconnectMode.cancel
|
||||||
|
assert record.store_only is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_by_thread_hydrates_run_with_null_enum_fields():
|
||||||
|
"""list_by_thread must not skip rows with NULL status; applies safe defaults."""
|
||||||
|
store = MemoryRunStore()
|
||||||
|
await store.put(
|
||||||
|
"run-null-status-list",
|
||||||
|
thread_id="thread-null",
|
||||||
|
status=None,
|
||||||
|
created_at="2026-01-01T00:00:00+00:00",
|
||||||
|
)
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
|
||||||
|
runs = await manager.list_by_thread("thread-null")
|
||||||
|
|
||||||
|
assert len(runs) == 1
|
||||||
|
assert runs[0].run_id == "run-null-status-list"
|
||||||
|
assert runs[0].status == RunStatus.pending
|
||||||
|
assert runs[0].on_disconnect == DisconnectMode.cancel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_record_is_not_store_only(manager: RunManager):
|
||||||
|
"""In-memory records created via create() must have store_only=False."""
|
||||||
|
record = await manager.create("thread-1")
|
||||||
|
assert record.store_only is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_prefers_in_memory_record_over_store():
|
||||||
|
"""In-memory records retain task/control state when store has same run."""
|
||||||
|
store = MemoryRunStore()
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
record = await manager.create("thread-1")
|
||||||
|
await store.update_status(record.run_id, "success")
|
||||||
|
|
||||||
|
fetched = await manager.get(record.run_id)
|
||||||
|
|
||||||
|
assert fetched is record
|
||||||
|
assert fetched.status == RunStatus.pending
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_by_thread_merges_store_runs_newest_first():
|
||||||
|
"""list_by_thread should merge memory and store rows with memory precedence."""
|
||||||
|
store = MemoryRunStore()
|
||||||
|
await store.put("old-store", thread_id="thread-1", status="success", created_at="2026-01-01T00:00:00+00:00")
|
||||||
|
await store.put("other-thread", thread_id="thread-2", status="success", created_at="2026-01-03T00:00:00+00:00")
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
memory_record = await manager.create("thread-1")
|
||||||
|
|
||||||
|
runs = await manager.list_by_thread("thread-1")
|
||||||
|
|
||||||
|
assert [run.run_id for run in runs] == [memory_record.run_id, "old-store"]
|
||||||
|
assert runs[0] is memory_record
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -171,11 +296,45 @@ async def test_model_name_create_or_reject():
|
|||||||
assert stored["model_name"] == "anthropic.claude-sonnet-4-20250514-v1:0"
|
assert stored["model_name"] == "anthropic.claude-sonnet-4-20250514-v1:0"
|
||||||
|
|
||||||
# Verify retrieval returns the model_name via in-memory record
|
# Verify retrieval returns the model_name via in-memory record
|
||||||
fetched = mgr.get(record.run_id)
|
fetched = await mgr.get(record.run_id)
|
||||||
assert fetched is not None
|
assert fetched is not None
|
||||||
assert fetched.model_name == "anthropic.claude-sonnet-4-20250514-v1:0"
|
assert fetched.model_name == "anthropic.claude-sonnet-4-20250514-v1:0"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_or_reject_interrupt_persists_interrupted_status_to_store():
|
||||||
|
"""interrupt strategy should persist interrupted status for old runs."""
|
||||||
|
store = MemoryRunStore()
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
old = await manager.create("thread-1")
|
||||||
|
await manager.set_status(old.run_id, RunStatus.running)
|
||||||
|
|
||||||
|
new = await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
|
||||||
|
|
||||||
|
stored_old = await store.get(old.run_id)
|
||||||
|
assert new.run_id != old.run_id
|
||||||
|
assert old.status == RunStatus.interrupted
|
||||||
|
assert stored_old is not None
|
||||||
|
assert stored_old["status"] == "interrupted"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_or_reject_rollback_persists_interrupted_status_to_store():
|
||||||
|
"""rollback strategy should persist interrupted status for old runs."""
|
||||||
|
store = MemoryRunStore()
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
old = await manager.create("thread-1")
|
||||||
|
await manager.set_status(old.run_id, RunStatus.running)
|
||||||
|
|
||||||
|
new = await manager.create_or_reject("thread-1", multitask_strategy="rollback")
|
||||||
|
|
||||||
|
stored_old = await store.get(old.run_id)
|
||||||
|
assert new.run_id != old.run_id
|
||||||
|
assert old.status == RunStatus.interrupted
|
||||||
|
assert stored_old is not None
|
||||||
|
assert stored_old["status"] == "interrupted"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_model_name_default_is_none():
|
async def test_model_name_default_is_none():
|
||||||
"""create_or_reject without model_name should default to None."""
|
"""create_or_reject without model_name should default to None."""
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import pytest
|
|||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
from deerflow.persistence.run import RunRepository
|
from deerflow.persistence.run import RunRepository
|
||||||
|
from deerflow.runtime import RunManager, RunStatus
|
||||||
|
|
||||||
|
|
||||||
async def _make_repo(tmp_path):
|
async def _make_repo(tmp_path):
|
||||||
@@ -326,3 +327,105 @@ class TestRunRepository:
|
|||||||
assert select_match is not None
|
assert select_match is not None
|
||||||
assert group_by_match is not None
|
assert group_by_match is not None
|
||||||
assert select_match.group(1) == group_by_match.group(1)
|
assert select_match.group(1) == group_by_match.group(1)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_run_manager_hydrates_store_only_run_from_sql(self, tmp_path):
|
||||||
|
"""RunManager should hydrate historical runs from SQL-backed store."""
|
||||||
|
repo = await _make_repo(tmp_path)
|
||||||
|
await repo.put(
|
||||||
|
"sql-store-only",
|
||||||
|
thread_id="thread-1",
|
||||||
|
assistant_id="lead_agent",
|
||||||
|
status="success",
|
||||||
|
metadata={"source": "sql"},
|
||||||
|
kwargs={"input": "value"},
|
||||||
|
model_name="model-a",
|
||||||
|
)
|
||||||
|
manager = RunManager(store=repo)
|
||||||
|
|
||||||
|
record = await manager.get("sql-store-only")
|
||||||
|
rows = await manager.list_by_thread("thread-1")
|
||||||
|
|
||||||
|
assert record is not None
|
||||||
|
assert record.run_id == "sql-store-only"
|
||||||
|
assert record.status == RunStatus.success
|
||||||
|
assert record.metadata == {"source": "sql"}
|
||||||
|
assert record.kwargs == {"input": "value"}
|
||||||
|
assert record.model_name == "model-a"
|
||||||
|
assert [run.run_id for run in rows] == ["sql-store-only"]
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_run_manager_cancel_persists_interrupted_status_to_sql(self, tmp_path):
|
||||||
|
"""RunManager.cancel should write interrupted status to SQL-backed store."""
|
||||||
|
repo = await _make_repo(tmp_path)
|
||||||
|
manager = RunManager(store=repo)
|
||||||
|
record = await manager.create("thread-1")
|
||||||
|
await manager.set_status(record.run_id, RunStatus.running)
|
||||||
|
|
||||||
|
cancelled = await manager.cancel(record.run_id)
|
||||||
|
row = await repo.get(record.run_id)
|
||||||
|
|
||||||
|
assert cancelled is True
|
||||||
|
assert row is not None
|
||||||
|
assert row["status"] == "interrupted"
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_model_name(self, tmp_path):
|
||||||
|
"""RunRepository.update_model_name should update model_name for existing run."""
|
||||||
|
repo = await _make_repo(tmp_path)
|
||||||
|
await repo.put("r1", thread_id="t1", model_name="initial-model")
|
||||||
|
await repo.update_model_name("r1", "updated-model")
|
||||||
|
row = await repo.get("r1")
|
||||||
|
assert row["model_name"] == "updated-model"
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_model_name_normalizes_value(self, tmp_path):
|
||||||
|
"""RunRepository.update_model_name should normalize and truncate model_name."""
|
||||||
|
repo = await _make_repo(tmp_path)
|
||||||
|
await repo.put("r1", thread_id="t1")
|
||||||
|
long_name = "a" * 200
|
||||||
|
await repo.update_model_name("r1", long_name)
|
||||||
|
row = await repo.get("r1")
|
||||||
|
assert row["model_name"] == "a" * 128
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_model_name_to_none(self, tmp_path):
|
||||||
|
"""RunRepository.update_model_name should allow setting model_name to None."""
|
||||||
|
repo = await _make_repo(tmp_path)
|
||||||
|
await repo.put("r1", thread_id="t1", model_name="initial-model")
|
||||||
|
await repo.update_model_name("r1", None)
|
||||||
|
row = await repo.get("r1")
|
||||||
|
assert row["model_name"] is None
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_run_manager_update_model_name_persists_to_sql(self, tmp_path):
|
||||||
|
"""RunManager.update_model_name should persist to SQL-backed store without integrity error."""
|
||||||
|
repo = await _make_repo(tmp_path)
|
||||||
|
manager = RunManager(store=repo)
|
||||||
|
record = await manager.create("thread-1")
|
||||||
|
|
||||||
|
await manager.update_model_name(record.run_id, "gpt-4o")
|
||||||
|
|
||||||
|
row = await repo.get(record.run_id)
|
||||||
|
assert row is not None
|
||||||
|
assert row["model_name"] == "gpt-4o"
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_run_manager_update_model_name_twice(self, tmp_path):
|
||||||
|
"""RunManager.update_model_name should support multiple updates."""
|
||||||
|
repo = await _make_repo(tmp_path)
|
||||||
|
manager = RunManager(store=repo)
|
||||||
|
record = await manager.create("thread-1")
|
||||||
|
|
||||||
|
await manager.update_model_name(record.run_id, "model-1")
|
||||||
|
await manager.update_model_name(record.run_id, "model-2")
|
||||||
|
|
||||||
|
row = await repo.get(record.run_id)
|
||||||
|
assert row["model_name"] == "model-2"
|
||||||
|
await _cleanup()
|
||||||
|
|||||||
@@ -88,7 +88,9 @@ async def test_run_agent_threads_explicit_app_config_into_config_only_factory():
|
|||||||
|
|
||||||
assert captured["factory_context"]["app_config"] is app_config
|
assert captured["factory_context"]["app_config"] is app_config
|
||||||
assert captured["astream_context"]["app_config"] is app_config
|
assert captured["astream_context"]["app_config"] is app_config
|
||||||
assert run_manager.get(record.run_id).status == RunStatus.success
|
fetched = await run_manager.get(record.run_id)
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.status == RunStatus.success
|
||||||
bridge.publish_end.assert_awaited_once_with(record.run_id)
|
bridge.publish_end.assert_awaited_once_with(record.run_id)
|
||||||
bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60)
|
bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60)
|
||||||
|
|
||||||
|
|||||||
@@ -2,25 +2,30 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
from _router_auth_helpers import make_authed_test_app
|
from _router_auth_helpers import make_authed_test_app
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from app.gateway.routers import thread_runs
|
from app.gateway.routers import thread_runs
|
||||||
|
from deerflow.runtime import RunManager
|
||||||
|
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _make_app(event_store=None):
|
def _make_app(event_store=None, run_manager=None):
|
||||||
"""Build a test FastAPI app with stub auth and mocked state."""
|
"""Build a test FastAPI app with stub auth and mocked state."""
|
||||||
app = make_authed_test_app()
|
app = make_authed_test_app()
|
||||||
app.include_router(thread_runs.router)
|
app.include_router(thread_runs.router)
|
||||||
|
|
||||||
if event_store is not None:
|
if event_store is not None:
|
||||||
app.state.run_event_store = event_store
|
app.state.run_event_store = event_store
|
||||||
|
if run_manager is not None:
|
||||||
|
app.state.run_manager = run_manager
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
@@ -36,6 +41,23 @@ def _make_message(seq: int) -> dict:
|
|||||||
return {"seq": seq, "event_type": "ai_message", "category": "message", "content": f"msg-{seq}"}
|
return {"seq": seq, "event_type": "ai_message", "category": "message", "content": f"msg-{seq}"}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_store_only_run_manager() -> RunManager:
|
||||||
|
store = MemoryRunStore()
|
||||||
|
asyncio.run(
|
||||||
|
store.put(
|
||||||
|
"store-only-run",
|
||||||
|
thread_id="thread-store",
|
||||||
|
assistant_id="lead_agent",
|
||||||
|
status="running",
|
||||||
|
multitask_strategy="reject",
|
||||||
|
metadata={},
|
||||||
|
kwargs={},
|
||||||
|
created_at="2026-01-01T00:00:00+00:00",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return RunManager(store=store)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Tests
|
# Tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -128,3 +150,46 @@ def test_empty_data_when_no_messages():
|
|||||||
body = response.json()
|
body = response.json()
|
||||||
assert body["data"] == []
|
assert body["data"] == []
|
||||||
assert body["has_more"] is False
|
assert body["has_more"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_run_hydrates_store_only_run():
|
||||||
|
"""GET /api/threads/{tid}/runs/{rid} should read historical store rows."""
|
||||||
|
app = _make_app(run_manager=_make_store_only_run_manager())
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/threads/thread-store/runs/store-only-run")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert body["run_id"] == "store-only-run"
|
||||||
|
assert body["thread_id"] == "thread-store"
|
||||||
|
assert body["status"] == "running"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_store_only_run_returns_409():
|
||||||
|
"""Store-only runs are readable but not cancellable by this worker."""
|
||||||
|
app = _make_app(run_manager=_make_store_only_run_manager())
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post("/api/threads/thread-store/runs/store-only-run/cancel")
|
||||||
|
|
||||||
|
assert response.status_code == 409
|
||||||
|
assert "not active on this worker" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_join_store_only_run_returns_409():
|
||||||
|
"""join endpoint should return 409 for store-only runs (no local stream state)."""
|
||||||
|
app = _make_app(run_manager=_make_store_only_run_manager())
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/threads/thread-store/runs/store-only-run/join")
|
||||||
|
|
||||||
|
assert response.status_code == 409
|
||||||
|
assert "not active on this worker" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_store_only_run_returns_409():
|
||||||
|
"""stream endpoint (action=None) should return 409 for store-only runs."""
|
||||||
|
app = _make_app(run_manager=_make_store_only_run_manager())
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/threads/thread-store/runs/store-only-run/stream")
|
||||||
|
|
||||||
|
assert response.status_code == 409
|
||||||
|
assert "not active on this worker" in response.json()["detail"]
|
||||||
|
|||||||
Reference in New Issue
Block a user