diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 35607c6fd..b951f919c 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -225,6 +225,12 @@ CORS is same-origin by default when requests enter through nginx on port 2026. S | **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific | | **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id | +**RunManager / RunStore contract**: +- `RunManager.get()` is async; direct callers must `await` it. +- When a persistent `RunStore` is configured, `get()` and `list_by_thread()` hydrate historical runs from the store. In-memory records win for the same `run_id` so task, abort, and stream-control state stays attached to active local runs. +- `cancel()` and `create_or_reject(..., multitask_strategy="interrupt"|"rollback")` persist interrupted status through `RunStore.update_status()`, matching normal `set_status()` transitions. +- Store-only hydrated runs are readable history. If the current worker has no in-memory task/control state for that run, cancellation APIs can return 409 because this worker cannot stop the task. + Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runtime, all other `/api/*` → Gateway REST APIs. ### Sandbox System (`packages/harness/deerflow/sandbox/`) diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py index 3d429fc03..294fa9799 100644 --- a/backend/app/gateway/routers/thread_runs.py +++ b/backend/app/gateway/routers/thread_runs.py @@ -22,7 +22,7 @@ from pydantic import BaseModel, Field from app.gateway.authz import require_permission from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge from app.gateway.services import sse_consumer, start_run -from deerflow.runtime import RunRecord, serialize_channel_values +from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/threads", tags=["runs"]) @@ -94,6 +94,12 @@ class ThreadTokenUsageResponse(BaseModel): # --------------------------------------------------------------------------- +def _cancel_conflict_detail(run_id: str, record: RunRecord) -> str: + if record.status in (RunStatus.pending, RunStatus.running): + return f"Run {run_id} is not active on this worker and cannot be cancelled" + return f"Run {run_id} is not cancellable (status: {record.status.value})" + + def _record_to_response(record: RunRecord) -> RunResponse: return RunResponse( run_id=record.run_id, @@ -191,7 +197,7 @@ async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse: """Get details of a specific run.""" run_mgr = get_run_manager(request) user_id = await get_current_user(request) - record = await run_mgr.aget(run_id, user_id=user_id) + record = await run_mgr.get(run_id, user_id=user_id) if record is None or record.thread_id != thread_id: raise HTTPException(status_code=404, detail=f"Run {run_id} not found") return _record_to_response(record) @@ -214,16 +220,13 @@ async def cancel_run( - wait=false: Return immediately with 202 """ run_mgr = get_run_manager(request) - record = run_mgr.get(run_id) + record = await run_mgr.get(run_id) if record is None or record.thread_id != thread_id: raise HTTPException(status_code=404, detail=f"Run {run_id} not found") cancelled = await run_mgr.cancel(run_id, action=action) if not cancelled: - raise HTTPException( - status_code=409, - detail=f"Run {run_id} is not cancellable (status: {record.status.value})", - ) + raise HTTPException(status_code=409, detail=_cancel_conflict_detail(run_id, record)) if wait and record.task is not None: try: @@ -239,12 +242,14 @@ async def cancel_run( @require_permission("runs", "read", owner_check=True) async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse: """Join an existing run's SSE stream.""" - bridge = get_stream_bridge(request) run_mgr = get_run_manager(request) - record = run_mgr.get(run_id) + record = await run_mgr.get(run_id) if record is None or record.thread_id != thread_id: raise HTTPException(status_code=404, detail=f"Run {run_id} not found") + if record.store_only: + raise HTTPException(status_code=409, detail=f"Run {run_id} is not active on this worker and cannot be streamed") + bridge = get_stream_bridge(request) return StreamingResponse( sse_consumer(bridge, record, request, run_mgr), media_type="text/event-stream", @@ -273,14 +278,18 @@ async def stream_existing_run( remaining buffered events so the client observes a clean shutdown. """ run_mgr = get_run_manager(request) - record = run_mgr.get(run_id) + record = await run_mgr.get(run_id) if record is None or record.thread_id != thread_id: raise HTTPException(status_code=404, detail=f"Run {run_id} not found") + if record.store_only and action is None: + raise HTTPException(status_code=409, detail=f"Run {run_id} is not active on this worker and cannot be streamed") # Cancel if an action was requested (stop-button / interrupt flow) if action is not None: cancelled = await run_mgr.cancel(run_id, action=action) - if cancelled and wait and record.task is not None: + if not cancelled: + raise HTTPException(status_code=409, detail=_cancel_conflict_detail(run_id, record)) + if wait and record.task is not None: try: await record.task except (asyncio.CancelledError, Exception): diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py index 5331451e3..d586a2b13 100644 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ b/backend/packages/harness/deerflow/persistence/run/sql.py @@ -151,6 +151,11 @@ class RunRepository(RunStore): await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) await session.commit() + async def update_model_name(self, run_id, model_name): + async with self._sf() as session: + await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(model_name=self._normalize_model_name(model_name), updated_at=datetime.now(UTC))) + await session.commit() + async def delete( self, run_id, diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index 11d6b478e..06731eb91 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -6,7 +6,7 @@ import asyncio import logging import uuid from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from deerflow.utils.time import now_iso as _now_iso @@ -37,6 +37,7 @@ class RunRecord: abort_action: str = "interrupt" error: str | None = None model_name: str | None = None + store_only: bool = False class RunManager: @@ -71,6 +72,38 @@ class RunManager: except Exception: logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True) + async def _persist_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: + """Best-effort persist a status transition to the backing store.""" + if self._store is None: + return + try: + await self._store.update_status(run_id, status.value, error=error) + except Exception: + logger.warning("Failed to persist status update for run %s", run_id, exc_info=True) + + @staticmethod + def _record_from_store(row: dict[str, Any]) -> RunRecord: + """Build a read-only runtime record from a serialized store row. + + NULL status/on_disconnect columns (e.g. from rows written before those + columns were added) default to ``pending`` and ``cancel`` respectively. + """ + return RunRecord( + run_id=row["run_id"], + thread_id=row["thread_id"], + assistant_id=row.get("assistant_id"), + status=RunStatus(row.get("status") or RunStatus.pending.value), + on_disconnect=DisconnectMode(row.get("on_disconnect") or DisconnectMode.cancel.value), + multitask_strategy=row.get("multitask_strategy") or "reject", + metadata=row.get("metadata") or {}, + kwargs=row.get("kwargs") or {}, + created_at=row.get("created_at") or "", + updated_at=row.get("updated_at") or "", + error=row.get("error"), + model_name=row.get("model_name"), + store_only=True, + ) + async def update_run_completion(self, run_id: str, **kwargs) -> None: """Persist token usage and completion data to the backing store.""" if self._store is not None: @@ -110,61 +143,77 @@ class RunManager: logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record - def get(self, run_id: str) -> RunRecord | None: - """Return an in-memory run record by ID, or ``None``.""" - return self._runs.get(run_id) + async def get(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None: + """Return a run record by ID, or ``None``. - async def aget(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None: - """Return a run record by ID, checking the persistent store as fallback.""" - record = self._runs.get(run_id) + Args: + run_id: The run ID to look up. + user_id: Optional user ID for permission filtering when hydrating from store. + """ + async with self._lock: + record = self._runs.get(run_id) if record is not None: return record - if self._store is not None: - try: - d = await self._store.get(run_id, user_id=user_id) - if d is not None: - return self._store_dict_to_record(d) - except Exception: - logger.warning("Failed to query store for run %s", run_id, exc_info=True) - return None - - def _store_dict_to_record(self, d: dict) -> RunRecord: - """Convert a store dict back to a RunRecord for read-only use.""" - return RunRecord( - run_id=d["run_id"], - thread_id=d["thread_id"], - assistant_id=d.get("assistant_id"), - status=RunStatus(d.get("status", RunStatus.error.value)), - on_disconnect=DisconnectMode.cancel, - multitask_strategy=d.get("multitask_strategy", "reject"), - metadata=d.get("metadata", {}), - kwargs=d.get("kwargs", {}), - created_at=d.get("created_at", ""), - updated_at=d.get("updated_at", ""), - model_name=d.get("model_name"), - error=d.get("error"), - ) - - async def list_by_thread(self, thread_id: str, *, user_id: str | None = None) -> list[RunRecord]: - """Return all runs for a given thread, oldest first.""" + if self._store is None: + return None + try: + row = await self._store.get(run_id, user_id=user_id) + except Exception: + logger.warning("Failed to hydrate run %s from store", run_id, exc_info=True) + return None + # Re-check after store await: a concurrent create() may have inserted the + # in-memory record while the store call was in flight. async with self._lock: - in_memory = [r for r in self._runs.values() if r.thread_id == thread_id] - in_memory_ids = {r.run_id for r in in_memory} + record = self._runs.get(run_id) + if record is not None: + return record + if row is None: + return None + try: + return self._record_from_store(row) + except Exception: + logger.warning("Failed to map store row for run %s", run_id, exc_info=True) + return None - store_records: list[RunRecord] = [] - if self._store is not None: - try: - store_dicts = await self._store.list_by_thread(thread_id, user_id=user_id) - for d in store_dicts: - if d["run_id"] not in in_memory_ids: - store_records.append(self._store_dict_to_record(d)) - except Exception: - logger.warning("Failed to query store for thread %s runs", thread_id, exc_info=True) + async def aget(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None: + """Return a run record by ID, checking the persistent store as fallback. - return sorted( - in_memory + store_records, - key=lambda record: record.created_at or "", - ) + Alias for :meth:`get` for backward compatibility. + """ + return await self.get(run_id, user_id=user_id) + + async def list_by_thread(self, thread_id: str, *, user_id: str | None = None, limit: int = 100) -> list[RunRecord]: + """Return runs for a given thread, newest first, at most ``limit`` records. + + In-memory runs take precedence only when the same ``run_id`` exists in both + memory and the backing store. The merged result is then sorted newest-first + by ``created_at`` and trimmed to ``limit`` (default 100). + + Args: + thread_id: The thread ID to filter by. + user_id: Optional user ID for permission filtering when hydrating from store. + limit: Maximum number of runs to return. + """ + async with self._lock: + # Dict insertion order gives deterministic results when timestamps tie. + memory_records = [r for r in self._runs.values() if r.thread_id == thread_id] + if self._store is None: + return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit] + records_by_id = {record.run_id: record for record in memory_records} + store_limit = max(0, limit - len(memory_records)) + try: + rows = await self._store.list_by_thread(thread_id, user_id=user_id, limit=store_limit) + except Exception: + logger.warning("Failed to hydrate runs for thread %s from store", thread_id, exc_info=True) + return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit] + for row in rows: + run_id = row.get("run_id") + if run_id and run_id not in records_by_id: + try: + records_by_id[run_id] = self._record_from_store(row) + except Exception: + logger.warning("Failed to map store row for run %s", run_id, exc_info=True) + return sorted(records_by_id.values(), key=lambda record: record.created_at, reverse=True)[:limit] async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: """Transition a run to a new status.""" @@ -177,13 +226,18 @@ class RunManager: record.updated_at = _now_iso() if error is not None: record.error = error - if self._store is not None: - try: - await self._store.update_status(run_id, status.value, error=error) - except Exception: - logger.warning("Failed to persist status update for run %s", run_id, exc_info=True) + await self._persist_status(run_id, status, error=error) logger.info("Run %s -> %s", run_id, status.value) + async def _persist_model_name(self, run_id: str, model_name: str | None) -> None: + """Best-effort persist model_name update to the backing store.""" + if self._store is None: + return + try: + await self._store.update_model_name(run_id, model_name) + except Exception: + logger.warning("Failed to persist model_name update for run %s", run_id, exc_info=True) + async def update_model_name(self, run_id: str, model_name: str | None) -> None: """Update the model name for a run.""" async with self._lock: @@ -193,7 +247,7 @@ class RunManager: return record.model_name = model_name record.updated_at = _now_iso() - await self._persist_to_store(record) + await self._persist_model_name(run_id, model_name) logger.info("Run %s model_name=%s", run_id, model_name) async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool: @@ -218,6 +272,7 @@ class RunManager: record.task.cancel() record.status = RunStatus.interrupted record.updated_at = _now_iso() + await self._persist_status(run_id, RunStatus.interrupted) logger.info("Run %s cancelled (action=%s)", run_id, action) return True @@ -245,6 +300,7 @@ class RunManager: now = _now_iso() _supported_strategies = ("reject", "interrupt", "rollback") + interrupted_run_ids: list[str] = [] async with self._lock: if multitask_strategy not in _supported_strategies: @@ -263,6 +319,7 @@ class RunManager: r.task.cancel() r.status = RunStatus.interrupted r.updated_at = now + interrupted_run_ids.append(r.run_id) logger.info( "Cancelled %d inflight run(s) on thread %s (strategy=%s)", len(inflight), @@ -285,6 +342,8 @@ class RunManager: ) self._runs[run_id] = record + for interrupted_run_id in interrupted_run_ids: + await self._persist_status(interrupted_run_id, RunStatus.interrupted) await self._persist_to_store(record) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record diff --git a/backend/packages/harness/deerflow/runtime/runs/store/base.py b/backend/packages/harness/deerflow/runtime/runs/store/base.py index a742d89ca..10c90d7ea 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/base.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/base.py @@ -66,6 +66,15 @@ class RunStore(abc.ABC): async def delete(self, run_id: str) -> None: pass + @abc.abstractmethod + async def update_model_name( + self, + run_id: str, + model_name: str | None, + ) -> None: + """Update the model_name field for an existing run.""" + pass + @abc.abstractmethod async def update_run_completion( self, diff --git a/backend/packages/harness/deerflow/runtime/runs/store/memory.py b/backend/packages/harness/deerflow/runtime/runs/store/memory.py index 9db27cacc..56ef02b5b 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/memory.py @@ -66,6 +66,11 @@ class MemoryRunStore(RunStore): self._runs[run_id]["error"] = error self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() + async def update_model_name(self, run_id, model_name): + if run_id in self._runs: + self._runs[run_id]["model_name"] = model_name + self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() + async def delete(self, run_id): self._runs.pop(run_id, None) diff --git a/backend/tests/test_run_manager.py b/backend/tests/test_run_manager.py index de8f66319..e7b5f06f5 100644 --- a/backend/tests/test_run_manager.py +++ b/backend/tests/test_run_manager.py @@ -4,7 +4,7 @@ import re import pytest -from deerflow.runtime import RunManager, RunStatus +from deerflow.runtime import DisconnectMode, RunManager, RunStatus from deerflow.runtime.runs.store.memory import MemoryRunStore ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") @@ -34,7 +34,7 @@ async def test_create_and_get(manager: RunManager): assert ISO_RE.match(record.created_at) assert ISO_RE.match(record.updated_at) - fetched = manager.get(record.run_id) + fetched = await manager.get(record.run_id) assert fetched is record @@ -64,6 +64,22 @@ async def test_cancel(manager: RunManager): assert record.status == RunStatus.interrupted +@pytest.mark.anyio +async def test_cancel_persists_interrupted_status_to_store(): + """Cancel should persist interrupted status to the backing store.""" + store = MemoryRunStore() + manager = RunManager(store=store) + record = await manager.create("thread-1") + await manager.set_status(record.run_id, RunStatus.running) + + cancelled = await manager.cancel(record.run_id) + + stored = await store.get(record.run_id) + assert cancelled is True + assert stored is not None + assert stored["status"] == "interrupted" + + @pytest.mark.anyio async def test_cancel_not_inflight(manager: RunManager): """Cancelling a completed run should return False.""" @@ -83,9 +99,9 @@ async def test_list_by_thread(manager: RunManager): runs = await manager.list_by_thread("thread-1") assert len(runs) == 2 - # list_by_thread returns oldest-first (ascending created_at). - assert runs[0].run_id == r1.run_id - assert runs[1].run_id == r2.run_id + # Newest first: r2 was created after r1. + assert runs[0].run_id == r2.run_id + assert runs[1].run_id == r1.run_id @pytest.mark.anyio @@ -117,7 +133,7 @@ async def test_cleanup(manager: RunManager): run_id = record.run_id await manager.cleanup(run_id, delay=0) - assert manager.get(run_id) is None + assert await manager.get(run_id) is None @pytest.mark.anyio @@ -132,7 +148,116 @@ async def test_set_status_with_error(manager: RunManager): @pytest.mark.anyio async def test_get_nonexistent(manager: RunManager): """Getting a nonexistent run should return None.""" - assert manager.get("does-not-exist") is None + assert await manager.get("does-not-exist") is None + + +@pytest.mark.anyio +async def test_get_hydrates_store_only_run(): + """Store-only runs should be readable after process restart.""" + store = MemoryRunStore() + await store.put( + "run-store-only", + thread_id="thread-1", + assistant_id="lead_agent", + status="success", + multitask_strategy="reject", + metadata={"source": "store"}, + kwargs={"input": "value"}, + created_at="2026-01-01T00:00:00+00:00", + model_name="model-a", + ) + manager = RunManager(store=store) + + record = await manager.get("run-store-only") + + assert record is not None + assert record.run_id == "run-store-only" + assert record.thread_id == "thread-1" + assert record.assistant_id == "lead_agent" + assert record.status == RunStatus.success + assert record.on_disconnect == DisconnectMode.cancel + assert record.metadata == {"source": "store"} + assert record.kwargs == {"input": "value"} + assert record.model_name == "model-a" + assert record.task is None + assert record.store_only is True + + +@pytest.mark.anyio +async def test_get_hydrates_run_with_null_enum_fields(): + """Rows with NULL status/on_disconnect must hydrate with safe defaults, not raise.""" + store = MemoryRunStore() + # Simulate a SQL row where the nullable status column is NULL + await store.put( + "run-null-status", + thread_id="thread-1", + status=None, + created_at="2026-01-01T00:00:00+00:00", + ) + manager = RunManager(store=store) + + record = await manager.get("run-null-status") + + assert record is not None + assert record.status == RunStatus.pending + assert record.on_disconnect == DisconnectMode.cancel + assert record.store_only is True + + +@pytest.mark.anyio +async def test_list_by_thread_hydrates_run_with_null_enum_fields(): + """list_by_thread must not skip rows with NULL status; applies safe defaults.""" + store = MemoryRunStore() + await store.put( + "run-null-status-list", + thread_id="thread-null", + status=None, + created_at="2026-01-01T00:00:00+00:00", + ) + manager = RunManager(store=store) + + runs = await manager.list_by_thread("thread-null") + + assert len(runs) == 1 + assert runs[0].run_id == "run-null-status-list" + assert runs[0].status == RunStatus.pending + assert runs[0].on_disconnect == DisconnectMode.cancel + + +@pytest.mark.anyio +async def test_create_record_is_not_store_only(manager: RunManager): + """In-memory records created via create() must have store_only=False.""" + record = await manager.create("thread-1") + assert record.store_only is False + + +@pytest.mark.anyio +async def test_get_prefers_in_memory_record_over_store(): + """In-memory records retain task/control state when store has same run.""" + store = MemoryRunStore() + manager = RunManager(store=store) + record = await manager.create("thread-1") + await store.update_status(record.run_id, "success") + + fetched = await manager.get(record.run_id) + + assert fetched is record + assert fetched.status == RunStatus.pending + + +@pytest.mark.anyio +async def test_list_by_thread_merges_store_runs_newest_first(): + """list_by_thread should merge memory and store rows with memory precedence.""" + store = MemoryRunStore() + await store.put("old-store", thread_id="thread-1", status="success", created_at="2026-01-01T00:00:00+00:00") + await store.put("other-thread", thread_id="thread-2", status="success", created_at="2026-01-03T00:00:00+00:00") + manager = RunManager(store=store) + memory_record = await manager.create("thread-1") + + runs = await manager.list_by_thread("thread-1") + + assert [run.run_id for run in runs] == [memory_record.run_id, "old-store"] + assert runs[0] is memory_record @pytest.mark.anyio @@ -171,11 +296,45 @@ async def test_model_name_create_or_reject(): assert stored["model_name"] == "anthropic.claude-sonnet-4-20250514-v1:0" # Verify retrieval returns the model_name via in-memory record - fetched = mgr.get(record.run_id) + fetched = await mgr.get(record.run_id) assert fetched is not None assert fetched.model_name == "anthropic.claude-sonnet-4-20250514-v1:0" +@pytest.mark.anyio +async def test_create_or_reject_interrupt_persists_interrupted_status_to_store(): + """interrupt strategy should persist interrupted status for old runs.""" + store = MemoryRunStore() + manager = RunManager(store=store) + old = await manager.create("thread-1") + await manager.set_status(old.run_id, RunStatus.running) + + new = await manager.create_or_reject("thread-1", multitask_strategy="interrupt") + + stored_old = await store.get(old.run_id) + assert new.run_id != old.run_id + assert old.status == RunStatus.interrupted + assert stored_old is not None + assert stored_old["status"] == "interrupted" + + +@pytest.mark.anyio +async def test_create_or_reject_rollback_persists_interrupted_status_to_store(): + """rollback strategy should persist interrupted status for old runs.""" + store = MemoryRunStore() + manager = RunManager(store=store) + old = await manager.create("thread-1") + await manager.set_status(old.run_id, RunStatus.running) + + new = await manager.create_or_reject("thread-1", multitask_strategy="rollback") + + stored_old = await store.get(old.run_id) + assert new.run_id != old.run_id + assert old.status == RunStatus.interrupted + assert stored_old is not None + assert stored_old["status"] == "interrupted" + + @pytest.mark.anyio async def test_model_name_default_is_none(): """create_or_reject without model_name should default to None.""" diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py index 5e230e790..5809db517 100644 --- a/backend/tests/test_run_repository.py +++ b/backend/tests/test_run_repository.py @@ -9,6 +9,7 @@ import pytest from sqlalchemy.dialects import postgresql from deerflow.persistence.run import RunRepository +from deerflow.runtime import RunManager, RunStatus async def _make_repo(tmp_path): @@ -326,3 +327,105 @@ class TestRunRepository: assert select_match is not None assert group_by_match is not None assert select_match.group(1) == group_by_match.group(1) + + @pytest.mark.anyio + async def test_run_manager_hydrates_store_only_run_from_sql(self, tmp_path): + """RunManager should hydrate historical runs from SQL-backed store.""" + repo = await _make_repo(tmp_path) + await repo.put( + "sql-store-only", + thread_id="thread-1", + assistant_id="lead_agent", + status="success", + metadata={"source": "sql"}, + kwargs={"input": "value"}, + model_name="model-a", + ) + manager = RunManager(store=repo) + + record = await manager.get("sql-store-only") + rows = await manager.list_by_thread("thread-1") + + assert record is not None + assert record.run_id == "sql-store-only" + assert record.status == RunStatus.success + assert record.metadata == {"source": "sql"} + assert record.kwargs == {"input": "value"} + assert record.model_name == "model-a" + assert [run.run_id for run in rows] == ["sql-store-only"] + await _cleanup() + + @pytest.mark.anyio + async def test_run_manager_cancel_persists_interrupted_status_to_sql(self, tmp_path): + """RunManager.cancel should write interrupted status to SQL-backed store.""" + repo = await _make_repo(tmp_path) + manager = RunManager(store=repo) + record = await manager.create("thread-1") + await manager.set_status(record.run_id, RunStatus.running) + + cancelled = await manager.cancel(record.run_id) + row = await repo.get(record.run_id) + + assert cancelled is True + assert row is not None + assert row["status"] == "interrupted" + await _cleanup() + + @pytest.mark.anyio + async def test_update_model_name(self, tmp_path): + """RunRepository.update_model_name should update model_name for existing run.""" + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", model_name="initial-model") + await repo.update_model_name("r1", "updated-model") + row = await repo.get("r1") + assert row["model_name"] == "updated-model" + await _cleanup() + + @pytest.mark.anyio + async def test_update_model_name_normalizes_value(self, tmp_path): + """RunRepository.update_model_name should normalize and truncate model_name.""" + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1") + long_name = "a" * 200 + await repo.update_model_name("r1", long_name) + row = await repo.get("r1") + assert row["model_name"] == "a" * 128 + await _cleanup() + + @pytest.mark.anyio + async def test_update_model_name_to_none(self, tmp_path): + """RunRepository.update_model_name should allow setting model_name to None.""" + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", model_name="initial-model") + await repo.update_model_name("r1", None) + row = await repo.get("r1") + assert row["model_name"] is None + await _cleanup() + + @pytest.mark.anyio + async def test_run_manager_update_model_name_persists_to_sql(self, tmp_path): + """RunManager.update_model_name should persist to SQL-backed store without integrity error.""" + repo = await _make_repo(tmp_path) + manager = RunManager(store=repo) + record = await manager.create("thread-1") + + await manager.update_model_name(record.run_id, "gpt-4o") + + row = await repo.get(record.run_id) + assert row is not None + assert row["model_name"] == "gpt-4o" + await _cleanup() + + @pytest.mark.anyio + async def test_run_manager_update_model_name_twice(self, tmp_path): + """RunManager.update_model_name should support multiple updates.""" + repo = await _make_repo(tmp_path) + manager = RunManager(store=repo) + record = await manager.create("thread-1") + + await manager.update_model_name(record.run_id, "model-1") + await manager.update_model_name(record.run_id, "model-2") + + row = await repo.get(record.run_id) + assert row["model_name"] == "model-2" + await _cleanup() diff --git a/backend/tests/test_run_worker_rollback.py b/backend/tests/test_run_worker_rollback.py index 0a4421e2f..72e3ac98e 100644 --- a/backend/tests/test_run_worker_rollback.py +++ b/backend/tests/test_run_worker_rollback.py @@ -88,7 +88,9 @@ async def test_run_agent_threads_explicit_app_config_into_config_only_factory(): assert captured["factory_context"]["app_config"] is app_config assert captured["astream_context"]["app_config"] is app_config - assert run_manager.get(record.run_id).status == RunStatus.success + fetched = await run_manager.get(record.run_id) + assert fetched is not None + assert fetched.status == RunStatus.success bridge.publish_end.assert_awaited_once_with(record.run_id) bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60) diff --git a/backend/tests/test_thread_run_messages_pagination.py b/backend/tests/test_thread_run_messages_pagination.py index 00e354a34..9098e2b73 100644 --- a/backend/tests/test_thread_run_messages_pagination.py +++ b/backend/tests/test_thread_run_messages_pagination.py @@ -2,25 +2,30 @@ from __future__ import annotations +import asyncio from unittest.mock import AsyncMock, MagicMock from _router_auth_helpers import make_authed_test_app from fastapi.testclient import TestClient from app.gateway.routers import thread_runs +from deerflow.runtime import RunManager +from deerflow.runtime.runs.store.memory import MemoryRunStore # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -def _make_app(event_store=None): +def _make_app(event_store=None, run_manager=None): """Build a test FastAPI app with stub auth and mocked state.""" app = make_authed_test_app() app.include_router(thread_runs.router) if event_store is not None: app.state.run_event_store = event_store + if run_manager is not None: + app.state.run_manager = run_manager return app @@ -36,6 +41,23 @@ def _make_message(seq: int) -> dict: return {"seq": seq, "event_type": "ai_message", "category": "message", "content": f"msg-{seq}"} +def _make_store_only_run_manager() -> RunManager: + store = MemoryRunStore() + asyncio.run( + store.put( + "store-only-run", + thread_id="thread-store", + assistant_id="lead_agent", + status="running", + multitask_strategy="reject", + metadata={}, + kwargs={}, + created_at="2026-01-01T00:00:00+00:00", + ) + ) + return RunManager(store=store) + + # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- @@ -128,3 +150,46 @@ def test_empty_data_when_no_messages(): body = response.json() assert body["data"] == [] assert body["has_more"] is False + + +def test_get_run_hydrates_store_only_run(): + """GET /api/threads/{tid}/runs/{rid} should read historical store rows.""" + app = _make_app(run_manager=_make_store_only_run_manager()) + with TestClient(app) as client: + response = client.get("/api/threads/thread-store/runs/store-only-run") + + assert response.status_code == 200 + body = response.json() + assert body["run_id"] == "store-only-run" + assert body["thread_id"] == "thread-store" + assert body["status"] == "running" + + +def test_cancel_store_only_run_returns_409(): + """Store-only runs are readable but not cancellable by this worker.""" + app = _make_app(run_manager=_make_store_only_run_manager()) + with TestClient(app) as client: + response = client.post("/api/threads/thread-store/runs/store-only-run/cancel") + + assert response.status_code == 409 + assert "not active on this worker" in response.json()["detail"] + + +def test_join_store_only_run_returns_409(): + """join endpoint should return 409 for store-only runs (no local stream state).""" + app = _make_app(run_manager=_make_store_only_run_manager()) + with TestClient(app) as client: + response = client.get("/api/threads/thread-store/runs/store-only-run/join") + + assert response.status_code == 409 + assert "not active on this worker" in response.json()["detail"] + + +def test_stream_store_only_run_returns_409(): + """stream endpoint (action=None) should return 409 for store-only runs.""" + app = _make_app(run_manager=_make_store_only_run_manager()) + with TestClient(app) as client: + response = client.get("/api/threads/thread-store/runs/store-only-run/stream") + + assert response.status_code == 409 + assert "not active on this worker" in response.json()["detail"]