mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-20 07:01:03 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e60621d519 | |||
| f7a6ca8364 | |||
| 2540acd5f7 | |||
| b2704525a0 | |||
| 00e0e9a49a |
+8
-20
@@ -158,7 +158,7 @@ from deerflow.config import get_app_config
|
||||
|
||||
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
|
||||
|
||||
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
|
||||
1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory
|
||||
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
||||
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
||||
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
|
||||
@@ -216,9 +216,6 @@ FastAPI application on port 8001 with health check at `GET /health`.
|
||||
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
||||
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
|
||||
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing |
|
||||
| **Thread Runs** (`/api/threads/{id}/runs`) | `POST /` - create background run; `POST /stream` - create + SSE stream; `POST /wait` - create + block; `GET /` - list runs; `GET /{rid}` - run details; `POST /{rid}/cancel` - cancel; `GET /{rid}/join` - join SSE; `GET /{rid}/messages` - paginated messages `{data, has_more}`; `GET /{rid}/events` - full event stream; `GET /../messages` - thread messages with feedback; `GET /../token-usage` - aggregate tokens |
|
||||
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
||||
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
||||
|
||||
Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway.
|
||||
|
||||
@@ -232,7 +229,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
||||
|
||||
**Virtual Path System**:
|
||||
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
||||
- Physical: `backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
||||
- Physical: `backend/.deer-flow/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
||||
- Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()`
|
||||
- Detection: `is_local_sandbox()` checks `sandbox_id == "local"`
|
||||
|
||||
@@ -272,7 +269,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
||||
- `invoke_acp_agent` - Invokes external ACP-compatible agents from `config.yaml`
|
||||
- ACP launchers must be real ACP adapters. The standard `codex` CLI is not ACP-compatible by itself; configure a wrapper such as `npx -y @zed-industries/codex-acp` or an installed `codex-acp` binary
|
||||
- Missing ACP executables now return an actionable error message instead of a raw `[Errno 2]`
|
||||
- Each ACP agent uses a per-thread workspace at `{base_dir}/users/{user_id}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
|
||||
- Each ACP agent uses a per-thread workspace at `{base_dir}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
|
||||
- `image_search/` - Image search via DuckDuckGo
|
||||
|
||||
### MCP System (`packages/harness/deerflow/mcp/`)
|
||||
@@ -341,27 +338,18 @@ Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow a
|
||||
|
||||
**Components**:
|
||||
- `updater.py` - LLM-based memory updates with fact extraction, whitespace-normalized fact deduplication (trims leading/trailing whitespace before comparing), and atomic file I/O
|
||||
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time); captures `user_id` at enqueue time so it survives the `threading.Timer` boundary
|
||||
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time)
|
||||
- `prompt.py` - Prompt templates for memory updates
|
||||
- `storage.py` - File-based storage with per-user isolation; cache keyed by `(user_id, agent_name)` tuple
|
||||
|
||||
**Per-User Isolation**:
|
||||
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
|
||||
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
|
||||
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
|
||||
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
|
||||
- Absolute `storage_path` in config opts out of per-user isolation
|
||||
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
|
||||
|
||||
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
|
||||
**Data Structure** (stored in `backend/.deer-flow/memory.json`):
|
||||
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
||||
- **History**: `recentMonths`, `earlierContext`, `longTermBackground`
|
||||
- **Facts**: Discrete facts with `id`, `content`, `category` (preference/knowledge/context/behavior/goal), `confidence` (0-1), `createdAt`, `source`
|
||||
|
||||
**Workflow**:
|
||||
1. `MemoryMiddleware` filters messages (user inputs + final AI responses), captures `user_id` via `get_effective_user_id()`, and queues conversation with the captured `user_id`
|
||||
1. `MemoryMiddleware` filters messages (user inputs + final AI responses) and queues conversation
|
||||
2. Queue debounces (30s default), batches updates, deduplicates per-thread
|
||||
3. Background thread invokes LLM to extract context updates and facts, using the stored `user_id` (not the contextvar, which is unavailable on timer threads)
|
||||
3. Background thread invokes LLM to extract context updates and facts
|
||||
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
||||
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
||||
|
||||
@@ -369,7 +357,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
|
||||
|
||||
**Configuration** (`config.yaml` → `memory`):
|
||||
- `enabled` / `injection_enabled` - Master switches
|
||||
- `storage_path` - Path to memory.json (absolute path opts out of per-user isolation)
|
||||
- `storage_path` - Path to memory.json
|
||||
- `debounce_seconds` - Wait time before processing (default: 30)
|
||||
- `model_name` - LLM for updates (null = default model)
|
||||
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
||||
|
||||
@@ -13,7 +13,6 @@ from app.channels.base import Channel
|
||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -345,9 +344,8 @@ class FeishuChannel(Channel):
|
||||
return f"Failed to obtain the [{type}]"
|
||||
|
||||
paths = get_paths()
|
||||
user_id = get_effective_user_id()
|
||||
paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
||||
uploads_dir = paths.sandbox_uploads_dir(thread_id, user_id=user_id).resolve()
|
||||
paths.ensure_thread_dirs(thread_id)
|
||||
uploads_dir = paths.sandbox_uploads_dir(thread_id).resolve()
|
||||
|
||||
ext = "png" if type == "image" else "bin"
|
||||
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
|
||||
|
||||
@@ -17,7 +17,6 @@ from langgraph_sdk.errors import ConflictError
|
||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.store import ChannelStore
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -342,15 +341,14 @@ def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedA
|
||||
|
||||
attachments: list[ResolvedAttachment] = []
|
||||
paths = get_paths()
|
||||
user_id = get_effective_user_id()
|
||||
outputs_dir = paths.sandbox_outputs_dir(thread_id, user_id=user_id).resolve()
|
||||
outputs_dir = paths.sandbox_outputs_dir(thread_id).resolve()
|
||||
for virtual_path in artifacts:
|
||||
# Security: only allow files from the agent outputs directory
|
||||
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
|
||||
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
|
||||
continue
|
||||
try:
|
||||
actual = paths.resolve_virtual_path(thread_id, virtual_path, user_id=user_id)
|
||||
actual = paths.resolve_virtual_path(thread_id, virtual_path)
|
||||
# Verify the resolved path is actually under the outputs directory
|
||||
# (guards against path-traversal even after prefix check)
|
||||
try:
|
||||
|
||||
+53
-49
@@ -2,6 +2,7 @@ import logging
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -40,69 +41,77 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
"""Startup hook: generate init token on first boot; migrate orphan threads otherwise.
|
||||
"""Auto-create the admin user on first boot if no users exist.
|
||||
|
||||
After admin creation, migrate orphan threads from the LangGraph
|
||||
store (metadata.user_id unset) to the admin account. This is the
|
||||
store (metadata.owner_id unset) to the admin account. This is the
|
||||
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
|
||||
authentication have existing LangGraph thread data that needs an
|
||||
owner assigned.
|
||||
First boot (no admin exists):
|
||||
- Generates a one-time ``init_token`` stored in ``app.state.init_token``
|
||||
- Logs the token to stdout so the operator can copy-paste it into the
|
||||
``/setup`` form to create the first admin account interactively.
|
||||
- Does NOT create any user accounts automatically.
|
||||
|
||||
Subsequent boots (admin already exists):
|
||||
- Runs the one-time "no-auth → with-auth" orphan thread migration for
|
||||
existing LangGraph thread metadata that has no owner_id.
|
||||
|
||||
No SQL persistence migration is needed: the four user_id columns
|
||||
No SQL persistence migration is needed: the four owner_id columns
|
||||
(threads_meta, runs, run_events, feedback) only come into existence
|
||||
alongside the auth module via create_all, so freshly created tables
|
||||
never contain NULL-owner rows.
|
||||
never contain NULL-owner rows. "Existing persistence DB + new auth"
|
||||
is not a supported upgrade path — fresh install or wipe-and-retry.
|
||||
|
||||
Multi-worker safe: relies on SQLite UNIQUE constraint to resolve
|
||||
races during admin creation. Only the worker that successfully
|
||||
creates/updates the admin prints the password; losers silently skip.
|
||||
"""
|
||||
import secrets
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.gateway.auth.credential_file import write_initial_credentials
|
||||
from app.gateway.deps import get_local_provider
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
def _announce_credentials(email: str, password: str, *, label: str, headline: str) -> None:
|
||||
"""Write the password to a 0600 file and log the path (never the secret)."""
|
||||
cred_path = write_initial_credentials(email, password, label=label)
|
||||
logger.info("=" * 60)
|
||||
logger.info(" %s", headline)
|
||||
logger.info(" Credentials written to: %s (mode 0600)", cred_path)
|
||||
logger.info(" Change it after login: Settings -> Account")
|
||||
logger.info("=" * 60)
|
||||
|
||||
provider = get_local_provider()
|
||||
admin_count = await provider.count_admin_users()
|
||||
user_count = await provider.count_users()
|
||||
|
||||
if admin_count == 0:
|
||||
init_token = secrets.token_urlsafe(32)
|
||||
app.state.init_token = init_token
|
||||
logger.info("=" * 60)
|
||||
logger.info(" First boot detected — no admin account exists.")
|
||||
logger.info(" Use the one-time token below to create the admin account.")
|
||||
logger.info(" Copy it into the /setup form when prompted.")
|
||||
logger.info(" INIT TOKEN: %s", init_token)
|
||||
logger.info(" Visit /setup to complete admin account creation.")
|
||||
logger.info("=" * 60)
|
||||
return
|
||||
admin = None
|
||||
|
||||
# Admin already exists — run orphan thread migration for any
|
||||
# LangGraph thread metadata that pre-dates the auth module.
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
return
|
||||
if user_count == 0:
|
||||
password = secrets.token_urlsafe(16)
|
||||
try:
|
||||
admin = await provider.create_user(email="admin@deerflow.dev", password=password, system_role="admin", needs_setup=True)
|
||||
except ValueError:
|
||||
return # Another worker already created the admin.
|
||||
_announce_credentials(admin.email, password, label="initial", headline="Admin account created on first boot")
|
||||
else:
|
||||
# Admin exists but setup never completed — reset password so operator
|
||||
# can always find it in the console without needing the CLI.
|
||||
# Multi-worker guard: if admin was created less than 30s ago, another
|
||||
# worker just created it and will print the password — skip reset.
|
||||
admin = await provider.get_user_by_email("admin@deerflow.dev")
|
||||
if admin and admin.needs_setup:
|
||||
import time
|
||||
|
||||
async with sf() as session:
|
||||
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
age = time.time() - admin.created_at.replace(tzinfo=UTC).timestamp()
|
||||
if age >= 30:
|
||||
from app.gateway.auth.password import hash_password_async
|
||||
|
||||
if row is None:
|
||||
return # Should not happen (admin_count > 0 above), but be safe.
|
||||
password = secrets.token_urlsafe(16)
|
||||
admin.password_hash = await hash_password_async(password)
|
||||
admin.token_version += 1
|
||||
await provider.update_user(admin)
|
||||
_announce_credentials(admin.email, password, label="reset", headline="Admin account setup incomplete — password reset")
|
||||
|
||||
admin_id = str(row.id)
|
||||
if admin is None:
|
||||
return # Nothing to bind orphans to.
|
||||
|
||||
admin_id = str(admin.id)
|
||||
|
||||
# LangGraph store orphan migration — non-fatal.
|
||||
# This covers the "no-auth → with-auth" upgrade path for users
|
||||
# whose existing LangGraph thread metadata has no user_id set.
|
||||
# whose existing LangGraph thread metadata has no owner_id set.
|
||||
store = getattr(app.state, "store", None)
|
||||
if store is not None:
|
||||
try:
|
||||
@@ -134,7 +143,7 @@ async def _iter_store_items(store, namespace, *, page_size: int = 500):
|
||||
|
||||
|
||||
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
||||
"""Migrate LangGraph store threads with no user_id to the given admin.
|
||||
"""Migrate LangGraph store threads with no owner_id to the given admin.
|
||||
|
||||
Uses cursor pagination so all orphans are migrated regardless of
|
||||
count. Returns the number of rows migrated.
|
||||
@@ -142,8 +151,8 @@ async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
||||
migrated = 0
|
||||
async for item in _iter_store_items(store, ("threads",)):
|
||||
metadata = item.value.get("metadata", {})
|
||||
if not metadata.get("user_id"):
|
||||
metadata["user_id"] = admin_user_id
|
||||
if not metadata.get("owner_id"):
|
||||
metadata["owner_id"] = admin_user_id
|
||||
item.value["metadata"] = metadata
|
||||
await store.aput(("threads",), item.key, item.value)
|
||||
migrated += 1
|
||||
@@ -365,11 +374,6 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
|
||||
"""
|
||||
return {"status": "healthy", "service": "deer-flow-gateway"}
|
||||
|
||||
# Ensure init_token always exists on app.state (None until lifespan sets it
|
||||
# if no admin is found). This prevents AttributeError in tests that don't
|
||||
# run the full lifespan.
|
||||
app.state.init_token = None
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@@ -20,8 +20,6 @@ class AuthErrorCode(StrEnum):
|
||||
EMAIL_ALREADY_EXISTS = "email_already_exists"
|
||||
PROVIDER_NOT_FOUND = "provider_not_found"
|
||||
NOT_AUTHENTICATED = "not_authenticated"
|
||||
SYSTEM_ALREADY_INITIALIZED = "system_already_initialized"
|
||||
INVALID_INIT_TOKEN = "invalid_init_token"
|
||||
|
||||
|
||||
class TokenError(StrEnum):
|
||||
|
||||
@@ -78,10 +78,6 @@ class LocalAuthProvider(AuthProvider):
|
||||
"""Return total number of registered users."""
|
||||
return await self._repo.count_users()
|
||||
|
||||
async def count_admin_users(self) -> int:
|
||||
"""Return number of admin users."""
|
||||
return await self._repo.count_admin_users()
|
||||
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user."""
|
||||
return await self._repo.update_user(user)
|
||||
|
||||
@@ -83,11 +83,6 @@ class UserRepository(ABC):
|
||||
"""Return total number of registered users."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count_admin_users(self) -> int:
|
||||
"""Return number of users with system_role == 'admin'."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID.
|
||||
|
||||
@@ -114,11 +114,6 @@ class SQLiteUserRepository(UserRepository):
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
async def count_admin_users(self) -> int:
|
||||
stmt = select(func.count()).select_from(UserRow).where(UserRow.system_role == "admin")
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
|
||||
async with self._sf() as session:
|
||||
|
||||
@@ -36,7 +36,6 @@ _PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/setup-status",
|
||||
"/api/v1/auth/initialize",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -233,18 +233,18 @@ def require_permission(
|
||||
# (``threads_meta`` table). We verify ownership via
|
||||
# ``ThreadMetaStore.check_access``: it returns True for
|
||||
# missing rows (untracked legacy thread) and for rows whose
|
||||
# ``user_id`` is NULL (shared / pre-auth data), so this is
|
||||
# ``owner_id`` is NULL (shared / pre-auth data), so this is
|
||||
# strict-deny rather than strict-allow — only an *existing*
|
||||
# row with a *different* user_id triggers 404.
|
||||
# row with a *different* owner_id triggers 404.
|
||||
if owner_check:
|
||||
thread_id = kwargs.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||
|
||||
from app.gateway.deps import get_thread_store
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
thread_store = get_thread_store(request)
|
||||
allowed = await thread_store.check_access(
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
allowed = await thread_meta_repo.check_access(
|
||||
thread_id,
|
||||
str(auth.user.id),
|
||||
require_existing=require_existing,
|
||||
|
||||
@@ -48,7 +48,6 @@ _AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/initialize",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
+10
-16
@@ -1,7 +1,8 @@
|
||||
"""Centralized accessors for singleton objects stored on ``app.state``.
|
||||
|
||||
**Getters** (used by routers): raise 503 when a required dependency is
|
||||
missing, except ``get_store`` which returns ``None``.
|
||||
missing, except ``get_store`` and ``get_thread_meta_repo`` which return
|
||||
``None``.
|
||||
|
||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||
"""
|
||||
@@ -19,7 +20,6 @@ from deerflow.runtime import RunContext, RunManager
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -31,10 +31,10 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
async with langgraph_runtime(app):
|
||||
yield
|
||||
"""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
||||
from deerflow.runtime import make_store, make_stream_bridge
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
@@ -53,18 +53,18 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
if sf is not None:
|
||||
from deerflow.persistence.feedback import FeedbackRepository
|
||||
from deerflow.persistence.run import RunRepository
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
app.state.run_store = RunRepository(sf)
|
||||
app.state.feedback_repo = FeedbackRepository(sf)
|
||||
app.state.thread_meta_repo = ThreadMetaRepository(sf)
|
||||
else:
|
||||
from deerflow.persistence.thread_meta import MemoryThreadMetaStore
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
app.state.run_store = MemoryRunStore()
|
||||
app.state.feedback_repo = None
|
||||
|
||||
from deerflow.persistence.thread_meta import make_thread_store
|
||||
|
||||
app.state.thread_store = make_thread_store(sf, app.state.store)
|
||||
app.state.thread_meta_repo = MemoryThreadMetaStore(app.state.store)
|
||||
|
||||
# Run event store (has its own factory with config-driven backend selection)
|
||||
run_events_config = getattr(config, "run_events", None)
|
||||
@@ -80,7 +80,7 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Getters – called by routers per-request
|
||||
# Getters -- called by routers per-request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -110,12 +110,7 @@ def get_store(request: Request):
|
||||
return getattr(request.app.state, "store", None)
|
||||
|
||||
|
||||
def get_thread_store(request: Request) -> ThreadMetaStore:
|
||||
"""Return the thread metadata store (SQL or memory-backed)."""
|
||||
val = getattr(request.app.state, "thread_store", None)
|
||||
if val is None:
|
||||
raise HTTPException(status_code=503, detail="Thread metadata store not available")
|
||||
return val
|
||||
get_thread_meta_repo = _require("thread_meta_repo", "Thread metadata store")
|
||||
|
||||
|
||||
def get_run_context(request: Request) -> RunContext:
|
||||
@@ -133,11 +128,10 @@ def get_run_context(request: Request) -> RunContext:
|
||||
store=get_store(request),
|
||||
event_store=get_run_event_store(request),
|
||||
run_events_config=getattr(get_app_config(), "run_events", None),
|
||||
thread_store=get_thread_store(request),
|
||||
thread_meta_repo=get_thread_meta_repo(request),
|
||||
)
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth helpers (used by authz.py and auth middleware)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -93,14 +93,14 @@ async def authenticate(request):
|
||||
|
||||
@auth.on
|
||||
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
||||
"""Inject user_id metadata on writes; filter by user_id on reads.
|
||||
"""Inject owner_id metadata on writes; filter by owner_id on reads.
|
||||
|
||||
Gateway stores thread ownership as ``metadata.user_id``.
|
||||
Gateway stores thread ownership as ``metadata.owner_id``.
|
||||
This handler ensures LangGraph Server enforces the same isolation.
|
||||
"""
|
||||
# On create/update: stamp user_id into metadata
|
||||
# On create/update: stamp owner_id into metadata
|
||||
metadata = value.setdefault("metadata", {})
|
||||
metadata["user_id"] = ctx.user.identity
|
||||
metadata["owner_id"] = ctx.user.identity
|
||||
|
||||
# Return filter dict — LangGraph applies it to search/read/delete
|
||||
return {"user_id": ctx.user.identity}
|
||||
return {"owner_id": ctx.user.identity}
|
||||
|
||||
@@ -5,7 +5,6 @@ from pathlib import Path
|
||||
from fastapi import HTTPException
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
|
||||
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
||||
@@ -23,7 +22,7 @@ def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
||||
HTTPException: If the path is invalid or outside allowed directories.
|
||||
"""
|
||||
try:
|
||||
return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=get_effective_user_id())
|
||||
return get_paths().resolve_virtual_path(thread_id, virtual_path)
|
||||
except ValueError as e:
|
||||
status = 403 if "traversal" in str(e) else 400
|
||||
raise HTTPException(status_code=status, detail=str(e))
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
from ipaddress import ip_address, ip_network
|
||||
|
||||
@@ -379,74 +378,9 @@ async def get_me(request: Request):
|
||||
|
||||
@router.get("/setup-status")
|
||||
async def setup_status():
|
||||
"""Check if an admin account exists. Returns needs_setup=True when no admin exists."""
|
||||
admin_count = await get_local_provider().count_admin_users()
|
||||
return {"needs_setup": admin_count == 0}
|
||||
|
||||
|
||||
class InitializeAdminRequest(BaseModel):
|
||||
"""Request model for first-boot admin account creation."""
|
||||
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=8)
|
||||
init_token: str | None = Field(default=None, description="One-time initialization token printed to server logs on first boot")
|
||||
|
||||
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
||||
|
||||
|
||||
@router.post("/initialize", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def initialize_admin(request: Request, response: Response, body: InitializeAdminRequest):
|
||||
"""Create the first admin account on initial system setup.
|
||||
|
||||
Only callable when no admin exists. Returns 409 Conflict if an admin
|
||||
already exists. Requires the one-time ``init_token`` that is logged to
|
||||
stdout at startup whenever the system has no admin account.
|
||||
|
||||
On success the token is consumed (one-time use), the admin account is
|
||||
created with ``needs_setup=False``, and the session cookie is set.
|
||||
"""
|
||||
# Validate the one-time initialization token. The token is generated
|
||||
# at startup and stored in app.state.init_token; it is consumed here on
|
||||
# the first successful call so it cannot be replayed.
|
||||
# Using str | None allows a missing/null token to return 403 (not 422),
|
||||
# giving a consistent error response regardless of whether the token is
|
||||
# absent or incorrect.
|
||||
stored_token: str | None = getattr(request.app.state, "init_token", None)
|
||||
provided_token: str = body.init_token or ""
|
||||
if stored_token is None or not secrets.compare_digest(stored_token, provided_token):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_INIT_TOKEN, message="Invalid or expired initialization token").model_dump(),
|
||||
)
|
||||
|
||||
admin_count = await get_local_provider().count_admin_users()
|
||||
if admin_count > 0:
|
||||
# Do NOT consume the token on this error path — consuming it here
|
||||
# would allow an attacker to exhaust the token by calling with the
|
||||
# correct token when admin already exists (denial-of-service).
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
|
||||
)
|
||||
|
||||
try:
|
||||
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="admin", needs_setup=False)
|
||||
except ValueError:
|
||||
# DB unique-constraint race: another concurrent request beat us.
|
||||
# Do NOT consume the token here for the same reason as above.
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
|
||||
)
|
||||
|
||||
# Consume the token only after successful initialization — this is the
|
||||
# single place where one-time use is enforced.
|
||||
request.app.state.init_token = None
|
||||
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
||||
"""Check if admin account exists. Always False after first boot."""
|
||||
user_count = await get_local_provider().count_users()
|
||||
return {"needs_setup": user_count == 0}
|
||||
|
||||
|
||||
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
|
||||
|
||||
@@ -30,16 +30,11 @@ class FeedbackCreateRequest(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
||||
|
||||
|
||||
class FeedbackUpsertRequest(BaseModel):
|
||||
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
|
||||
comment: str | None = Field(default=None, description="Optional text feedback")
|
||||
|
||||
|
||||
class FeedbackResponse(BaseModel):
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
user_id: str | None = None
|
||||
owner_id: str | None = None
|
||||
message_id: str | None = None
|
||||
rating: int
|
||||
comment: str | None = None
|
||||
@@ -58,57 +53,6 @@ class FeedbackStatsResponse(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||
async def upsert_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
body: FeedbackUpsertRequest,
|
||||
request: Request,
|
||||
) -> dict[str, Any]:
|
||||
"""Create or update feedback for a run (idempotent)."""
|
||||
if body.rating not in (1, -1):
|
||||
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
||||
|
||||
user_id = await get_current_user(request)
|
||||
|
||||
run_store = get_run_store(request)
|
||||
run = await run_store.get(run_id)
|
||||
if run is None:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
if run.get("thread_id") != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
||||
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
return await feedback_repo.upsert(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=body.rating,
|
||||
user_id=user_id,
|
||||
comment=body.comment,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{thread_id}/runs/{run_id}/feedback")
|
||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
||||
async def delete_run_feedback(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
) -> dict[str, bool]:
|
||||
"""Delete the current user's feedback for a run."""
|
||||
user_id = await get_current_user(request)
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
deleted = await feedback_repo.delete_by_run(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="No feedback found for this run")
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||
async def create_feedback(
|
||||
@@ -136,7 +80,7 @@ async def create_feedback(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=body.rating,
|
||||
user_id=user_id,
|
||||
owner_id=user_id,
|
||||
message_id=body.message_id,
|
||||
comment=body.comment,
|
||||
)
|
||||
|
||||
@@ -13,7 +13,6 @@ from deerflow.agents.memory.updater import (
|
||||
update_memory_fact,
|
||||
)
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["memory"])
|
||||
|
||||
@@ -148,7 +147,7 @@ async def get_memory() -> MemoryResponse:
|
||||
}
|
||||
```
|
||||
"""
|
||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||
memory_data = get_memory_data()
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@@ -168,7 +167,7 @@ async def reload_memory() -> MemoryResponse:
|
||||
Returns:
|
||||
The reloaded memory data.
|
||||
"""
|
||||
memory_data = reload_memory_data(user_id=get_effective_user_id())
|
||||
memory_data = reload_memory_data()
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@@ -182,7 +181,7 @@ async def reload_memory() -> MemoryResponse:
|
||||
async def clear_memory() -> MemoryResponse:
|
||||
"""Clear all persisted memory data."""
|
||||
try:
|
||||
memory_data = clear_memory_data(user_id=get_effective_user_id())
|
||||
memory_data = clear_memory_data()
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
||||
|
||||
@@ -203,7 +202,6 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
||||
content=request.content,
|
||||
category=request.category,
|
||||
confidence=request.confidence,
|
||||
user_id=get_effective_user_id(),
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise _map_memory_fact_value_error(exc) from exc
|
||||
@@ -223,7 +221,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
||||
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
||||
"""Delete a single fact from memory by fact id."""
|
||||
try:
|
||||
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
|
||||
memory_data = delete_memory_fact(fact_id)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||
except OSError as exc:
|
||||
@@ -247,7 +245,6 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
||||
content=request.content,
|
||||
category=request.category,
|
||||
confidence=request.confidence,
|
||||
user_id=get_effective_user_id(),
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise _map_memory_fact_value_error(exc) from exc
|
||||
@@ -268,7 +265,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
||||
)
|
||||
async def export_memory() -> MemoryResponse:
|
||||
"""Export the current memory data."""
|
||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||
memory_data = get_memory_data()
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@@ -282,7 +279,7 @@ async def export_memory() -> MemoryResponse:
|
||||
async def import_memory(request: MemoryResponse) -> MemoryResponse:
|
||||
"""Import and persist memory data."""
|
||||
try:
|
||||
memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id())
|
||||
memory_data = import_memory_data(request.model_dump())
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
||||
|
||||
@@ -340,7 +337,7 @@ async def get_memory_status() -> MemoryStatusResponse:
|
||||
Combined memory configuration and current data.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||
memory_data = get_memory_data()
|
||||
|
||||
return MemoryStatusResponse(
|
||||
config=MemoryConfigResponse(
|
||||
|
||||
@@ -11,11 +11,10 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
|
||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
@@ -86,57 +85,3 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
|
||||
return {"status": record.status.value, "error": record.error}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run-scoped read endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _resolve_run(run_id: str, request: Request) -> dict:
|
||||
"""Fetch run by run_id with user ownership check. Raises 404 if not found."""
|
||||
run_store = get_run_store(request)
|
||||
record = await run_store.get(run_id) # user_id=AUTO filters by contextvar
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
return record
|
||||
|
||||
|
||||
@router.get("/{run_id}/messages")
|
||||
@require_permission("runs", "read")
|
||||
async def run_messages(
|
||||
run_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(default=50, le=200, ge=1),
|
||||
before_seq: int | None = Query(default=None),
|
||||
after_seq: int | None = Query(default=None),
|
||||
) -> dict:
|
||||
"""Return paginated messages for a run (cursor-based).
|
||||
|
||||
Pagination:
|
||||
- after_seq: messages with seq > after_seq (forward)
|
||||
- before_seq: messages with seq < before_seq (backward)
|
||||
- neither: latest messages
|
||||
|
||||
Response: { data: [...], has_more: bool }
|
||||
"""
|
||||
run = await _resolve_run(run_id, request)
|
||||
event_store = get_run_event_store(request)
|
||||
rows = await event_store.list_messages_by_run(
|
||||
run["thread_id"], run_id,
|
||||
limit=limit + 1,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
)
|
||||
has_more = len(rows) > limit
|
||||
data = rows[:limit] if has_more else rows
|
||||
return {"data": data, "has_more": has_more}
|
||||
|
||||
|
||||
@router.get("/{run_id}/feedback")
|
||||
@require_permission("runs", "read")
|
||||
async def run_feedback(run_id: str, request: Request) -> list[dict]:
|
||||
"""Return all feedback for a run."""
|
||||
run = await _resolve_run(run_id, request)
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
return await feedback_repo.list_by_run(run["thread_id"], run_id)
|
||||
|
||||
@@ -20,7 +20,7 @@ from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import RunRecord, serialize_channel_values
|
||||
|
||||
@@ -291,62 +291,17 @@ async def list_thread_messages(
|
||||
before_seq: int | None = Query(default=None),
|
||||
after_seq: int | None = Query(default=None),
|
||||
) -> list[dict]:
|
||||
"""Return displayable messages for a thread (across all runs), with feedback attached."""
|
||||
"""Return displayable messages for a thread (across all runs)."""
|
||||
event_store = get_run_event_store(request)
|
||||
messages = await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
||||
|
||||
# Attach feedback to the last AI message of each run
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
user_id = await get_current_user(request)
|
||||
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
|
||||
|
||||
# Find the last ai_message per run_id
|
||||
last_ai_per_run: dict[str, int] = {} # run_id -> index in messages list
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.get("event_type") == "ai_message":
|
||||
last_ai_per_run[msg["run_id"]] = i
|
||||
|
||||
# Attach feedback field
|
||||
last_ai_indices = set(last_ai_per_run.values())
|
||||
for i, msg in enumerate(messages):
|
||||
if i in last_ai_indices:
|
||||
run_id = msg["run_id"]
|
||||
fb = feedback_map.get(run_id)
|
||||
msg["feedback"] = {
|
||||
"feedback_id": fb["feedback_id"],
|
||||
"rating": fb["rating"],
|
||||
"comment": fb.get("comment"),
|
||||
} if fb else None
|
||||
else:
|
||||
msg["feedback"] = None
|
||||
|
||||
return messages
|
||||
return await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/messages")
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def list_run_messages(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(default=50, le=200, ge=1),
|
||||
before_seq: int | None = Query(default=None),
|
||||
after_seq: int | None = Query(default=None),
|
||||
) -> dict:
|
||||
"""Return paginated messages for a specific run.
|
||||
|
||||
Response: { data: [...], has_more: bool }
|
||||
"""
|
||||
async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]:
|
||||
"""Return displayable messages for a specific run."""
|
||||
event_store = get_run_event_store(request)
|
||||
rows = await event_store.list_messages_by_run(
|
||||
thread_id, run_id,
|
||||
limit=limit + 1,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
)
|
||||
has_more = len(rows) > limit
|
||||
data = rows[:limit] if has_more else rows
|
||||
return {"data": data, "has_more": has_more}
|
||||
return await event_store.list_messages_by_run(thread_id, run_id)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/events")
|
||||
|
||||
@@ -13,7 +13,6 @@ matching the LangGraph Platform wire format expected by the
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
@@ -22,11 +21,10 @@ from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store
|
||||
from app.gateway.deps import get_checkpointer
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||
@@ -36,7 +34,7 @@ router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||
# them. Pydantic ``@field_validator("metadata")`` strips them on every
|
||||
# inbound model below so a malicious client cannot reflect a forged
|
||||
# owner identity through the API surface. Defense-in-depth — the
|
||||
# row-level invariant is still ``threads_meta.user_id`` populated from
|
||||
# row-level invariant is still ``threads_meta.owner_id`` populated from
|
||||
# the auth contextvar; this list closes the metadata-blob echo gap.
|
||||
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
|
||||
|
||||
@@ -144,11 +142,11 @@ class ThreadHistoryRequest(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _delete_thread_data(thread_id: str, paths: Paths | None = None, *, user_id: str | None = None) -> ThreadDeleteResponse:
|
||||
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
|
||||
"""Delete local persisted filesystem data for a thread."""
|
||||
path_manager = paths or get_paths()
|
||||
try:
|
||||
path_manager.delete_thread_dir(thread_id, user_id=user_id)
|
||||
path_manager.delete_thread_dir(thread_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
except FileNotFoundError:
|
||||
@@ -196,10 +194,10 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
||||
and removes the thread_meta row from the configured ThreadMetaStore
|
||||
(sqlite or memory).
|
||||
"""
|
||||
from app.gateway.deps import get_thread_store
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
# Clean local filesystem
|
||||
response = _delete_thread_data(thread_id, user_id=get_effective_user_id())
|
||||
response = _delete_thread_data(thread_id)
|
||||
|
||||
# Remove checkpoints (best-effort)
|
||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
||||
@@ -213,8 +211,8 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
||||
# Remove thread_meta row (best-effort) — required for sqlite backend
|
||||
# so the deleted thread no longer appears in /threads/search.
|
||||
try:
|
||||
thread_store = get_thread_store(request)
|
||||
await thread_store.delete(thread_id)
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
await thread_meta_repo.delete(thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
|
||||
|
||||
@@ -229,17 +227,17 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
and an empty checkpoint (so state endpoints work immediately).
|
||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
||||
"""
|
||||
from app.gateway.deps import get_thread_store
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_store = get_thread_store(request)
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = time.time()
|
||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||
|
||||
# Idempotency: return existing record when already present
|
||||
existing_record = await thread_store.get(thread_id)
|
||||
existing_record = await thread_meta_repo.get(thread_id)
|
||||
if existing_record is not None:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
@@ -251,7 +249,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
|
||||
# Write thread_meta so the thread appears in /threads/search immediately
|
||||
try:
|
||||
await thread_store.create(
|
||||
await thread_meta_repo.create(
|
||||
thread_id,
|
||||
assistant_id=getattr(body, "assistant_id", None),
|
||||
metadata=body.metadata,
|
||||
@@ -295,9 +293,9 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
Delegates to the configured ThreadMetaStore implementation
|
||||
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
||||
"""
|
||||
from app.gateway.deps import get_thread_store
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
repo = get_thread_store(request)
|
||||
repo = get_thread_meta_repo(request)
|
||||
rows = await repo.search(
|
||||
metadata=body.metadata or None,
|
||||
status=body.status,
|
||||
@@ -322,22 +320,22 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
||||
"""Merge metadata into a thread record."""
|
||||
from app.gateway.deps import get_thread_store
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
thread_store = get_thread_store(request)
|
||||
record = await thread_store.get(thread_id)
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
record = await thread_meta_repo.get(thread_id)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
|
||||
try:
|
||||
await thread_store.update_metadata(thread_id, body.metadata)
|
||||
await thread_meta_repo.update_metadata(thread_id, body.metadata)
|
||||
except Exception:
|
||||
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to update thread")
|
||||
|
||||
# Re-read to get the merged metadata + refreshed updated_at
|
||||
record = await thread_store.get(thread_id) or record
|
||||
record = await thread_meta_repo.get(thread_id) or record
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=record.get("status", "idle"),
|
||||
@@ -356,12 +354,12 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
execution status from the checkpointer. Falls back to the checkpointer
|
||||
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
|
||||
"""
|
||||
from app.gateway.deps import get_thread_store
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
thread_store = get_thread_store(request)
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
record: dict | None = await thread_store.get(thread_id)
|
||||
record: dict | None = await thread_meta_repo.get(thread_id)
|
||||
|
||||
# Derive accurate status from the checkpointer
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
@@ -404,165 +402,6 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event-store-backed message loader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_LEGACY_CMD_INNER_CONTENT_RE = re.compile(
|
||||
r"ToolMessage\(content=(?P<q>['\"])(?P<inner>.*?)(?P=q)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_legacy_command_repr(content_field: Any) -> Any:
|
||||
"""Recover the inner ToolMessage text from a legacy ``str(Command(...))`` repr.
|
||||
|
||||
Runs captured before the ``on_tool_end`` fix in ``journal.py`` stored
|
||||
``str(Command(update={'messages':[ToolMessage(content='X', ...)]}))`` as the
|
||||
tool_result content. New runs store ``'X'`` directly. For legacy rows, try
|
||||
to extract ``'X'`` defensively; return the original string if extraction
|
||||
fails (still no worse than the checkpoint fallback for summarized threads).
|
||||
"""
|
||||
if not isinstance(content_field, str) or not content_field.startswith("Command(update="):
|
||||
return content_field
|
||||
match = _LEGACY_CMD_INNER_CONTENT_RE.search(content_field)
|
||||
return match.group("inner") if match else content_field
|
||||
|
||||
|
||||
async def _get_event_store_messages(request: Request, thread_id: str) -> list[dict] | None:
|
||||
"""Load the full message stream for ``thread_id`` from the event store.
|
||||
|
||||
The event store is append-only and unaffected by summarization — the
|
||||
checkpoint's ``channel_values["messages"]`` is rewritten in-place when the
|
||||
SummarizationMiddleware runs, which drops all pre-summarize messages. The
|
||||
event store retains the full transcript, so callers in Gateway mode should
|
||||
prefer it for rendering the conversation history.
|
||||
|
||||
In addition to the core message content, this helper attaches two extra
|
||||
fields to every returned dict:
|
||||
|
||||
- ``run_id``: the ``run_id`` of the event that produced this message.
|
||||
Always present.
|
||||
- ``feedback``: thumbs-up/down data. Present only on the **final
|
||||
``ai_message`` of each run** (matching the per-run feedback semantics
|
||||
of ``POST /api/threads/{id}/runs/{run_id}/feedback``). The frontend uses
|
||||
the presence of this field to decide whether to render the feedback
|
||||
button, which sidesteps the positional-index mapping bug that an
|
||||
out-of-band ``/messages`` fetch exhibited.
|
||||
|
||||
Behaviour contract:
|
||||
|
||||
- **Full pagination.** ``RunEventStore.list_messages`` returns the newest
|
||||
``limit`` records when no cursor is given, so a fixed limit silently
|
||||
drops older messages on long threads. We size the read from
|
||||
``count_messages()`` and then page forward with ``after_seq`` cursors.
|
||||
- **Copy-on-read.** Each content dict is copied before ``id`` is patched
|
||||
so the live store object is never mutated; ``MemoryRunEventStore``
|
||||
returns live references.
|
||||
- **Stable ids.** Messages with ``id=None`` (human + tool_result) receive
|
||||
a deterministic ``uuid5(NAMESPACE_URL, f"{thread_id}:{seq}")`` so React
|
||||
keys are stable across requests without altering stored data. AI messages
|
||||
retain their LLM-assigned ``lc_run--*`` ids.
|
||||
- **Legacy Command repr.** Rows captured before the ``journal.py``
|
||||
``on_tool_end`` fix stored ``str(Command(update={...}))`` as the tool
|
||||
result content. ``_sanitize_legacy_command_repr`` extracts the inner
|
||||
ToolMessage text.
|
||||
- **User context.** ``DbRunEventStore`` is user-scoped by default via
|
||||
``resolve_user_id(AUTO)`` in ``runtime/user_context.py``. This helper
|
||||
must run inside a request where ``@require_permission`` has populated
|
||||
the user contextvar. Both callers below are decorated appropriately.
|
||||
Do not call this helper from CLI or migration scripts without passing
|
||||
``user_id=None`` explicitly to the underlying store methods.
|
||||
|
||||
Returns ``None`` when the event store is not configured or has no message
|
||||
events for this thread, so callers fall back to checkpoint messages.
|
||||
"""
|
||||
try:
|
||||
event_store = get_run_event_store(request)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
total = await event_store.count_messages(thread_id)
|
||||
except Exception:
|
||||
logger.exception("count_messages failed for thread %s", sanitize_log_param(thread_id))
|
||||
return None
|
||||
if not total:
|
||||
return None
|
||||
|
||||
# Batch by page_size to keep memory bounded for very long threads.
|
||||
page_size = 500
|
||||
collected: list[dict] = []
|
||||
after_seq: int | None = None
|
||||
while True:
|
||||
try:
|
||||
page = await event_store.list_messages(thread_id, limit=page_size, after_seq=after_seq)
|
||||
except Exception:
|
||||
logger.exception("list_messages failed for thread %s", sanitize_log_param(thread_id))
|
||||
return None
|
||||
if not page:
|
||||
break
|
||||
collected.extend(page)
|
||||
if len(page) < page_size:
|
||||
break
|
||||
next_cursor = page[-1].get("seq")
|
||||
if next_cursor is None or (after_seq is not None and next_cursor <= after_seq):
|
||||
break
|
||||
after_seq = next_cursor
|
||||
|
||||
# Build the message list; track the final ``ai_message`` index per run so
|
||||
# feedback can be attached at the right position (matches thread_runs.py).
|
||||
messages: list[dict] = []
|
||||
last_ai_per_run: dict[str, int] = {}
|
||||
for evt in collected:
|
||||
raw = evt.get("content")
|
||||
if not isinstance(raw, dict) or "type" not in raw:
|
||||
continue
|
||||
content = dict(raw)
|
||||
if content.get("id") is None:
|
||||
content["id"] = str(uuid.uuid5(uuid.NAMESPACE_URL, f"{thread_id}:{evt['seq']}"))
|
||||
if content.get("type") == "tool":
|
||||
content["content"] = _sanitize_legacy_command_repr(content.get("content"))
|
||||
run_id = evt.get("run_id")
|
||||
if run_id:
|
||||
content["run_id"] = run_id
|
||||
if evt.get("event_type") == "ai_message" and run_id:
|
||||
last_ai_per_run[run_id] = len(messages)
|
||||
messages.append(content)
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Attach feedback to the final ai_message of each run. If the feedback
|
||||
# subsystem is unavailable, leave the ``feedback`` field absent entirely
|
||||
# so the frontend hides the button rather than showing it over a broken
|
||||
# write path.
|
||||
feedback_available = False
|
||||
feedback_map: dict[str, dict] = {}
|
||||
try:
|
||||
feedback_repo = get_feedback_repo(request)
|
||||
user_id = await get_current_user(request)
|
||||
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
|
||||
feedback_available = True
|
||||
except Exception:
|
||||
logger.exception("feedback lookup failed for thread %s", sanitize_log_param(thread_id))
|
||||
|
||||
if feedback_available:
|
||||
for run_id, idx in last_ai_per_run.items():
|
||||
fb = feedback_map.get(run_id)
|
||||
messages[idx]["feedback"] = (
|
||||
{
|
||||
"feedback_id": fb["feedback_id"],
|
||||
"rating": fb["rating"],
|
||||
"comment": fb.get("comment"),
|
||||
}
|
||||
if fb
|
||||
else None
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
||||
@@ -601,15 +440,8 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
||||
|
||||
values = serialize_channel_values(channel_values)
|
||||
|
||||
# Prefer event-store messages: append-only, immune to summarization.
|
||||
es_messages = await _get_event_store_messages(request, thread_id)
|
||||
if es_messages is not None:
|
||||
values["messages"] = es_messages
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=values,
|
||||
values=serialize_channel_values(channel_values),
|
||||
next=next_tasks,
|
||||
metadata=metadata,
|
||||
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||
@@ -630,10 +462,10 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
|
||||
change immediately in both sqlite and memory backends.
|
||||
"""
|
||||
from app.gateway.deps import get_thread_store
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_store = get_thread_store(request)
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
|
||||
# checkpoint_ns must be present in the config for aput — default to ""
|
||||
# (the root graph namespace). checkpoint_id is optional; omitting it
|
||||
@@ -697,7 +529,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
new_title = body.values["title"]
|
||||
if new_title: # Skip empty strings and None
|
||||
try:
|
||||
await thread_store.update_display_name(thread_id, new_title)
|
||||
await thread_meta_repo.update_display_name(thread_id, new_title)
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
|
||||
@@ -727,11 +559,6 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
||||
if body.before:
|
||||
config["configurable"]["checkpoint_id"] = body.before
|
||||
|
||||
# Load the full event-store message stream once; attach to the latest
|
||||
# checkpoint entry only (matching the prior semantics). The event store
|
||||
# is append-only and immune to summarization.
|
||||
es_messages = await _get_event_store_messages(request, thread_id)
|
||||
|
||||
entries: list[HistoryEntry] = []
|
||||
is_latest_checkpoint = True
|
||||
try:
|
||||
@@ -755,17 +582,11 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
||||
if thread_data := channel_values.get("thread_data"):
|
||||
values["thread_data"] = thread_data
|
||||
|
||||
# Attach messages only to the latest checkpoint. Prefer the
|
||||
# event-store stream (complete and unaffected by summarization);
|
||||
# fall back to checkpoint channel_values when the event store is
|
||||
# unavailable or empty.
|
||||
# Attach messages from checkpointer only for the latest checkpoint
|
||||
if is_latest_checkpoint:
|
||||
if es_messages is not None:
|
||||
values["messages"] = es_messages
|
||||
else:
|
||||
messages = channel_values.get("messages")
|
||||
if messages:
|
||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||
messages = channel_values.get("messages")
|
||||
if messages:
|
||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||
is_latest_checkpoint = False
|
||||
|
||||
# Derive next tasks
|
||||
|
||||
@@ -9,7 +9,6 @@ from pydantic import BaseModel
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
from deerflow.uploads.manager import (
|
||||
PathTraversalError,
|
||||
@@ -70,7 +69,7 @@ async def upload_files(
|
||||
uploads_dir = ensure_uploads_dir(thread_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
|
||||
uploaded_files = []
|
||||
|
||||
sandbox_provider = get_sandbox_provider()
|
||||
@@ -148,7 +147,7 @@ async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
||||
enrich_file_listing(result, thread_id)
|
||||
|
||||
# Gateway additionally includes the sandbox-relative path.
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
|
||||
for f in result["files"]:
|
||||
f["path"] = str(sandbox_uploads / f["filename"])
|
||||
|
||||
|
||||
@@ -229,15 +229,15 @@ async def start_run(
|
||||
# even for threads that were never explicitly created via POST /threads
|
||||
# (e.g. stateless runs).
|
||||
try:
|
||||
existing = await run_ctx.thread_store.get(thread_id)
|
||||
existing = await run_ctx.thread_meta_repo.get(thread_id)
|
||||
if existing is None:
|
||||
await run_ctx.thread_store.create(
|
||||
await run_ctx.thread_meta_repo.create(
|
||||
thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
else:
|
||||
await run_ctx.thread_store.update_status(thread_id, "running")
|
||||
await run_ctx.thread_meta_repo.update_status(thread_id, "running")
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
|
||||
@@ -285,7 +285,7 @@ async def start_run(
|
||||
record.task = task
|
||||
|
||||
# Title sync is handled by worker.py's finally block which reads the
|
||||
# title from the checkpoint and calls thread_store.update_display_name
|
||||
# title from the checkpoint and calls thread_meta_repo.update_display_name
|
||||
# after the run completes.
|
||||
|
||||
return record
|
||||
|
||||
@@ -124,7 +124,7 @@ title:
|
||||
# checkpointer.py
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
|
||||
checkpointer = SqliteSaver.from_conn_string("deerflow.db")
|
||||
checkpointer = SqliteSaver.from_conn_string("checkpoints.db")
|
||||
```
|
||||
|
||||
```json
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointer
|
||||
from .factory import create_deerflow_agent
|
||||
from .features import Next, Prev, RuntimeFeatures
|
||||
from .lead_agent import make_lead_agent
|
||||
@@ -17,4 +18,7 @@ __all__ = [
|
||||
"make_lead_agent",
|
||||
"SandboxState",
|
||||
"ThreadState",
|
||||
"get_checkpointer",
|
||||
"reset_checkpointer",
|
||||
"make_checkpointer",
|
||||
]
|
||||
|
||||
+4
-4
@@ -7,12 +7,12 @@ Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage (e.g. FastAPI lifespan)::
|
||||
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
app.state.checkpointer = checkpointer # InMemorySaver if not configured
|
||||
|
||||
For sync usage see :mod:`deerflow.runtime.checkpointer.provider`.
|
||||
For sync usage see :mod:`deerflow.agents.checkpointer.provider`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -24,12 +24,12 @@ from collections.abc import AsyncIterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime.checkpointer.provider import (
|
||||
from deerflow.agents.checkpointer.provider import (
|
||||
POSTGRES_CONN_REQUIRED,
|
||||
POSTGRES_INSTALL,
|
||||
SQLITE_INSTALL,
|
||||
)
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
+1
-1
@@ -7,7 +7,7 @@ Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage::
|
||||
|
||||
from deerflow.runtime.checkpointer.provider import get_checkpointer, checkpointer_context
|
||||
from deerflow.agents.checkpointer.provider import get_checkpointer, checkpointer_context
|
||||
|
||||
# Singleton — reused across calls, closed on process exit
|
||||
cp = get_checkpointer()
|
||||
@@ -519,13 +519,12 @@ def _get_memory_context(agent_name: str | None = None) -> str:
|
||||
try:
|
||||
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
config = get_memory_config()
|
||||
if not config.enabled or not config.injection_enabled:
|
||||
return ""
|
||||
|
||||
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
||||
memory_data = get_memory_data(agent_name)
|
||||
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
||||
|
||||
if not memory_content.strip():
|
||||
|
||||
@@ -20,7 +20,6 @@ class ConversationContext:
|
||||
messages: list[Any]
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
agent_name: str | None = None
|
||||
user_id: str | None = None
|
||||
correction_detected: bool = False
|
||||
reinforcement_detected: bool = False
|
||||
|
||||
@@ -45,7 +44,6 @@ class MemoryUpdateQueue:
|
||||
thread_id: str,
|
||||
messages: list[Any],
|
||||
agent_name: str | None = None,
|
||||
user_id: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> None:
|
||||
@@ -55,9 +53,6 @@ class MemoryUpdateQueue:
|
||||
thread_id: The thread ID.
|
||||
messages: The conversation messages.
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
user_id: The user ID captured at enqueue time. Stored in ConversationContext so it
|
||||
survives the threading.Timer boundary (ContextVar does not propagate across
|
||||
raw threads).
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
"""
|
||||
@@ -76,7 +71,6 @@ class MemoryUpdateQueue:
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
user_id=user_id,
|
||||
correction_detected=merged_correction_detected,
|
||||
reinforcement_detected=merged_reinforcement_detected,
|
||||
)
|
||||
@@ -142,7 +136,6 @@ class MemoryUpdateQueue:
|
||||
agent_name=context.agent_name,
|
||||
correction_detected=context.correction_detected,
|
||||
reinforcement_detected=context.reinforcement_detected,
|
||||
user_id=context.user_id,
|
||||
)
|
||||
if success:
|
||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||
|
||||
@@ -43,17 +43,17 @@ class MemoryStorage(abc.ABC):
|
||||
"""Abstract base class for memory storage providers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
def load(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data for the given agent."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Force reload memory data for the given agent."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
"""Save memory data for the given agent."""
|
||||
pass
|
||||
|
||||
@@ -63,9 +63,9 @@ class FileMemoryStorage(MemoryStorage):
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the file memory storage."""
|
||||
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global)
|
||||
# Per-agent memory cache: keyed by agent_name (None = global)
|
||||
# Value: (memory_data, file_mtime)
|
||||
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {}
|
||||
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
|
||||
|
||||
def _validate_agent_name(self, agent_name: str) -> None:
|
||||
"""Validate that the agent name is safe to use in filesystem paths.
|
||||
@@ -78,29 +78,21 @@ class FileMemoryStorage(MemoryStorage):
|
||||
if not AGENT_NAME_PATTERN.match(agent_name):
|
||||
raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}")
|
||||
|
||||
def _get_memory_file_path(self, agent_name: str | None = None, *, user_id: str | None = None) -> Path:
|
||||
def _get_memory_file_path(self, agent_name: str | None = None) -> Path:
|
||||
"""Get the path to the memory file."""
|
||||
if user_id is not None:
|
||||
if agent_name is not None:
|
||||
self._validate_agent_name(agent_name)
|
||||
return get_paths().user_agent_memory_file(user_id, agent_name)
|
||||
config = get_memory_config()
|
||||
if config.storage_path and Path(config.storage_path).is_absolute():
|
||||
return Path(config.storage_path)
|
||||
return get_paths().user_memory_file(user_id)
|
||||
# Legacy: no user_id
|
||||
if agent_name is not None:
|
||||
self._validate_agent_name(agent_name)
|
||||
return get_paths().agent_memory_file(agent_name)
|
||||
|
||||
config = get_memory_config()
|
||||
if config.storage_path:
|
||||
p = Path(config.storage_path)
|
||||
return p if p.is_absolute() else get_paths().base_dir / p
|
||||
return get_paths().memory_file
|
||||
|
||||
def _load_memory_from_file(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
def _load_memory_from_file(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data from file."""
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
|
||||
if not file_path.exists():
|
||||
return create_empty_memory()
|
||||
@@ -113,42 +105,40 @@ class FileMemoryStorage(MemoryStorage):
|
||||
logger.warning("Failed to load memory file: %s", e)
|
||||
return create_empty_memory()
|
||||
|
||||
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
def load(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data (cached with file modification time check)."""
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
|
||||
try:
|
||||
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
except OSError:
|
||||
current_mtime = None
|
||||
|
||||
cache_key = (user_id, agent_name)
|
||||
cached = self._memory_cache.get(cache_key)
|
||||
cached = self._memory_cache.get(agent_name)
|
||||
|
||||
if cached is None or cached[1] != current_mtime:
|
||||
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
||||
self._memory_cache[cache_key] = (memory_data, current_mtime)
|
||||
memory_data = self._load_memory_from_file(agent_name)
|
||||
self._memory_cache[agent_name] = (memory_data, current_mtime)
|
||||
return memory_data
|
||||
|
||||
return cached[0]
|
||||
|
||||
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Reload memory data from file, forcing cache invalidation."""
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
memory_data = self._load_memory_from_file(agent_name)
|
||||
|
||||
try:
|
||||
mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
except OSError:
|
||||
mtime = None
|
||||
|
||||
cache_key = (user_id, agent_name)
|
||||
self._memory_cache[cache_key] = (memory_data, mtime)
|
||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||
return memory_data
|
||||
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
"""Save memory data to file and update cache."""
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
|
||||
try:
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -165,8 +155,7 @@ class FileMemoryStorage(MemoryStorage):
|
||||
except OSError:
|
||||
mtime = None
|
||||
|
||||
cache_key = (user_id, agent_name)
|
||||
self._memory_cache[cache_key] = (memory_data, mtime)
|
||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||
logger.info("Memory saved to %s", file_path)
|
||||
return True
|
||||
except OSError as e:
|
||||
|
||||
@@ -27,28 +27,27 @@ def _create_empty_memory() -> dict[str, Any]:
|
||||
return create_empty_memory()
|
||||
|
||||
|
||||
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
"""Backward-compatible wrapper around the configured memory storage save path."""
|
||||
return get_memory_storage().save(memory_data, agent_name, user_id=user_id)
|
||||
return get_memory_storage().save(memory_data, agent_name)
|
||||
|
||||
|
||||
def get_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Get the current memory data via storage provider."""
|
||||
return get_memory_storage().load(agent_name, user_id=user_id)
|
||||
return get_memory_storage().load(agent_name)
|
||||
|
||||
|
||||
def reload_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Reload memory data via storage provider."""
|
||||
return get_memory_storage().reload(agent_name, user_id=user_id)
|
||||
return get_memory_storage().reload(agent_name)
|
||||
|
||||
|
||||
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Persist imported memory data via storage provider.
|
||||
|
||||
Args:
|
||||
memory_data: Full memory payload to persist.
|
||||
agent_name: If provided, imports into per-agent memory.
|
||||
user_id: If provided, scopes memory to a specific user.
|
||||
|
||||
Returns:
|
||||
The saved memory data after storage normalization.
|
||||
@@ -57,15 +56,15 @@ def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = Non
|
||||
OSError: If persisting the imported memory fails.
|
||||
"""
|
||||
storage = get_memory_storage()
|
||||
if not storage.save(memory_data, agent_name, user_id=user_id):
|
||||
if not storage.save(memory_data, agent_name):
|
||||
raise OSError("Failed to save imported memory data")
|
||||
return storage.load(agent_name, user_id=user_id)
|
||||
return storage.load(agent_name)
|
||||
|
||||
|
||||
def clear_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Clear all stored memory data and persist an empty structure."""
|
||||
cleared_memory = create_empty_memory()
|
||||
if not _save_memory_to_file(cleared_memory, agent_name, user_id=user_id):
|
||||
if not _save_memory_to_file(cleared_memory, agent_name):
|
||||
raise OSError("Failed to save cleared memory data")
|
||||
return cleared_memory
|
||||
|
||||
@@ -82,8 +81,6 @@ def create_memory_fact(
|
||||
category: str = "context",
|
||||
confidence: float = 0.5,
|
||||
agent_name: str | None = None,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new fact and persist the updated memory data."""
|
||||
normalized_content = content.strip()
|
||||
@@ -93,7 +90,7 @@ def create_memory_fact(
|
||||
normalized_category = category.strip() or "context"
|
||||
validated_confidence = _validate_confidence(confidence)
|
||||
now = utc_now_iso_z()
|
||||
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||
memory_data = get_memory_data(agent_name)
|
||||
updated_memory = dict(memory_data)
|
||||
facts = list(memory_data.get("facts", []))
|
||||
facts.append(
|
||||
@@ -108,15 +105,15 @@ def create_memory_fact(
|
||||
)
|
||||
updated_memory["facts"] = facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||
if not _save_memory_to_file(updated_memory, agent_name):
|
||||
raise OSError("Failed to save memory data after creating fact")
|
||||
|
||||
return updated_memory
|
||||
|
||||
|
||||
def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Delete a fact by its id and persist the updated memory data."""
|
||||
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||
memory_data = get_memory_data(agent_name)
|
||||
facts = memory_data.get("facts", [])
|
||||
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
|
||||
if len(updated_facts) == len(facts):
|
||||
@@ -125,7 +122,7 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id:
|
||||
updated_memory = dict(memory_data)
|
||||
updated_memory["facts"] = updated_facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||
if not _save_memory_to_file(updated_memory, agent_name):
|
||||
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
|
||||
|
||||
return updated_memory
|
||||
@@ -137,11 +134,9 @@ def update_memory_fact(
|
||||
category: str | None = None,
|
||||
confidence: float | None = None,
|
||||
agent_name: str | None = None,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing fact and persist the updated memory data."""
|
||||
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||
memory_data = get_memory_data(agent_name)
|
||||
updated_memory = dict(memory_data)
|
||||
updated_facts: list[dict[str, Any]] = []
|
||||
found = False
|
||||
@@ -168,7 +163,7 @@ def update_memory_fact(
|
||||
|
||||
updated_memory["facts"] = updated_facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||
if not _save_memory_to_file(updated_memory, agent_name):
|
||||
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
|
||||
|
||||
return updated_memory
|
||||
@@ -281,7 +276,6 @@ class MemoryUpdater:
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Update memory based on conversation messages.
|
||||
|
||||
@@ -291,7 +285,6 @@ class MemoryUpdater:
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
user_id: If provided, scopes memory to a specific user.
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
@@ -305,7 +298,7 @@ class MemoryUpdater:
|
||||
|
||||
try:
|
||||
# Get current memory
|
||||
current_memory = get_memory_data(agent_name, user_id=user_id)
|
||||
current_memory = get_memory_data(agent_name)
|
||||
|
||||
# Format conversation for prompt
|
||||
conversation_text = format_conversation_for_update(messages)
|
||||
@@ -360,7 +353,7 @@ class MemoryUpdater:
|
||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||
|
||||
# Save
|
||||
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
|
||||
return get_memory_storage().save(updated_memory, agent_name)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||
@@ -462,7 +455,6 @@ def update_memory_from_conversation(
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Convenience function to update memory from a conversation.
|
||||
|
||||
@@ -472,10 +464,9 @@ def update_memory_from_conversation(
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
user_id: If provided, scopes memory to a specific user.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected, user_id=user_id)
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
|
||||
|
||||
@@ -11,7 +11,6 @@ from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -237,16 +236,11 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
# Queue the filtered conversation for memory update
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||
# Capture user_id at enqueue time while the request context is still alive.
|
||||
# threading.Timer fires on a different thread where ContextVar values are not
|
||||
# propagated, so we must store user_id explicitly in ConversationContext.
|
||||
user_id = get_effective_user_id()
|
||||
queue = get_memory_queue()
|
||||
queue.add(
|
||||
thread_id=thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=self._agent_name,
|
||||
user_id=user_id,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
|
||||
@@ -8,7 +8,6 @@ from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -47,34 +46,32 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||
self._lazy_init = lazy_init
|
||||
|
||||
def _get_thread_paths(self, thread_id: str, user_id: str | None = None) -> dict[str, str]:
|
||||
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
|
||||
"""Get the paths for a thread's data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
user_id: Optional user ID for per-user path isolation.
|
||||
|
||||
Returns:
|
||||
Dictionary with workspace_path, uploads_path, and outputs_path.
|
||||
"""
|
||||
return {
|
||||
"workspace_path": str(self._paths.sandbox_work_dir(thread_id, user_id=user_id)),
|
||||
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id, user_id=user_id)),
|
||||
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id, user_id=user_id)),
|
||||
"workspace_path": str(self._paths.sandbox_work_dir(thread_id)),
|
||||
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)),
|
||||
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)),
|
||||
}
|
||||
|
||||
def _create_thread_directories(self, thread_id: str, user_id: str | None = None) -> dict[str, str]:
|
||||
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
|
||||
"""Create the thread data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
user_id: Optional user ID for per-user path isolation.
|
||||
|
||||
Returns:
|
||||
Dictionary with the created directory paths.
|
||||
"""
|
||||
self._paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
||||
return self._get_thread_paths(thread_id, user_id=user_id)
|
||||
self._paths.ensure_thread_dirs(thread_id)
|
||||
return self._get_thread_paths(thread_id)
|
||||
|
||||
@override
|
||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
@@ -87,14 +84,12 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
if thread_id is None:
|
||||
raise ValueError("Thread ID is required in runtime context or config.configurable")
|
||||
|
||||
user_id = get_effective_user_id()
|
||||
|
||||
if self._lazy_init:
|
||||
# Lazy initialization: only compute paths, don't create directories
|
||||
paths = self._get_thread_paths(thread_id, user_id=user_id)
|
||||
paths = self._get_thread_paths(thread_id)
|
||||
else:
|
||||
# Eager initialization: create directories immediately
|
||||
paths = self._create_thread_directories(thread_id, user_id=user_id)
|
||||
paths = self._create_thread_directories(thread_id)
|
||||
logger.debug("Created thread data directories for thread %s", thread_id)
|
||||
|
||||
return {
|
||||
|
||||
@@ -10,7 +10,6 @@ from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.utils.file_conversion import extract_outline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -222,7 +221,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
thread_id = get_config().get("configurable", {}).get("thread_id")
|
||||
except RuntimeError:
|
||||
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) if thread_id else None
|
||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
||||
|
||||
# Get newly uploaded files from the current message's additional_kwargs.files
|
||||
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
|
||||
|
||||
@@ -40,7 +40,6 @@ from deerflow.config.app_config import get_app_config, reload_app_config
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.models import create_chat_model
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.skills.installer import install_skill_from_archive
|
||||
from deerflow.uploads.manager import (
|
||||
claim_unique_filename,
|
||||
@@ -241,7 +240,7 @@ class DeerFlowClient:
|
||||
}
|
||||
checkpointer = self._checkpointer
|
||||
if checkpointer is None:
|
||||
from deerflow.runtime.checkpointer import get_checkpointer
|
||||
from deerflow.agents.checkpointer import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
if checkpointer is not None:
|
||||
@@ -375,7 +374,7 @@ class DeerFlowClient:
|
||||
"""
|
||||
checkpointer = self._checkpointer
|
||||
if checkpointer is None:
|
||||
from deerflow.runtime.checkpointer.provider import get_checkpointer
|
||||
from deerflow.agents.checkpointer.provider import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
|
||||
@@ -430,7 +429,7 @@ class DeerFlowClient:
|
||||
"""
|
||||
checkpointer = self._checkpointer
|
||||
if checkpointer is None:
|
||||
from deerflow.runtime.checkpointer.provider import get_checkpointer
|
||||
from deerflow.agents.checkpointer.provider import get_checkpointer
|
||||
|
||||
checkpointer = get_checkpointer()
|
||||
|
||||
@@ -770,19 +769,19 @@ class DeerFlowClient:
|
||||
"""
|
||||
from deerflow.agents.memory.updater import get_memory_data
|
||||
|
||||
return get_memory_data(user_id=get_effective_user_id())
|
||||
return get_memory_data()
|
||||
|
||||
def export_memory(self) -> dict:
|
||||
"""Export current memory data for backup or transfer."""
|
||||
from deerflow.agents.memory.updater import get_memory_data
|
||||
|
||||
return get_memory_data(user_id=get_effective_user_id())
|
||||
return get_memory_data()
|
||||
|
||||
def import_memory(self, memory_data: dict) -> dict:
|
||||
"""Import and persist full memory data."""
|
||||
from deerflow.agents.memory.updater import import_memory_data
|
||||
|
||||
return import_memory_data(memory_data, user_id=get_effective_user_id())
|
||||
return import_memory_data(memory_data)
|
||||
|
||||
def get_model(self, name: str) -> dict | None:
|
||||
"""Get a specific model's configuration by name.
|
||||
@@ -957,13 +956,13 @@ class DeerFlowClient:
|
||||
"""
|
||||
from deerflow.agents.memory.updater import reload_memory_data
|
||||
|
||||
return reload_memory_data(user_id=get_effective_user_id())
|
||||
return reload_memory_data()
|
||||
|
||||
def clear_memory(self) -> dict:
|
||||
"""Clear all persisted memory data."""
|
||||
from deerflow.agents.memory.updater import clear_memory_data
|
||||
|
||||
return clear_memory_data(user_id=get_effective_user_id())
|
||||
return clear_memory_data()
|
||||
|
||||
def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict:
|
||||
"""Create a single fact manually."""
|
||||
@@ -1180,7 +1179,7 @@ class DeerFlowClient:
|
||||
ValueError: If the path is invalid.
|
||||
"""
|
||||
try:
|
||||
actual = get_paths().resolve_virtual_path(thread_id, path, user_id=get_effective_user_id())
|
||||
actual = get_paths().resolve_virtual_path(thread_id, path)
|
||||
except ValueError as exc:
|
||||
if "traversal" in str(exc):
|
||||
from deerflow.uploads.manager import PathTraversalError
|
||||
|
||||
@@ -27,7 +27,6 @@ except ImportError: # pragma: no cover - Windows fallback
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
||||
|
||||
@@ -261,16 +260,15 @@ class AioSandboxProvider(SandboxProvider):
|
||||
mounted Docker socket (DooD), the host Docker daemon can resolve the paths.
|
||||
"""
|
||||
paths = get_paths()
|
||||
user_id = get_effective_user_id()
|
||||
paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
||||
paths.ensure_thread_dirs(thread_id)
|
||||
|
||||
return [
|
||||
(paths.host_sandbox_work_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
|
||||
(paths.host_sandbox_uploads_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
|
||||
(paths.host_sandbox_outputs_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
|
||||
(paths.host_sandbox_work_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
|
||||
(paths.host_sandbox_uploads_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
|
||||
(paths.host_sandbox_outputs_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
|
||||
# ACP workspace: read-only inside the sandbox (lead agent reads results;
|
||||
# the ACP subprocess writes from the host side, not from within the container).
|
||||
(paths.host_acp_workspace_dir(thread_id, user_id=user_id), "/mnt/acp-workspace", True),
|
||||
(paths.host_acp_workspace_dir(thread_id), "/mnt/acp-workspace", True),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -482,9 +480,8 @@ class AioSandboxProvider(SandboxProvider):
|
||||
across multiple processes, preventing container-name conflicts.
|
||||
"""
|
||||
paths = get_paths()
|
||||
user_id = get_effective_user_id()
|
||||
paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
||||
lock_path = paths.thread_dir(thread_id, user_id=user_id) / f"{sandbox_id}.lock"
|
||||
paths.ensure_thread_dirs(thread_id)
|
||||
lock_path = paths.thread_dir(thread_id) / f"{sandbox_id}.lock"
|
||||
|
||||
with open(lock_path, "a", encoding="utf-8") as lock_file:
|
||||
locked = False
|
||||
|
||||
@@ -4,12 +4,8 @@ Controls BOTH the LangGraph checkpointer and the DeerFlow application
|
||||
persistence layer (runs, threads metadata, users, etc.). The user
|
||||
configures one backend; the system handles physical separation details.
|
||||
|
||||
SQLite mode: checkpointer and app share a single .db file
|
||||
({sqlite_dir}/deerflow.db) with WAL journal mode enabled on every
|
||||
connection. WAL allows concurrent readers and a single writer without
|
||||
blocking, making a unified file safe for both workloads. Writers
|
||||
that contend for the lock wait via the default 5-second sqlite3
|
||||
busy timeout rather than failing immediately.
|
||||
SQLite mode: checkpointer and app use different .db files in the same
|
||||
directory to avoid write-lock contention. This is automatic.
|
||||
|
||||
Postgres mode: both use the same database URL but maintain independent
|
||||
connection pools with different lifecycles.
|
||||
@@ -44,7 +40,7 @@ class DatabaseConfig(BaseModel):
|
||||
)
|
||||
sqlite_dir: str = Field(
|
||||
default=".deer-flow/data",
|
||||
description=("Directory for the SQLite database file. Both checkpointer and application data share {sqlite_dir}/deerflow.db."),
|
||||
description=("Directory for SQLite database files. Checkpointer uses {sqlite_dir}/checkpoints.db, application data uses {sqlite_dir}/app.db."),
|
||||
)
|
||||
postgres_url: str = Field(
|
||||
default="",
|
||||
@@ -73,27 +69,21 @@ class DatabaseConfig(BaseModel):
|
||||
|
||||
return str(Path(self.sqlite_dir).resolve())
|
||||
|
||||
@property
|
||||
def sqlite_path(self) -> str:
|
||||
"""Unified SQLite file path shared by checkpointer and app."""
|
||||
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
|
||||
|
||||
# Backward-compatible aliases
|
||||
@property
|
||||
def checkpointer_sqlite_path(self) -> str:
|
||||
"""SQLite file path for the LangGraph checkpointer (alias for sqlite_path)."""
|
||||
return self.sqlite_path
|
||||
"""SQLite file path for the LangGraph checkpointer."""
|
||||
return os.path.join(self._resolved_sqlite_dir, "checkpoints.db")
|
||||
|
||||
@property
|
||||
def app_sqlite_path(self) -> str:
|
||||
"""SQLite file path for application ORM data (alias for sqlite_path)."""
|
||||
return self.sqlite_path
|
||||
"""SQLite file path for application ORM data."""
|
||||
return os.path.join(self._resolved_sqlite_dir, "app.db")
|
||||
|
||||
@property
|
||||
def app_sqlalchemy_url(self) -> str:
|
||||
"""SQLAlchemy async URL for the application ORM engine."""
|
||||
if self.backend == "sqlite":
|
||||
return f"sqlite+aiosqlite:///{self.sqlite_path}"
|
||||
return f"sqlite+aiosqlite:///{self.app_sqlite_path}"
|
||||
if self.backend == "postgres":
|
||||
url = self.postgres_url
|
||||
if url.startswith("postgresql://"):
|
||||
|
||||
@@ -14,9 +14,8 @@ class MemoryConfig(BaseModel):
|
||||
default="",
|
||||
description=(
|
||||
"Path to store memory data. "
|
||||
"If empty, defaults to per-user memory at `{base_dir}/users/{user_id}/memory.json`. "
|
||||
"Absolute paths are used as-is and opt out of per-user isolation "
|
||||
"(all users share the same file). "
|
||||
"If empty, defaults to `{base_dir}/memory.json` (see Paths.memory_file). "
|
||||
"Absolute paths are used as-is. "
|
||||
"Relative paths are resolved against `Paths.base_dir` "
|
||||
"(not the backend working directory). "
|
||||
"Note: if you previously set this to `.deer-flow/memory.json`, "
|
||||
|
||||
@@ -7,7 +7,6 @@ from pathlib import Path, PureWindowsPath
|
||||
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
||||
|
||||
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
_SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
|
||||
|
||||
def _default_local_base_dir() -> Path:
|
||||
@@ -23,13 +22,6 @@ def _validate_thread_id(thread_id: str) -> str:
|
||||
return thread_id
|
||||
|
||||
|
||||
def _validate_user_id(user_id: str) -> str:
|
||||
"""Validate a user ID before using it in filesystem paths."""
|
||||
if not _SAFE_USER_ID_RE.match(user_id):
|
||||
raise ValueError(f"Invalid user_id {user_id!r}: only alphanumeric characters, hyphens, and underscores are allowed.")
|
||||
return user_id
|
||||
|
||||
|
||||
def _join_host_path(base: str, *parts: str) -> str:
|
||||
"""Join host filesystem path segments while preserving native style.
|
||||
|
||||
@@ -142,63 +134,44 @@ class Paths:
|
||||
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
|
||||
return self.agent_dir(name) / "memory.json"
|
||||
|
||||
def user_dir(self, user_id: str) -> Path:
|
||||
"""Directory for a specific user: `{base_dir}/users/{user_id}/`."""
|
||||
return self.base_dir / "users" / _validate_user_id(user_id)
|
||||
|
||||
def user_memory_file(self, user_id: str) -> Path:
|
||||
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
|
||||
return self.user_dir(user_id) / "memory.json"
|
||||
|
||||
def user_agent_memory_file(self, user_id: str, agent_name: str) -> Path:
|
||||
"""Per-user per-agent memory: `{base_dir}/users/{user_id}/agents/{name}/memory.json`."""
|
||||
return self.user_dir(user_id) / "agents" / agent_name.lower() / "memory.json"
|
||||
|
||||
def thread_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||
def thread_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for a thread's data.
|
||||
|
||||
When *user_id* is provided:
|
||||
`{base_dir}/users/{user_id}/threads/{thread_id}/`
|
||||
Otherwise (legacy layout):
|
||||
`{base_dir}/threads/{thread_id}/`
|
||||
Host path for a thread's data: `{base_dir}/threads/{thread_id}/`
|
||||
|
||||
This directory contains a `user-data/` subdirectory that is mounted
|
||||
as `/mnt/user-data/` inside the sandbox.
|
||||
|
||||
Raises:
|
||||
ValueError: If `thread_id` or `user_id` contains unsafe characters (path
|
||||
separators or `..`) that could cause directory traversal.
|
||||
ValueError: If `thread_id` contains unsafe characters (path separators
|
||||
or `..`) that could cause directory traversal.
|
||||
"""
|
||||
if user_id is not None:
|
||||
return self.user_dir(user_id) / "threads" / _validate_thread_id(thread_id)
|
||||
return self.base_dir / "threads" / _validate_thread_id(thread_id)
|
||||
|
||||
def sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||
def sandbox_work_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for the agent's workspace directory.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/workspace/`
|
||||
Sandbox: `/mnt/user-data/workspace/`
|
||||
"""
|
||||
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "workspace"
|
||||
return self.thread_dir(thread_id) / "user-data" / "workspace"
|
||||
|
||||
def sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||
def sandbox_uploads_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for user-uploaded files.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/uploads/`
|
||||
Sandbox: `/mnt/user-data/uploads/`
|
||||
"""
|
||||
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "uploads"
|
||||
return self.thread_dir(thread_id) / "user-data" / "uploads"
|
||||
|
||||
def sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||
def sandbox_outputs_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for agent-generated artifacts.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/outputs/`
|
||||
Sandbox: `/mnt/user-data/outputs/`
|
||||
"""
|
||||
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "outputs"
|
||||
return self.thread_dir(thread_id) / "user-data" / "outputs"
|
||||
|
||||
def acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||
def acp_workspace_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for the ACP workspace of a specific thread.
|
||||
Host: `{base_dir}/threads/{thread_id}/acp-workspace/`
|
||||
@@ -207,43 +180,41 @@ class Paths:
|
||||
Each thread gets its own isolated ACP workspace so that concurrent
|
||||
sessions cannot read each other's ACP agent outputs.
|
||||
"""
|
||||
return self.thread_dir(thread_id, user_id=user_id) / "acp-workspace"
|
||||
return self.thread_dir(thread_id) / "acp-workspace"
|
||||
|
||||
def sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||
def sandbox_user_data_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for the user-data root.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/`
|
||||
Sandbox: `/mnt/user-data/`
|
||||
"""
|
||||
return self.thread_dir(thread_id, user_id=user_id) / "user-data"
|
||||
return self.thread_dir(thread_id) / "user-data"
|
||||
|
||||
def host_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||
def host_thread_dir(self, thread_id: str) -> str:
|
||||
"""Host path for a thread directory, preserving Windows path syntax."""
|
||||
if user_id is not None:
|
||||
return _join_host_path(self._host_base_dir_str(), "users", _validate_user_id(user_id), "threads", _validate_thread_id(thread_id))
|
||||
return _join_host_path(self._host_base_dir_str(), "threads", _validate_thread_id(thread_id))
|
||||
|
||||
def host_sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||
def host_sandbox_user_data_dir(self, thread_id: str) -> str:
|
||||
"""Host path for a thread's user-data root."""
|
||||
return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "user-data")
|
||||
return _join_host_path(self.host_thread_dir(thread_id), "user-data")
|
||||
|
||||
def host_sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||
def host_sandbox_work_dir(self, thread_id: str) -> str:
|
||||
"""Host path for the workspace mount source."""
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "workspace")
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "workspace")
|
||||
|
||||
def host_sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||
def host_sandbox_uploads_dir(self, thread_id: str) -> str:
|
||||
"""Host path for the uploads mount source."""
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "uploads")
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "uploads")
|
||||
|
||||
def host_sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||
def host_sandbox_outputs_dir(self, thread_id: str) -> str:
|
||||
"""Host path for the outputs mount source."""
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "outputs")
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "outputs")
|
||||
|
||||
def host_acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||
def host_acp_workspace_dir(self, thread_id: str) -> str:
|
||||
"""Host path for the ACP workspace mount source."""
|
||||
return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "acp-workspace")
|
||||
return _join_host_path(self.host_thread_dir(thread_id), "acp-workspace")
|
||||
|
||||
def ensure_thread_dirs(self, thread_id: str, *, user_id: str | None = None) -> None:
|
||||
def ensure_thread_dirs(self, thread_id: str) -> None:
|
||||
"""Create all standard sandbox directories for a thread.
|
||||
|
||||
Directories are created with mode 0o777 so that sandbox containers
|
||||
@@ -257,24 +228,24 @@ class Paths:
|
||||
ACP agent invocation.
|
||||
"""
|
||||
for d in [
|
||||
self.sandbox_work_dir(thread_id, user_id=user_id),
|
||||
self.sandbox_uploads_dir(thread_id, user_id=user_id),
|
||||
self.sandbox_outputs_dir(thread_id, user_id=user_id),
|
||||
self.acp_workspace_dir(thread_id, user_id=user_id),
|
||||
self.sandbox_work_dir(thread_id),
|
||||
self.sandbox_uploads_dir(thread_id),
|
||||
self.sandbox_outputs_dir(thread_id),
|
||||
self.acp_workspace_dir(thread_id),
|
||||
]:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
d.chmod(0o777)
|
||||
|
||||
def delete_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> None:
|
||||
def delete_thread_dir(self, thread_id: str) -> None:
|
||||
"""Delete all persisted data for a thread.
|
||||
|
||||
The operation is idempotent: missing thread directories are ignored.
|
||||
"""
|
||||
thread_dir = self.thread_dir(thread_id, user_id=user_id)
|
||||
thread_dir = self.thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
shutil.rmtree(thread_dir)
|
||||
|
||||
def resolve_virtual_path(self, thread_id: str, virtual_path: str, *, user_id: str | None = None) -> Path:
|
||||
def resolve_virtual_path(self, thread_id: str, virtual_path: str) -> Path:
|
||||
"""Resolve a sandbox virtual path to the actual host filesystem path.
|
||||
|
||||
Args:
|
||||
@@ -282,7 +253,6 @@ class Paths:
|
||||
virtual_path: Virtual path as seen inside the sandbox, e.g.
|
||||
``/mnt/user-data/outputs/report.pdf``.
|
||||
Leading slashes are stripped before matching.
|
||||
user_id: Optional user ID for user-scoped path resolution.
|
||||
|
||||
Returns:
|
||||
The resolved absolute host filesystem path.
|
||||
@@ -300,7 +270,7 @@ class Paths:
|
||||
raise ValueError(f"Path must start with /{prefix}")
|
||||
|
||||
relative = stripped[len(prefix) :].lstrip("/")
|
||||
base = self.sandbox_user_data_dir(thread_id, user_id=user_id).resolve()
|
||||
base = self.sandbox_user_data_dir(thread_id).resolve()
|
||||
actual = (base / relative).resolve()
|
||||
|
||||
try:
|
||||
|
||||
@@ -98,11 +98,6 @@ async def init_engine(
|
||||
# SQLite deployment (TC-UPG-06 in AUTH_TEST_PLAN.md). The companion
|
||||
# ``synchronous=NORMAL`` is the safe-and-fast pairing — fsync only
|
||||
# at WAL checkpoint boundaries instead of every commit.
|
||||
# Note: we do not set PRAGMA busy_timeout here — Python's sqlite3
|
||||
# driver already defaults to a 5-second busy timeout (see the
|
||||
# ``timeout`` kwarg of ``sqlite3.connect``), and aiosqlite /
|
||||
# SQLAlchemy's aiosqlite dialect inherit that default. Setting
|
||||
# it again would be a no-op.
|
||||
@event.listens_for(_engine.sync_engine, "connect")
|
||||
def _enable_sqlite_wal(dbapi_conn, _record): # noqa: ARG001 — SQLAlchemy contract
|
||||
cursor = dbapi_conn.cursor()
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import DateTime, String, Text, UniqueConstraint
|
||||
from sqlalchemy import DateTime, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
@@ -13,14 +13,10 @@ from deerflow.persistence.base import Base
|
||||
class FeedbackRow(Base):
|
||||
__tablename__ = "feedback"
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
|
||||
)
|
||||
|
||||
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
message_id: Mapped[str | None] = mapped_column(String(64))
|
||||
# message_id is an optional RunEventStore event identifier —
|
||||
# allows feedback to target a specific message or the entire run
|
||||
|
||||
@@ -12,7 +12,7 @@ from sqlalchemy import case, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
|
||||
|
||||
class FeedbackRepository:
|
||||
@@ -33,19 +33,19 @@ class FeedbackRepository:
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
message_id: str | None = None,
|
||||
comment: str | None = None,
|
||||
) -> dict:
|
||||
"""Create a feedback record. rating must be +1 or -1."""
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.create")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create")
|
||||
row = FeedbackRow(
|
||||
feedback_id=str(uuid.uuid4()),
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
user_id=resolved_user_id,
|
||||
owner_id=resolved_owner_id,
|
||||
message_id=message_id,
|
||||
rating=rating,
|
||||
comment=comment,
|
||||
@@ -61,14 +61,14 @@ class FeedbackRepository:
|
||||
self,
|
||||
feedback_id: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict | None:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.get")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return None
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
@@ -78,12 +78,12 @@ class FeedbackRepository:
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_run")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
@@ -94,12 +94,12 @@ class FeedbackRepository:
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
@@ -109,97 +109,19 @@ class FeedbackRepository:
|
||||
self,
|
||||
feedback_id: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> bool:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return False
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return False
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
comment: str | None = None,
|
||||
) -> dict:
|
||||
"""Create or update feedback for (thread_id, run_id, user_id). rating must be +1 or -1."""
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.upsert")
|
||||
async with self._sf() as session:
|
||||
stmt = select(FeedbackRow).where(
|
||||
FeedbackRow.thread_id == thread_id,
|
||||
FeedbackRow.run_id == run_id,
|
||||
FeedbackRow.user_id == resolved_user_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is not None:
|
||||
row.rating = rating
|
||||
row.comment = comment
|
||||
row.created_at = datetime.now(UTC)
|
||||
else:
|
||||
row = FeedbackRow(
|
||||
feedback_id=str(uuid.uuid4()),
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
user_id=resolved_user_id,
|
||||
rating=rating,
|
||||
comment=comment,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def delete_by_run(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> bool:
|
||||
"""Delete the current user's feedback for a run. Returns True if a record was deleted."""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete_by_run")
|
||||
async with self._sf() as session:
|
||||
stmt = select(FeedbackRow).where(
|
||||
FeedbackRow.thread_id == thread_id,
|
||||
FeedbackRow.run_id == run_id,
|
||||
FeedbackRow.user_id == resolved_user_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
return False
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def list_by_thread_grouped(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict[str, dict]:
|
||||
"""Return feedback grouped by run_id for a thread: {run_id: feedback_dict}."""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread_grouped")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return {row.run_id: self._row_to_dict(row) for row in result.scalars()}
|
||||
|
||||
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
|
||||
"""Aggregate feedback stats for a run using database-side counting."""
|
||||
stmt = select(
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
script_location = %(here)s
|
||||
# Default URL for offline mode / autogenerate.
|
||||
# Runtime uses engine from DeerFlow config.
|
||||
sqlalchemy.url = sqlite+aiosqlite:///./data/deerflow.db
|
||||
sqlalchemy.url = sqlite+aiosqlite:///./data/app.db
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
@@ -19,7 +19,7 @@ class RunEventRow(Base):
|
||||
# Owner of the conversation this event belongs to. Nullable for data
|
||||
# created before auth was introduced; populated by auth middleware on
|
||||
# new writes and by the boot-time orphan migration on existing rows.
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
# "message" | "trace" | "lifecycle"
|
||||
|
||||
@@ -16,7 +16,7 @@ class RunRow(Base):
|
||||
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128))
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending")
|
||||
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.run.model import RunRow
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
|
||||
|
||||
class RunRepository(RunStore):
|
||||
@@ -69,7 +69,7 @@ class RunRepository(RunStore):
|
||||
*,
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
@@ -78,13 +78,13 @@ class RunRepository(RunStore):
|
||||
created_at=None,
|
||||
follow_up_to_run_id=None,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put")
|
||||
now = datetime.now(UTC)
|
||||
row = RunRow(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
user_id=resolved_user_id,
|
||||
owner_id=resolved_owner_id,
|
||||
status=status,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata_json=self._safe_json(metadata) or {},
|
||||
@@ -102,14 +102,14 @@ class RunRepository(RunStore):
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.get")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is None:
|
||||
return None
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
@@ -117,13 +117,13 @@ class RunRepository(RunStore):
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
limit=100,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.list_by_thread")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread")
|
||||
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunRow.user_id == resolved_user_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
@@ -141,14 +141,14 @@ class RunRepository(RunStore):
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.delete")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
@@ -1,38 +1,13 @@
|
||||
"""Thread metadata persistence — ORM, abstract store, and concrete implementations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.persistence.thread_meta.sql import ThreadMetaRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.store.base import BaseStore
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
__all__ = [
|
||||
"MemoryThreadMetaStore",
|
||||
"ThreadMetaRepository",
|
||||
"ThreadMetaRow",
|
||||
"ThreadMetaStore",
|
||||
"make_thread_store",
|
||||
]
|
||||
|
||||
|
||||
def make_thread_store(
|
||||
session_factory: async_sessionmaker[AsyncSession] | None,
|
||||
store: BaseStore | None = None,
|
||||
) -> ThreadMetaStore:
|
||||
"""Create the appropriate ThreadMetaStore based on available backends.
|
||||
|
||||
Returns a SQL-backed repository when a session factory is available,
|
||||
otherwise falls back to the in-memory LangGraph Store implementation.
|
||||
"""
|
||||
if session_factory is not None:
|
||||
return ThreadMetaRepository(session_factory)
|
||||
if store is None:
|
||||
raise ValueError("make_thread_store requires either a session_factory (SQL) or a store (memory)")
|
||||
return MemoryThreadMetaStore(store)
|
||||
|
||||
@@ -3,21 +3,12 @@
|
||||
Implementations:
|
||||
- ThreadMetaRepository: SQL-backed (sqlite / postgres via SQLAlchemy)
|
||||
- MemoryThreadMetaStore: wraps LangGraph BaseStore (memory mode)
|
||||
|
||||
All mutating and querying methods accept a ``user_id`` parameter with
|
||||
three-state semantics (see :mod:`deerflow.runtime.user_context`):
|
||||
|
||||
- ``AUTO`` (default): resolve from the request-scoped contextvar.
|
||||
- Explicit ``str``: use the provided value verbatim.
|
||||
- Explicit ``None``: bypass owner filtering (migration/CLI only).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel
|
||||
|
||||
|
||||
class ThreadMetaStore(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
@@ -26,14 +17,14 @@ class ThreadMetaStore(abc.ABC):
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None = None,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None:
|
||||
async def get(self, thread_id: str) -> dict | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -44,33 +35,26 @@ class ThreadMetaStore(abc.ABC):
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
async def update_status(self, thread_id: str, status: str) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
|
||||
"""Merge ``metadata`` into the thread's metadata field.
|
||||
|
||||
Existing keys are overwritten by the new values; keys absent from
|
||||
``metadata`` are preserved. No-op if the thread does not exist
|
||||
or the owner check fails.
|
||||
``metadata`` are preserved. No-op if the thread does not exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||
"""Check if ``user_id`` has access to ``thread_id``."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
async def delete(self, thread_id: str) -> None:
|
||||
pass
|
||||
|
||||
@@ -13,7 +13,6 @@ from typing import Any
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
|
||||
THREADS_NS: tuple[str, ...] = ("threads",)
|
||||
|
||||
@@ -22,37 +21,20 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
def __init__(self, store: BaseStore) -> None:
|
||||
self._store = store
|
||||
|
||||
async def _get_owned_record(
|
||||
self,
|
||||
thread_id: str,
|
||||
user_id: str | None | _AutoSentinel,
|
||||
method_name: str,
|
||||
) -> dict | None:
|
||||
"""Fetch a record and verify ownership. Returns a mutable copy, or None."""
|
||||
resolved = resolve_user_id(user_id, method_name=method_name)
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
return None
|
||||
record = dict(item.value)
|
||||
if resolved is not None and record.get("user_id") != resolved:
|
||||
return None
|
||||
return record
|
||||
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None = None,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create")
|
||||
now = time.time()
|
||||
record: dict[str, Any] = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"user_id": resolved_user_id,
|
||||
"owner_id": owner_id,
|
||||
"display_name": display_name,
|
||||
"status": "idle",
|
||||
"metadata": metadata or {},
|
||||
@@ -63,8 +45,9 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
return record
|
||||
|
||||
async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None:
|
||||
return await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.get")
|
||||
async def get(self, thread_id: str) -> dict | None:
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
return item.value if item is not None else None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
@@ -73,16 +56,12 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search")
|
||||
filter_dict: dict[str, Any] = {}
|
||||
if metadata:
|
||||
filter_dict.update(metadata)
|
||||
if status:
|
||||
filter_dict["status"] = status
|
||||
if resolved_user_id is not None:
|
||||
filter_dict["user_id"] = resolved_user_id
|
||||
|
||||
items = await self._store.asearch(
|
||||
THREADS_NS,
|
||||
@@ -92,45 +71,37 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
)
|
||||
return [self._item_to_dict(item) for item in items]
|
||||
|
||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
return not require_existing
|
||||
record_user_id = item.value.get("user_id")
|
||||
if record_user_id is None:
|
||||
return True
|
||||
return record_user_id == user_id
|
||||
|
||||
async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_display_name")
|
||||
if record is None:
|
||||
return
|
||||
record = dict(item.value)
|
||||
record["display_name"] = display_name
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_status")
|
||||
if record is None:
|
||||
async def update_status(self, thread_id: str, status: str) -> None:
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
return
|
||||
record = dict(item.value)
|
||||
record["status"] = status
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_metadata")
|
||||
if record is None:
|
||||
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
|
||||
"""Merge ``metadata`` into the in-memory record. No-op if absent."""
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
return
|
||||
record = dict(item.value)
|
||||
merged = dict(record.get("metadata") or {})
|
||||
merged.update(metadata)
|
||||
record["metadata"] = merged
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
|
||||
if record is None:
|
||||
return
|
||||
async def delete(self, thread_id: str) -> None:
|
||||
await self._store.adelete(THREADS_NS, thread_id)
|
||||
|
||||
@staticmethod
|
||||
@@ -140,7 +111,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
return {
|
||||
"thread_id": item.key,
|
||||
"assistant_id": val.get("assistant_id"),
|
||||
"user_id": val.get("user_id"),
|
||||
"owner_id": val.get("owner_id"),
|
||||
"display_name": val.get("display_name"),
|
||||
"status": val.get("status", "idle"),
|
||||
"metadata": val.get("metadata", {}),
|
||||
|
||||
@@ -15,7 +15,7 @@ class ThreadMetaRow(Base):
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
display_name: Mapped[str | None] = mapped_column(String(256))
|
||||
status: Mapped[str] = mapped_column(String(20), default="idle")
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
|
||||
@@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||
|
||||
|
||||
class ThreadMetaRepository(ThreadMetaStore):
|
||||
@@ -32,18 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
# Auto-resolve user_id from contextvar when AUTO; explicit None
|
||||
# Auto-resolve owner_id from contextvar when AUTO; explicit None
|
||||
# creates an orphan row (used by migration scripts).
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.create")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create")
|
||||
now = datetime.now(UTC)
|
||||
row = ThreadMetaRow(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
user_id=resolved_user_id,
|
||||
owner_id=resolved_owner_id,
|
||||
display_name=display_name,
|
||||
metadata_json=metadata or {},
|
||||
created_at=now,
|
||||
@@ -59,34 +59,40 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict | None:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.get")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return None
|
||||
# Enforce owner filter unless explicitly bypassed (user_id=None).
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
# Enforce owner filter unless explicitly bypassed (owner_id=None).
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||
"""Check if ``user_id`` has access to ``thread_id``.
|
||||
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
|
||||
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def check_access(self, thread_id: str, owner_id: str, *, require_existing: bool = False) -> bool:
|
||||
"""Check if ``owner_id`` has access to ``thread_id``.
|
||||
|
||||
Two modes — one row, two distinct semantics depending on what
|
||||
the caller is about to do:
|
||||
|
||||
- ``require_existing=False`` (default, permissive):
|
||||
Returns True for: row missing (untracked legacy thread),
|
||||
``row.user_id`` is None (shared / pre-auth data),
|
||||
or ``row.user_id == user_id``. Use for **read-style**
|
||||
``row.owner_id`` is None (shared / pre-auth data),
|
||||
or ``row.owner_id == owner_id``. Use for **read-style**
|
||||
decorators where treating an untracked thread as accessible
|
||||
preserves backward-compat.
|
||||
|
||||
- ``require_existing=True`` (strict):
|
||||
Returns True **only** when the row exists AND
|
||||
(``row.user_id == user_id`` OR ``row.user_id is None``).
|
||||
(``row.owner_id == owner_id`` OR ``row.owner_id is None``).
|
||||
Use for **destructive / mutating** decorators (DELETE, PATCH,
|
||||
state-update) so a thread that has *already been deleted*
|
||||
cannot be re-targeted by any caller — closing the
|
||||
@@ -97,9 +103,9 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return not require_existing
|
||||
if row.user_id is None:
|
||||
if row.owner_id is None:
|
||||
return True
|
||||
return row.user_id == user_id
|
||||
return row.owner_id == owner_id
|
||||
|
||||
async def search(
|
||||
self,
|
||||
@@ -108,17 +114,17 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
"""Search threads with optional metadata and status filters.
|
||||
|
||||
Owner filter is enforced by default: caller must be in a user
|
||||
context. Pass ``user_id=None`` to bypass (migration/CLI).
|
||||
context. Pass ``owner_id=None`` to bypass (migration/CLI).
|
||||
"""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search")
|
||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id)
|
||||
if status:
|
||||
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||
|
||||
@@ -138,24 +144,24 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool:
|
||||
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool:
|
||||
"""Return True if the row exists and is owned (or filter bypassed)."""
|
||||
if resolved_user_id is None:
|
||||
if resolved_owner_id is None:
|
||||
return True # explicit bypass
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
return row is not None and row.user_id == resolved_user_id
|
||||
return row is not None and row.owner_id == resolved_owner_id
|
||||
|
||||
async def update_display_name(
|
||||
self,
|
||||
thread_id: str,
|
||||
display_name: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
"""Update the display_name (title) for a thread."""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_display_name")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
@@ -165,11 +171,11 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
thread_id: str,
|
||||
status: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_status")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
@@ -179,20 +185,20 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
thread_id: str,
|
||||
metadata: dict,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
"""Merge ``metadata`` into ``metadata_json``.
|
||||
|
||||
Read-modify-write inside a single session/transaction so concurrent
|
||||
callers see consistent state. No-op if the row does not exist or
|
||||
the user_id check fails.
|
||||
the owner_id check fails.
|
||||
"""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_metadata")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return
|
||||
merged = dict(row.metadata_json or {})
|
||||
merged.update(metadata)
|
||||
@@ -204,14 +210,14 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.delete")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||
return
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
@@ -5,18 +5,12 @@ Re-exports the public API of :mod:`~deerflow.runtime.runs` and
|
||||
directly from ``deerflow.runtime``.
|
||||
"""
|
||||
|
||||
from .checkpointer import checkpointer_context, get_checkpointer, make_checkpointer, reset_checkpointer
|
||||
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
|
||||
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
|
||||
from .store import get_store, make_store, reset_store, store_context
|
||||
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
|
||||
|
||||
__all__ = [
|
||||
# checkpointer
|
||||
"checkpointer_context",
|
||||
"get_checkpointer",
|
||||
"make_checkpointer",
|
||||
"reset_checkpointer",
|
||||
# runs
|
||||
"ConflictError",
|
||||
"DisconnectMode",
|
||||
|
||||
@@ -83,18 +83,8 @@ class RunEventStore(abc.ABC):
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending.
|
||||
|
||||
Supports bidirectional cursor pagination:
|
||||
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
|
||||
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
|
||||
- neither: return the latest ``limit`` records (ascending)
|
||||
"""
|
||||
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def count_messages(self, thread_id: str) -> int:
|
||||
|
||||
@@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,22 +55,16 @@ class DbRunEventStore(RunEventStore):
|
||||
return content, metadata or {}
|
||||
|
||||
@staticmethod
|
||||
def _user_id_from_context() -> str | None:
|
||||
"""Soft read of user_id from contextvar for write paths.
|
||||
def _owner_from_context() -> str | None:
|
||||
"""Soft read of owner_id from contextvar for write paths.
|
||||
|
||||
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
||||
which is the expected case for background worker writes. HTTP
|
||||
request writes will have the contextvar set by auth middleware
|
||||
and get their user_id stamped automatically.
|
||||
|
||||
Coerces ``user.id`` to ``str`` at the boundary: ``User.id`` is
|
||||
typed as ``UUID`` by the auth layer, but ``run_events.user_id``
|
||||
is ``VARCHAR(64)`` and aiosqlite cannot bind a raw UUID object
|
||||
to a VARCHAR column ("type 'UUID' is not supported") — the
|
||||
INSERT would silently roll back and the worker would hang.
|
||||
"""
|
||||
user = get_current_user()
|
||||
return str(user.id) if user is not None else None
|
||||
return user.id if user is not None else None
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||
"""Write a single event — low-frequency path only.
|
||||
@@ -87,7 +81,7 @@ class DbRunEventStore(RunEventStore):
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
user_id = self._user_id_from_context()
|
||||
owner_id = self._owner_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||
@@ -98,7 +92,7 @@ class DbRunEventStore(RunEventStore):
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
user_id=user_id,
|
||||
owner_id=owner_id,
|
||||
event_type=event_type,
|
||||
category=category,
|
||||
content=db_content,
|
||||
@@ -112,7 +106,7 @@ class DbRunEventStore(RunEventStore):
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
user_id = self._user_id_from_context()
|
||||
owner_id = self._owner_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||
@@ -136,7 +130,7 @@ class DbRunEventStore(RunEventStore):
|
||||
row = RunEventRow(
|
||||
thread_id=e["thread_id"],
|
||||
run_id=e["run_id"],
|
||||
user_id=e.get("user_id", user_id),
|
||||
owner_id=e.get("owner_id", owner_id),
|
||||
event_type=e["event_type"],
|
||||
category=category,
|
||||
content=db_content,
|
||||
@@ -155,12 +149,12 @@ class DbRunEventStore(RunEventStore):
|
||||
limit=50,
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||
if after_seq is not None:
|
||||
@@ -187,12 +181,12 @@ class DbRunEventStore(RunEventStore):
|
||||
*,
|
||||
event_types=None,
|
||||
limit=500,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
if event_types:
|
||||
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
@@ -205,46 +199,27 @@ class DbRunEventStore(RunEventStore):
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
limit=50,
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||
stmt = select(RunEventRow).where(
|
||||
RunEventRow.thread_id == thread_id,
|
||||
RunEventRow.run_id == run_id,
|
||||
RunEventRow.category == "message",
|
||||
)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq > after_seq)
|
||||
|
||||
if after_seq is not None:
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
else:
|
||||
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
rows = list(result.scalars())
|
||||
return [self._row_to_dict(r) for r in reversed(rows)]
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc())
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def count_messages(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages")
|
||||
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if resolved_owner_id is not None:
|
||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
@@ -252,13 +227,13 @@ class DbRunEventStore(RunEventStore):
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread")
|
||||
async with self._sf() as session:
|
||||
count_conditions = [RunEventRow.thread_id == thread_id]
|
||||
if resolved_user_id is not None:
|
||||
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
||||
if resolved_owner_id is not None:
|
||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
@@ -271,13 +246,13 @@ class DbRunEventStore(RunEventStore):
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
owner_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run")
|
||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run")
|
||||
async with self._sf() as session:
|
||||
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
||||
if resolved_user_id is not None:
|
||||
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
||||
if resolved_owner_id is not None:
|
||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
|
||||
@@ -152,17 +152,9 @@ class JsonlRunEventStore(RunEventStore):
|
||||
events = [e for e in events if e.get("event_type") in event_types]
|
||||
return events[:limit]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
async def list_messages_by_run(self, thread_id, run_id):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
filtered = [e for e in events if e.get("category") == "message"]
|
||||
if before_seq is not None:
|
||||
filtered = [e for e in filtered if e.get("seq", 0) < before_seq]
|
||||
if after_seq is not None:
|
||||
filtered = [e for e in filtered if e.get("seq", 0) > after_seq]
|
||||
if after_seq is not None:
|
||||
return filtered[:limit]
|
||||
else:
|
||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||
return [e for e in events if e.get("category") == "message"]
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
|
||||
@@ -97,17 +97,9 @@ class MemoryRunEventStore(RunEventStore):
|
||||
filtered = [e for e in filtered if e["event_type"] in event_types]
|
||||
return filtered[:limit]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
async def list_messages_by_run(self, thread_id, run_id):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
filtered = [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
|
||||
if before_seq is not None:
|
||||
filtered = [e for e in filtered if e["seq"] < before_seq]
|
||||
if after_seq is not None:
|
||||
filtered = [e for e in filtered if e["seq"] > after_seq]
|
||||
if after_seq is not None:
|
||||
return filtered[:limit]
|
||||
else:
|
||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||
return [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
|
||||
@@ -50,7 +50,6 @@ class RunJournal(BaseCallbackHandler):
|
||||
|
||||
# Write buffer
|
||||
self._buffer: list[dict] = []
|
||||
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
# Token accumulators
|
||||
self._total_input_tokens = 0
|
||||
@@ -246,19 +245,6 @@ class RunJournal(BaseCallbackHandler):
|
||||
|
||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
# Tools that update graph state return a ``Command`` (e.g.
|
||||
# ``present_files``). LangGraph later unwraps the inner ToolMessage
|
||||
# into checkpoint state, so to stay checkpoint-aligned we must
|
||||
# extract it here rather than storing ``str(Command(...))``.
|
||||
if isinstance(output, Command):
|
||||
update = getattr(output, "update", None) or {}
|
||||
inner_msgs = update.get("messages") if isinstance(update, dict) else None
|
||||
if isinstance(inner_msgs, list):
|
||||
inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None)
|
||||
if inner_tool_msg is not None:
|
||||
output = inner_tool_msg
|
||||
|
||||
# Extract fields from ToolMessage object when LangChain provides one.
|
||||
# LangChain's _format_output wraps tool results into a ToolMessage
|
||||
@@ -395,10 +381,6 @@ class RunJournal(BaseCallbackHandler):
|
||||
"""
|
||||
if not self._buffer:
|
||||
return
|
||||
# Skip if a flush is already in flight — avoids concurrent writes
|
||||
# to the same SQLite file from multiple fire-and-forget tasks.
|
||||
if self._pending_flush_tasks:
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
@@ -407,7 +389,6 @@ class RunJournal(BaseCallbackHandler):
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
task = loop.create_task(self._flush_async(batch))
|
||||
self._pending_flush_tasks.add(task)
|
||||
task.add_done_callback(self._on_flush_done)
|
||||
|
||||
async def _flush_async(self, batch: list[dict]) -> None:
|
||||
@@ -423,8 +404,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
# Return failed events to buffer for retry on next flush
|
||||
self._buffer = batch + self._buffer
|
||||
|
||||
def _on_flush_done(self, task: asyncio.Task) -> None:
|
||||
self._pending_flush_tasks.discard(task)
|
||||
@staticmethod
|
||||
def _on_flush_done(task: asyncio.Task) -> None:
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
@@ -469,17 +450,10 @@ class RunJournal(BaseCallbackHandler):
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Force flush remaining buffer. Called in worker's finally block."""
|
||||
if self._pending_flush_tasks:
|
||||
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
|
||||
|
||||
while self._buffer:
|
||||
batch = self._buffer[: self._flush_threshold]
|
||||
del self._buffer[: self._flush_threshold]
|
||||
try:
|
||||
await self._store.put_batch(batch)
|
||||
except Exception:
|
||||
self._buffer = batch + self._buffer
|
||||
raise
|
||||
if self._buffer:
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
await self._store.put_batch(batch)
|
||||
|
||||
def get_completion_data(self) -> dict:
|
||||
"""Return accumulated token and message data for run completion."""
|
||||
|
||||
@@ -4,8 +4,8 @@ RunManager depends on this interface. Implementations:
|
||||
- MemoryRunStore: in-memory dict (development, tests)
|
||||
- Future: RunRepository backed by SQLAlchemy ORM
|
||||
|
||||
All methods accept an optional user_id for user isolation.
|
||||
When user_id is None, no user filtering is applied (single-user mode).
|
||||
All methods accept an optional owner_id for user isolation.
|
||||
When owner_id is None, no user filtering is applied (single-user mode).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -22,7 +22,7 @@ class RunStore(abc.ABC):
|
||||
*,
|
||||
thread_id: str,
|
||||
assistant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
owner_id: str | None = None,
|
||||
status: str = "pending",
|
||||
multitask_strategy: str = "reject",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
@@ -42,7 +42,7 @@ class RunStore(abc.ABC):
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
owner_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@@ -21,7 +21,7 @@ class MemoryRunStore(RunStore):
|
||||
*,
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
user_id=None,
|
||||
owner_id=None,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
@@ -35,7 +35,7 @@ class MemoryRunStore(RunStore):
|
||||
"run_id": run_id,
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"user_id": user_id,
|
||||
"owner_id": owner_id,
|
||||
"status": status,
|
||||
"multitask_strategy": multitask_strategy,
|
||||
"metadata": metadata or {},
|
||||
@@ -49,8 +49,8 @@ class MemoryRunStore(RunStore):
|
||||
async def get(self, run_id):
|
||||
return self._runs.get(run_id)
|
||||
|
||||
async def list_by_thread(self, thread_id, *, user_id=None, limit=100):
|
||||
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (user_id is None or r.get("user_id") == user_id)]
|
||||
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
|
||||
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (owner_id is None or r.get("owner_id") == owner_id)]
|
||||
results.sort(key=lambda r: r["created_at"], reverse=True)
|
||||
return results[:limit]
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ class RunContext:
|
||||
store: Any | None = field(default=None)
|
||||
event_store: Any | None = field(default=None)
|
||||
run_events_config: Any | None = field(default=None)
|
||||
thread_store: Any | None = field(default=None)
|
||||
thread_meta_repo: Any | None = field(default=None)
|
||||
follow_up_to_run_id: str | None = field(default=None)
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ async def run_agent(
|
||||
store = ctx.store
|
||||
event_store = ctx.event_store
|
||||
run_events_config = ctx.run_events_config
|
||||
thread_store = ctx.thread_store
|
||||
thread_meta_repo = ctx.thread_meta_repo
|
||||
follow_up_to_run_id = ctx.follow_up_to_run_id
|
||||
|
||||
run_id = record.run_id
|
||||
@@ -85,7 +85,34 @@ async def run_agent(
|
||||
pre_run_snapshot: dict[str, Any] | None = None
|
||||
snapshot_capture_failed = False
|
||||
|
||||
# Initialize RunJournal for event capture
|
||||
journal = None
|
||||
if event_store is not None:
|
||||
from deerflow.runtime.journal import RunJournal
|
||||
|
||||
journal = RunJournal(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
event_store=event_store,
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
)
|
||||
|
||||
# Write human_message event (model_dump format, aligned with checkpoint)
|
||||
human_msg = _extract_human_message(graph_input)
|
||||
if human_msg is not None:
|
||||
msg_metadata = {}
|
||||
if follow_up_to_run_id:
|
||||
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
||||
await event_store.put(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=human_msg.model_dump(),
|
||||
metadata=msg_metadata or None,
|
||||
)
|
||||
content = human_msg.content
|
||||
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
||||
|
||||
# Track whether "events" was requested but skipped
|
||||
if "events" in requested_modes:
|
||||
@@ -95,38 +122,6 @@ async def run_agent(
|
||||
)
|
||||
|
||||
try:
|
||||
# Initialize RunJournal + write human_message event.
|
||||
# These are inside the try block so any exception (e.g. a DB
|
||||
# error writing the event) flows through the except/finally
|
||||
# path that publishes an "end" event to the SSE bridge —
|
||||
# otherwise a failure here would leave the stream hanging
|
||||
# with no terminator.
|
||||
if event_store is not None:
|
||||
from deerflow.runtime.journal import RunJournal
|
||||
|
||||
journal = RunJournal(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
event_store=event_store,
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
)
|
||||
|
||||
human_msg = _extract_human_message(graph_input)
|
||||
if human_msg is not None:
|
||||
msg_metadata = {}
|
||||
if follow_up_to_run_id:
|
||||
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
||||
await event_store.put(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=human_msg.model_dump(),
|
||||
metadata=msg_metadata or None,
|
||||
)
|
||||
content = human_msg.content
|
||||
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
||||
|
||||
# 1. Mark running
|
||||
await run_manager.set_status(run_id, RunStatus.running)
|
||||
|
||||
@@ -310,15 +305,12 @@ async def run_agent(
|
||||
except Exception:
|
||||
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
|
||||
|
||||
try:
|
||||
# Persist token usage + convenience fields to RunStore
|
||||
completion = journal.get_completion_data()
|
||||
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run completion for %s (non-fatal)", run_id, exc_info=True)
|
||||
# Persist token usage + convenience fields to RunStore
|
||||
completion = journal.get_completion_data()
|
||||
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
|
||||
|
||||
# Sync title from checkpoint to threads_meta.display_name
|
||||
if checkpointer is not None and thread_store is not None:
|
||||
if checkpointer is not None:
|
||||
try:
|
||||
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
|
||||
@@ -326,17 +318,16 @@ async def run_agent(
|
||||
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
|
||||
title = ckpt.get("channel_values", {}).get("title")
|
||||
if title:
|
||||
await thread_store.update_display_name(thread_id, title)
|
||||
await thread_meta_repo.update_display_name(thread_id, title)
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
|
||||
|
||||
# Update threads_meta status based on run outcome
|
||||
if thread_store is not None:
|
||||
try:
|
||||
final_status = "idle" if record.status == RunStatus.success else record.status.value
|
||||
await thread_store.update_status(thread_id, final_status)
|
||||
except Exception:
|
||||
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
|
||||
try:
|
||||
final_status = "idle" if record.status == RunStatus.success else record.status.value
|
||||
await thread_meta_repo.update_status(thread_id, final_status)
|
||||
except Exception:
|
||||
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
|
||||
|
||||
await bridge.publish_end(run_id)
|
||||
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
||||
|
||||
@@ -91,7 +91,7 @@ async def make_store() -> AsyncIterator[BaseStore]:
|
||||
configured checkpointer.
|
||||
|
||||
Reads from the same ``checkpointer`` section of *config.yaml* used by
|
||||
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so
|
||||
:func:`deerflow.agents.checkpointer.async_provider.make_checkpointer` so
|
||||
that both singletons always use the same persistence technology::
|
||||
|
||||
async with make_store() as store:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Async stream bridge factory.
|
||||
|
||||
Provides an **async context manager** aligned with
|
||||
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer`.
|
||||
:func:`deerflow.agents.checkpointer.async_provider.make_checkpointer`.
|
||||
|
||||
Usage (e.g. FastAPI lifespan)::
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Request-scoped user context for user-based authorization.
|
||||
"""Request-scoped user context for owner-based authorization.
|
||||
|
||||
This module holds a :class:`~contextvars.ContextVar` that the gateway's
|
||||
auth middleware sets after a successful authentication. Repository
|
||||
methods read the contextvar via a sentinel default parameter, letting
|
||||
routers stay free of ``user_id`` boilerplate.
|
||||
routers stay free of ``owner_id`` boilerplate.
|
||||
|
||||
Three-state semantics for the repository ``user_id`` parameter (the
|
||||
Three-state semantics for the repository ``owner_id`` parameter (the
|
||||
consumer side of this module lives in ``deerflow.persistence.*``):
|
||||
|
||||
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
|
||||
@@ -91,35 +91,16 @@ def require_current_user() -> CurrentUser:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Effective user_id helpers (filesystem isolation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_USER_ID: Final[str] = "default"
|
||||
|
||||
|
||||
def get_effective_user_id() -> str:
|
||||
"""Return the current user's id as a string, or DEFAULT_USER_ID if unset.
|
||||
|
||||
Unlike :func:`require_current_user` this never raises — it is designed
|
||||
for filesystem-path resolution where a valid user bucket is always needed.
|
||||
"""
|
||||
user = _current_user.get()
|
||||
if user is None:
|
||||
return DEFAULT_USER_ID
|
||||
return str(user.id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sentinel-based user_id resolution
|
||||
# Sentinel-based owner_id resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Repository methods accept a ``user_id`` keyword-only argument that
|
||||
# Repository methods accept an ``owner_id`` keyword-only argument that
|
||||
# defaults to ``AUTO``. The three possible values drive distinct
|
||||
# behaviours; see the docstring on :func:`resolve_user_id`.
|
||||
# behaviours; see the docstring on :func:`resolve_owner_id`.
|
||||
|
||||
|
||||
class _AutoSentinel:
|
||||
"""Singleton marker meaning 'resolve user_id from contextvar'."""
|
||||
"""Singleton marker meaning 'resolve owner_id from contextvar'."""
|
||||
|
||||
_instance: _AutoSentinel | None = None
|
||||
|
||||
@@ -135,12 +116,12 @@ class _AutoSentinel:
|
||||
AUTO: Final[_AutoSentinel] = _AutoSentinel()
|
||||
|
||||
|
||||
def resolve_user_id(
|
||||
def resolve_owner_id(
|
||||
value: str | None | _AutoSentinel,
|
||||
*,
|
||||
method_name: str = "repository method",
|
||||
) -> str | None:
|
||||
"""Resolve the user_id parameter passed to a repository method.
|
||||
"""Resolve the owner_id parameter passed to a repository method.
|
||||
|
||||
Three-state semantics:
|
||||
|
||||
@@ -150,16 +131,16 @@ def resolve_user_id(
|
||||
- Explicit ``str``: use the provided id verbatim, overriding any
|
||||
contextvar value. Useful for tests and admin-override flows.
|
||||
- Explicit ``None``: no filter — the repository should skip the
|
||||
user_id WHERE clause entirely. Reserved for migration scripts
|
||||
owner_id WHERE clause entirely. Reserved for migration scripts
|
||||
and CLI tools that intentionally bypass isolation.
|
||||
"""
|
||||
if isinstance(value, _AutoSentinel):
|
||||
user = _current_user.get()
|
||||
if user is None:
|
||||
raise RuntimeError(f"{method_name} called with user_id=AUTO but no user context is set; pass an explicit user_id, set the contextvar via auth middleware, or opt out with user_id=None for migration/CLI paths.")
|
||||
raise RuntimeError(f"{method_name} called with owner_id=AUTO but no user context is set; pass an explicit owner_id, set the contextvar via auth middleware, or opt out with owner_id=None for migration/CLI paths.")
|
||||
# Coerce to ``str`` at the boundary: ``User.id`` is typed as
|
||||
# ``UUID`` for the API surface, but the persistence layer
|
||||
# stores ``user_id`` as ``String(64)`` and aiosqlite cannot
|
||||
# stores ``owner_id`` as ``String(64)`` and aiosqlite cannot
|
||||
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is
|
||||
# not supported"). Honour the documented return type here
|
||||
# rather than ripple a type change through every caller.
|
||||
|
||||
@@ -200,9 +200,8 @@ def _get_acp_workspace_host_path(thread_id: str | None = None) -> str | None:
|
||||
if thread_id is not None:
|
||||
try:
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
host_path = get_paths().acp_workspace_dir(thread_id, user_id=get_effective_user_id())
|
||||
host_path = get_paths().acp_workspace_dir(thread_id)
|
||||
if host_path.exists():
|
||||
return str(host_path)
|
||||
except Exception:
|
||||
|
||||
@@ -33,12 +33,11 @@ def _get_work_dir(thread_id: str | None) -> str:
|
||||
An absolute physical filesystem path to use as the working directory.
|
||||
"""
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
paths = get_paths()
|
||||
if thread_id:
|
||||
try:
|
||||
work_dir = paths.acp_workspace_dir(thread_id, user_id=get_effective_user_id())
|
||||
work_dir = paths.acp_workspace_dir(thread_id)
|
||||
except ValueError:
|
||||
logger.warning("Invalid thread_id %r for ACP workspace, falling back to global", thread_id)
|
||||
work_dir = paths.base_dir / "acp-workspace"
|
||||
|
||||
@@ -8,7 +8,6 @@ from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
||||
|
||||
@@ -48,7 +47,7 @@ def _normalize_presented_filepath(
|
||||
virtual_prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
|
||||
|
||||
if stripped == virtual_prefix or stripped.startswith(virtual_prefix + "/"):
|
||||
actual_path = get_paths().resolve_virtual_path(thread_id, filepath, user_id=get_effective_user_id())
|
||||
actual_path = get_paths().resolve_virtual_path(thread_id, filepath)
|
||||
else:
|
||||
actual_path = Path(filepath).expanduser().resolve()
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
|
||||
class PathTraversalError(ValueError):
|
||||
@@ -34,7 +33,7 @@ def validate_thread_id(thread_id: str) -> None:
|
||||
def get_uploads_dir(thread_id: str) -> Path:
|
||||
"""Return the uploads directory path for a thread (no side effects)."""
|
||||
validate_thread_id(thread_id)
|
||||
return get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||
return get_paths().sandbox_uploads_dir(thread_id)
|
||||
|
||||
|
||||
def ensure_uploads_dir(thread_id: str) -> Path:
|
||||
|
||||
@@ -39,13 +39,13 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
ollama = ["langchain-ollama>=0.3.0"]
|
||||
postgres = [
|
||||
"asyncpg>=0.29",
|
||||
"langgraph-checkpoint-postgres>=3.0.5",
|
||||
"psycopg[binary]>=3.3.3",
|
||||
"psycopg-pool>=3.3.0",
|
||||
]
|
||||
ollama = ["langchain-ollama>=0.3.0"]
|
||||
pymupdf = ["pymupdf4llm>=0.0.17"]
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -23,7 +23,9 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
postgres = ["deerflow-harness[postgres]"]
|
||||
postgres = [
|
||||
"deerflow-harness[postgres]",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = ["pytest>=8.0.0", "ruff>=0.14.11"]
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
"""One-time migration: move legacy thread dirs and memory into per-user layout.
|
||||
|
||||
Usage:
|
||||
PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run]
|
||||
|
||||
The script is idempotent — re-running it after a successful migration is a no-op.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate_thread_dirs(
|
||||
paths: Paths,
|
||||
thread_owner_map: dict[str, str],
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
) -> list[dict]:
|
||||
"""Move legacy thread directories into per-user layout.
|
||||
|
||||
Args:
|
||||
paths: Paths instance.
|
||||
thread_owner_map: Mapping of thread_id -> user_id from threads_meta table.
|
||||
dry_run: If True, only log what would happen.
|
||||
|
||||
Returns:
|
||||
List of migration report entries.
|
||||
"""
|
||||
report: list[dict] = []
|
||||
legacy_threads = paths.base_dir / "threads"
|
||||
if not legacy_threads.exists():
|
||||
logger.info("No legacy threads directory found — nothing to migrate.")
|
||||
return report
|
||||
|
||||
for thread_dir in sorted(legacy_threads.iterdir()):
|
||||
if not thread_dir.is_dir():
|
||||
continue
|
||||
thread_id = thread_dir.name
|
||||
user_id = thread_owner_map.get(thread_id, "default")
|
||||
dest = paths.base_dir / "users" / user_id / "threads" / thread_id
|
||||
|
||||
entry = {"thread_id": thread_id, "user_id": user_id, "action": ""}
|
||||
|
||||
if dest.exists():
|
||||
conflicts_dir = paths.base_dir / "migration-conflicts" / thread_id
|
||||
entry["action"] = f"conflict -> {conflicts_dir}"
|
||||
if not dry_run:
|
||||
conflicts_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(thread_dir), str(conflicts_dir))
|
||||
logger.warning("Conflict for thread %s: moved to %s", thread_id, conflicts_dir)
|
||||
else:
|
||||
entry["action"] = f"moved -> {dest}"
|
||||
if not dry_run:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(thread_dir), str(dest))
|
||||
logger.info("Migrated thread %s -> user %s", thread_id, user_id)
|
||||
|
||||
report.append(entry)
|
||||
|
||||
# Clean up empty legacy threads dir
|
||||
if not dry_run and legacy_threads.exists() and not any(legacy_threads.iterdir()):
|
||||
legacy_threads.rmdir()
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def migrate_memory(
|
||||
paths: Paths,
|
||||
user_id: str = "default",
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
"""Move legacy global memory.json into per-user layout.
|
||||
|
||||
Args:
|
||||
paths: Paths instance.
|
||||
user_id: Target user to receive the legacy memory.
|
||||
dry_run: If True, only log.
|
||||
"""
|
||||
legacy_mem = paths.base_dir / "memory.json"
|
||||
if not legacy_mem.exists():
|
||||
logger.info("No legacy memory.json found — nothing to migrate.")
|
||||
return
|
||||
|
||||
dest = paths.user_memory_file(user_id)
|
||||
if dest.exists():
|
||||
legacy_backup = paths.base_dir / "memory.legacy.json"
|
||||
logger.warning("Destination %s exists; renaming legacy to %s", dest, legacy_backup)
|
||||
if not dry_run:
|
||||
legacy_mem.rename(legacy_backup)
|
||||
return
|
||||
|
||||
logger.info("Migrating memory.json -> %s", dest)
|
||||
if not dry_run:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(legacy_mem), str(dest))
|
||||
|
||||
|
||||
def _build_owner_map_from_db(paths: Paths) -> dict[str, str]:
|
||||
"""Query threads_meta table for thread_id -> user_id mapping.
|
||||
|
||||
Uses raw sqlite3 to avoid async dependencies.
|
||||
"""
|
||||
import sqlite3
|
||||
|
||||
db_path = paths.base_dir / "deer-flow.db"
|
||||
if not db_path.exists():
|
||||
logger.info("No database found at %s — using empty owner map.", db_path)
|
||||
return {}
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
try:
|
||||
cursor = conn.execute("SELECT thread_id, user_id FROM threads_meta WHERE user_id IS NOT NULL")
|
||||
return {row[0]: row[1] for row in cursor.fetchall()}
|
||||
except sqlite3.OperationalError as e:
|
||||
logger.warning("Failed to query threads_meta: %s", e)
|
||||
return {}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Migrate DeerFlow data to per-user layout")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Log actions without making changes")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
paths = get_paths()
|
||||
logger.info("Base directory: %s", paths.base_dir)
|
||||
logger.info("Dry run: %s", args.dry_run)
|
||||
|
||||
owner_map = _build_owner_map_from_db(paths)
|
||||
logger.info("Found %d thread ownership records in DB", len(owner_map))
|
||||
|
||||
report = migrate_thread_dirs(paths, owner_map, dry_run=args.dry_run)
|
||||
migrate_memory(paths, user_id="default", dry_run=args.dry_run)
|
||||
|
||||
if report:
|
||||
logger.info("Migration report:")
|
||||
for entry in report:
|
||||
logger.info(" thread=%s user=%s action=%s", entry["thread_id"], entry["user_id"], entry["action"])
|
||||
else:
|
||||
logger.info("No threads to migrate.")
|
||||
|
||||
unowned = [e for e in report if e["user_id"] == "default"]
|
||||
if unowned:
|
||||
logger.warning("%d thread(s) had no owner and were assigned to 'default':", len(unowned))
|
||||
for e in unowned:
|
||||
logger.warning(" %s", e["thread_id"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,16 +3,16 @@
|
||||
The production gateway runs ``AuthMiddleware`` (validates the JWT cookie)
|
||||
ahead of every router, plus ``@require_permission(owner_check=True)``
|
||||
decorators that read ``request.state.auth`` and call
|
||||
``thread_store.check_access``. Router-level unit tests construct
|
||||
``thread_meta_repo.check_access``. Router-level unit tests construct
|
||||
**bare** FastAPI apps that include only one router — they have neither
|
||||
the auth middleware nor a real thread_store, so the decorators raise
|
||||
the auth middleware nor a real thread_meta_repo, so the decorators raise
|
||||
401 (TestClient path) or ValueError (direct-call path).
|
||||
|
||||
This module provides two surfaces:
|
||||
|
||||
1. :func:`make_authed_test_app` — wraps ``FastAPI()`` with a tiny
|
||||
``BaseHTTPMiddleware`` that stamps a fake user / AuthContext on every
|
||||
request, plus a permissive ``thread_store`` mock on
|
||||
request, plus a permissive ``thread_meta_repo`` mock on
|
||||
``app.state``. Use from TestClient-based router tests.
|
||||
|
||||
2. :func:`call_unwrapped` — invokes the underlying function bypassing
|
||||
@@ -86,20 +86,20 @@ def make_authed_test_app(
|
||||
user_factory: Callable[[], User] | None = None,
|
||||
owner_check_passes: bool = True,
|
||||
) -> FastAPI:
|
||||
"""Build a FastAPI test app with stub auth + permissive thread_store.
|
||||
"""Build a FastAPI test app with stub auth + permissive thread_meta_repo.
|
||||
|
||||
Args:
|
||||
user_factory: Override the default test user. Must return a fully
|
||||
populated :class:`User`. Useful for cross-user isolation tests
|
||||
that need a stable id across requests.
|
||||
owner_check_passes: When True (default), ``thread_store.check_access``
|
||||
owner_check_passes: When True (default), ``thread_meta_repo.check_access``
|
||||
returns True for every call so ``@require_permission(owner_check=True)``
|
||||
never blocks the route under test. Pass False to verify that
|
||||
permission failures surface correctly.
|
||||
|
||||
Returns:
|
||||
A ``FastAPI`` app with the stub middleware installed and
|
||||
``app.state.thread_store`` set to a permissive mock. The
|
||||
``app.state.thread_meta_repo`` set to a permissive mock. The
|
||||
caller is still responsible for ``app.include_router(...)``.
|
||||
"""
|
||||
factory = user_factory or _make_stub_user
|
||||
@@ -108,7 +108,7 @@ def make_authed_test_app(
|
||||
|
||||
repo = MagicMock()
|
||||
repo.check_access = AsyncMock(return_value=owner_check_passes)
|
||||
app.state.thread_store = repo
|
||||
app.state.thread_meta_repo = repo
|
||||
|
||||
return app
|
||||
|
||||
|
||||
+19
-19
@@ -38,29 +38,11 @@ _executor_mock.get_background_task_result = MagicMock()
|
||||
sys.modules["deerflow.subagents.executor"] = _executor_mock
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def provisioner_module():
|
||||
"""Load docker/provisioner/app.py as an importable test module.
|
||||
|
||||
Shared by test_provisioner_kubeconfig and test_provisioner_pvc_volumes so
|
||||
that any change to the provisioner entry-point path or module name only
|
||||
needs to be updated in one place.
|
||||
"""
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
module_path = repo_root / "docker" / "provisioner" / "app.py"
|
||||
spec = importlib.util.spec_from_file_location("provisioner_app_test", module_path)
|
||||
assert spec is not None
|
||||
assert spec.loader is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-set user context for every test unless marked no_auto_user
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Repository methods read ``user_id`` from a contextvar by default
|
||||
# Repository methods read ``owner_id`` from a contextvar by default
|
||||
# (see ``deerflow.runtime.user_context``). Without this fixture, every
|
||||
# pre-existing persistence test would raise RuntimeError because the
|
||||
# contextvar is unset. The fixture sets a default test user on every
|
||||
@@ -95,3 +77,21 @@ def _auto_user_context(request):
|
||||
yield
|
||||
finally:
|
||||
reset_current_user(token)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def provisioner_module():
|
||||
"""Load docker/provisioner/app.py as an importable test module.
|
||||
|
||||
Shared by test_provisioner_kubeconfig and test_provisioner_pvc_volumes so
|
||||
that any change to the provisioner entry-point path or module name only
|
||||
needs to be updated in one place.
|
||||
"""
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
module_path = repo_root / "docker" / "provisioner" / "app.py"
|
||||
spec = importlib.util.spec_from_file_location("provisioner_app_test", module_path)
|
||||
assert spec is not None
|
||||
assert spec.loader is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
@@ -57,7 +57,6 @@ def test_get_thread_mounts_includes_acp_workspace(tmp_path, monkeypatch):
|
||||
"""_get_thread_mounts must include /mnt/acp-workspace (read-only) for docker sandbox."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
|
||||
|
||||
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-3")
|
||||
|
||||
@@ -96,7 +95,6 @@ def test_get_thread_mounts_preserves_windows_host_path_style(tmp_path, monkeypat
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
monkeypatch.setenv("DEER_FLOW_HOST_BASE_DIR", r"C:\Users\demo\deer-flow\backend\.deer-flow")
|
||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
|
||||
|
||||
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-10")
|
||||
|
||||
|
||||
@@ -231,7 +231,7 @@ class TestResolveAttachments:
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.sandbox_outputs_dir.return_value = outputs_dir
|
||||
|
||||
def resolve_side_effect(tid, vpath, *, user_id=None):
|
||||
def resolve_side_effect(tid, vpath):
|
||||
if "data.csv" in vpath:
|
||||
return good_file
|
||||
return tmp_path / "missing.txt"
|
||||
|
||||
@@ -6,13 +6,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
import deerflow.config.app_config as app_config_module
|
||||
from deerflow.agents.checkpointer import get_checkpointer, reset_checkpointer
|
||||
from deerflow.config.checkpointer_config import (
|
||||
CheckpointerConfig,
|
||||
get_checkpointer_config,
|
||||
load_checkpointer_config_from_dict,
|
||||
set_checkpointer_config,
|
||||
)
|
||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -78,7 +78,7 @@ class TestGetCheckpointer:
|
||||
"""get_checkpointer should return InMemorySaver when not configured."""
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
|
||||
with patch("deerflow.agents.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
|
||||
cp = get_checkpointer()
|
||||
assert cp is not None
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
@@ -178,7 +178,7 @@ class TestAsyncCheckpointer:
|
||||
@pytest.mark.anyio
|
||||
async def test_sqlite_creates_parent_dir_via_to_thread(self):
|
||||
"""Async SQLite setup should move mkdir off the event loop."""
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = CheckpointerConfig(type="sqlite", connection_string="relative/test.db")
|
||||
@@ -195,11 +195,11 @@ class TestAsyncCheckpointer:
|
||||
mock_module.AsyncSqliteSaver = mock_saver_cls
|
||||
|
||||
with (
|
||||
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
|
||||
patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config),
|
||||
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
|
||||
patch("deerflow.runtime.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
|
||||
patch("deerflow.agents.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
|
||||
patch(
|
||||
"deerflow.runtime.checkpointer.async_provider.resolve_sqlite_conn_str",
|
||||
"deerflow.agents.checkpointer.async_provider.resolve_sqlite_conn_str",
|
||||
return_value="/tmp/resolved/test.db",
|
||||
),
|
||||
):
|
||||
|
||||
@@ -12,14 +12,14 @@ class TestCheckpointerNoneFix:
|
||||
@pytest.mark.anyio
|
||||
async def test_async_make_checkpointer_returns_in_memory_saver_when_not_configured(self):
|
||||
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
# Mock get_app_config to return a config with checkpointer=None and database=None
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = None
|
||||
mock_config.database = None
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config):
|
||||
with patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config):
|
||||
async with make_checkpointer() as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
@@ -36,13 +36,13 @@ class TestCheckpointerNoneFix:
|
||||
|
||||
def test_sync_checkpointer_context_returns_in_memory_saver_when_not_configured(self):
|
||||
"""checkpointer_context should return InMemorySaver when config.checkpointer is None."""
|
||||
from deerflow.runtime.checkpointer.provider import checkpointer_context
|
||||
from deerflow.agents.checkpointer.provider import checkpointer_context
|
||||
|
||||
# Mock get_app_config to return a config with checkpointer=None
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = None
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_app_config", return_value=mock_config):
|
||||
with patch("deerflow.agents.checkpointer.provider.get_app_config", return_value=mock_config):
|
||||
with checkpointer_context() as checkpointer:
|
||||
# Should return InMemorySaver, not None
|
||||
assert checkpointer is not None
|
||||
|
||||
@@ -817,7 +817,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._agent_name = "custom-agent"
|
||||
client._available_skills = {"test_skill"}
|
||||
@@ -842,7 +842,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=mock_checkpointer),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
@@ -867,7 +867,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
@@ -886,7 +886,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=None),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=None),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
@@ -1015,7 +1015,7 @@ class TestThreadQueries:
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.list.return_value = []
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
# No internal checkpointer, should fetch from provider
|
||||
result = client.list_threads()
|
||||
|
||||
@@ -1069,7 +1069,7 @@ class TestThreadQueries:
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.list.return_value = []
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||
result = client.get_thread("t99")
|
||||
|
||||
assert result["thread_id"] == "t99"
|
||||
@@ -1241,10 +1241,7 @@ class TestMemoryManagement:
|
||||
with patch("deerflow.agents.memory.updater.import_memory_data", return_value=imported) as mock_import:
|
||||
result = client.import_memory(imported)
|
||||
|
||||
assert mock_import.call_count == 1
|
||||
call_args = mock_import.call_args
|
||||
assert call_args.args == (imported,)
|
||||
assert "user_id" in call_args.kwargs
|
||||
mock_import.assert_called_once_with(imported)
|
||||
assert result == imported
|
||||
|
||||
def test_reload_memory(self, client):
|
||||
@@ -1490,12 +1487,9 @@ class TestUploads:
|
||||
|
||||
class TestArtifacts:
|
||||
def test_get_artifact(self, client):
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
user_id = get_effective_user_id()
|
||||
outputs = paths.sandbox_outputs_dir("t1", user_id=user_id)
|
||||
outputs = paths.sandbox_outputs_dir("t1")
|
||||
outputs.mkdir(parents=True)
|
||||
(outputs / "result.txt").write_text("artifact content")
|
||||
|
||||
@@ -1506,12 +1500,9 @@ class TestArtifacts:
|
||||
assert "text" in mime
|
||||
|
||||
def test_get_artifact_not_found(self, client):
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
user_id = get_effective_user_id()
|
||||
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
||||
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
@@ -1522,12 +1513,9 @@ class TestArtifacts:
|
||||
client.get_artifact("t1", "bad/path/file.txt")
|
||||
|
||||
def test_get_artifact_path_traversal(self, client):
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
user_id = get_effective_user_id()
|
||||
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
||||
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
with pytest.raises(PathTraversalError):
|
||||
@@ -1711,16 +1699,13 @@ class TestScenarioFileLifecycle:
|
||||
|
||||
def test_upload_then_read_artifact(self, client):
|
||||
"""Upload a file, simulate agent producing artifact, read it back."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
uploads_dir = tmp_path / "uploads"
|
||||
uploads_dir.mkdir()
|
||||
|
||||
paths = Paths(base_dir=tmp_path)
|
||||
user_id = get_effective_user_id()
|
||||
outputs_dir = paths.sandbox_outputs_dir("t-artifact", user_id=user_id)
|
||||
outputs_dir = paths.sandbox_outputs_dir("t-artifact")
|
||||
outputs_dir.mkdir(parents=True)
|
||||
|
||||
# Upload phase
|
||||
@@ -1859,7 +1844,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config_a)
|
||||
first_agent = client._agent
|
||||
@@ -1887,7 +1872,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
client._ensure_agent(config)
|
||||
@@ -1912,7 +1897,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
client.reset_agent()
|
||||
@@ -1970,14 +1955,11 @@ class TestScenarioThreadIsolation:
|
||||
|
||||
def test_artifacts_isolated_per_thread(self, client):
|
||||
"""Artifacts in thread-A are not accessible from thread-B."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
user_id = get_effective_user_id()
|
||||
outputs_a = paths.sandbox_outputs_dir("thread-a", user_id=user_id)
|
||||
outputs_a = paths.sandbox_outputs_dir("thread-a")
|
||||
outputs_a.mkdir(parents=True)
|
||||
paths.sandbox_outputs_dir("thread-b", user_id=user_id).mkdir(parents=True)
|
||||
paths.sandbox_user_data_dir("thread-b").mkdir(parents=True)
|
||||
(outputs_a / "result.txt").write_text("thread-a artifact")
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
@@ -2882,12 +2864,9 @@ class TestUploadDeleteSymlink:
|
||||
class TestArtifactHardening:
|
||||
def test_artifact_directory_rejected(self, client):
|
||||
"""get_artifact rejects paths that resolve to a directory."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
user_id = get_effective_user_id()
|
||||
subdir = paths.sandbox_outputs_dir("t1", user_id=user_id) / "subdir"
|
||||
subdir = paths.sandbox_outputs_dir("t1") / "subdir"
|
||||
subdir.mkdir(parents=True)
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
@@ -2896,12 +2875,9 @@ class TestArtifactHardening:
|
||||
|
||||
def test_artifact_leading_slash_stripped(self, client):
|
||||
"""Paths with leading slash are handled correctly."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
user_id = get_effective_user_id()
|
||||
outputs = paths.sandbox_outputs_dir("t1", user_id=user_id)
|
||||
outputs = paths.sandbox_outputs_dir("t1")
|
||||
outputs.mkdir(parents=True)
|
||||
(outputs / "file.txt").write_text("content")
|
||||
|
||||
@@ -3015,12 +2991,9 @@ class TestBugArtifactPrefixMatchTooLoose:
|
||||
|
||||
def test_exact_prefix_without_subpath_accepted(self, client):
|
||||
"""Bare 'mnt/user-data' is accepted (will later fail as directory, not at prefix)."""
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
paths = Paths(base_dir=tmp)
|
||||
user_id = get_effective_user_id()
|
||||
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
||||
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
||||
|
||||
with patch("deerflow.client.get_paths", return_value=paths):
|
||||
# Accepted at prefix check, but fails because it's a directory.
|
||||
|
||||
@@ -262,9 +262,8 @@ class TestFileUploadIntegration:
|
||||
|
||||
# Physically exists
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
assert (get_paths().sandbox_uploads_dir(tid, user_id=get_effective_user_id()) / "readme.txt").exists()
|
||||
assert (get_paths().sandbox_uploads_dir(tid) / "readme.txt").exists()
|
||||
|
||||
def test_upload_duplicate_rename(self, e2e_env, tmp_path):
|
||||
"""Uploading two files with the same name auto-renames the second."""
|
||||
@@ -473,13 +472,12 @@ class TestArtifactAccess:
|
||||
def test_get_artifact_happy_path(self, e2e_env):
|
||||
"""Write a file to outputs, then read it back via get_artifact()."""
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||
tid = str(uuid.uuid4())
|
||||
|
||||
# Create an output file in the thread's outputs directory
|
||||
outputs_dir = get_paths().sandbox_outputs_dir(tid, user_id=get_effective_user_id())
|
||||
outputs_dir = get_paths().sandbox_outputs_dir(tid)
|
||||
outputs_dir.mkdir(parents=True, exist_ok=True)
|
||||
(outputs_dir / "result.txt").write_text("hello artifact")
|
||||
|
||||
@@ -490,12 +488,11 @@ class TestArtifactAccess:
|
||||
def test_get_artifact_nested_path(self, e2e_env):
|
||||
"""Artifacts in subdirectories are accessible."""
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||
tid = str(uuid.uuid4())
|
||||
|
||||
outputs_dir = get_paths().sandbox_outputs_dir(tid, user_id=get_effective_user_id())
|
||||
outputs_dir = get_paths().sandbox_outputs_dir(tid)
|
||||
sub = outputs_dir / "charts"
|
||||
sub.mkdir(parents=True, exist_ok=True)
|
||||
(sub / "data.json").write_text('{"x": 1}')
|
||||
|
||||
+134
-116
@@ -1,19 +1,21 @@
|
||||
"""Tests for _ensure_admin_user() in app.py.
|
||||
|
||||
Covers: first-boot no-op (admin creation removed), orphan migration
|
||||
when admin exists, no-op on no admin found, and edge cases.
|
||||
Covers: first-boot admin creation, auto-reset on needs_setup=True,
|
||||
no-op on needs_setup=False, migration, and edge cases.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
|
||||
|
||||
@@ -33,90 +35,53 @@ def _make_app_stub(store=None):
|
||||
return app
|
||||
|
||||
|
||||
def _make_provider(admin_count=0):
|
||||
def _make_provider(user_count=0, admin_user=None):
|
||||
p = AsyncMock()
|
||||
p.count_users = AsyncMock(return_value=admin_count)
|
||||
p.count_admin_users = AsyncMock(return_value=admin_count)
|
||||
p.create_user = AsyncMock()
|
||||
p.count_users = AsyncMock(return_value=user_count)
|
||||
p.create_user = AsyncMock(
|
||||
side_effect=lambda **kw: User(
|
||||
email=kw["email"],
|
||||
password_hash="hashed",
|
||||
system_role=kw.get("system_role", "user"),
|
||||
needs_setup=kw.get("needs_setup", False),
|
||||
)
|
||||
)
|
||||
p.get_user_by_email = AsyncMock(return_value=admin_user)
|
||||
p.update_user = AsyncMock(side_effect=lambda u: u)
|
||||
return p
|
||||
|
||||
|
||||
def _make_session_factory(admin_row=None):
|
||||
"""Build a mock async session factory that returns a row from execute()."""
|
||||
row_result = MagicMock()
|
||||
row_result.scalar_one_or_none.return_value = admin_row
|
||||
|
||||
execute_result = MagicMock()
|
||||
execute_result.scalar_one_or_none.return_value = admin_row
|
||||
|
||||
session = AsyncMock()
|
||||
session.execute = AsyncMock(return_value=execute_result)
|
||||
|
||||
# Async context manager
|
||||
session_cm = AsyncMock()
|
||||
session_cm.__aenter__ = AsyncMock(return_value=session)
|
||||
session_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
sf = MagicMock()
|
||||
sf.return_value = session_cm
|
||||
return sf
|
||||
# ── First boot: no users ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ── First boot: no admin → generate init_token, return early ─────────────
|
||||
|
||||
|
||||
def test_first_boot_does_not_create_admin():
|
||||
"""admin_count==0 → generate init_token, do NOT create admin automatically."""
|
||||
provider = _make_provider(admin_count=0)
|
||||
def test_first_boot_creates_admin():
|
||||
"""count_users==0 → create admin with needs_setup=True."""
|
||||
provider = _make_provider(user_count=0)
|
||||
app = _make_app_stub()
|
||||
app.state.init_token = None # lifespan sets this
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_not_called()
|
||||
# init_token must have been set on app.state
|
||||
assert app.state.init_token is not None
|
||||
assert len(app.state.init_token) > 10
|
||||
provider.create_user.assert_called_once()
|
||||
call_kwargs = provider.create_user.call_args[1]
|
||||
assert call_kwargs["email"] == "admin@deerflow.dev"
|
||||
assert call_kwargs["system_role"] == "admin"
|
||||
assert call_kwargs["needs_setup"] is True
|
||||
assert len(call_kwargs["password"]) > 10 # random password generated
|
||||
|
||||
|
||||
def test_first_boot_skips_migration():
|
||||
"""No admin → return early before any migration attempt."""
|
||||
provider = _make_provider(admin_count=0)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=[])
|
||||
app = _make_app_stub(store=store)
|
||||
app.state.init_token = None # lifespan sets this
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_not_called()
|
||||
|
||||
|
||||
# ── Admin exists: migration runs when admin row found ────────────────────
|
||||
|
||||
|
||||
def test_admin_exists_triggers_migration():
|
||||
"""Admin exists and admin row found → _migrate_orphaned_threads called."""
|
||||
from uuid import uuid4
|
||||
|
||||
admin_row = MagicMock()
|
||||
admin_row.id = uuid4()
|
||||
|
||||
provider = _make_provider(admin_count=1)
|
||||
sf = _make_session_factory(admin_row=admin_row)
|
||||
def test_first_boot_triggers_migration_if_store_present():
|
||||
"""First boot with store → _migrate_orphaned_threads called."""
|
||||
provider = _make_provider(user_count=0)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=[])
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
@@ -124,87 +89,140 @@ def test_admin_exists_triggers_migration():
|
||||
store.asearch.assert_called_once()
|
||||
|
||||
|
||||
def test_admin_exists_no_admin_row_skips_migration():
|
||||
"""Admin count > 0 but DB row missing (edge case) → skip migration gracefully."""
|
||||
provider = _make_provider(admin_count=2)
|
||||
sf = _make_session_factory(admin_row=None)
|
||||
store = AsyncMock()
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_not_called()
|
||||
|
||||
|
||||
def test_admin_exists_no_store_skips_migration():
|
||||
"""Admin exists, row found, but no store → no crash, no migration."""
|
||||
from uuid import uuid4
|
||||
|
||||
admin_row = MagicMock()
|
||||
admin_row.id = uuid4()
|
||||
|
||||
provider = _make_provider(admin_count=1)
|
||||
sf = _make_session_factory(admin_row=admin_row)
|
||||
def test_first_boot_no_store_skips_migration():
|
||||
"""First boot without store → no crash, migration skipped."""
|
||||
provider = _make_provider(user_count=0)
|
||||
app = _make_app_stub(store=None)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
# No assertion needed — just verify no crash
|
||||
provider.create_user.assert_called_once()
|
||||
|
||||
|
||||
def test_admin_exists_session_factory_none_skips_migration():
|
||||
"""get_session_factory() returns None → return early, no crash."""
|
||||
provider = _make_provider(admin_count=1)
|
||||
store = AsyncMock()
|
||||
app = _make_app_stub(store=store)
|
||||
# ── Subsequent boot: needs_setup=True → auto-reset ───────────────────────
|
||||
|
||||
|
||||
def test_needs_setup_true_resets_password():
|
||||
"""Existing admin with needs_setup=True → password reset + token_version bumped."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="old-hash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=0,
|
||||
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=None):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_not_called()
|
||||
# Password was reset
|
||||
provider.update_user.assert_called_once()
|
||||
updated = provider.update_user.call_args[0][0]
|
||||
assert updated.password_hash == "new-hash"
|
||||
assert updated.token_version == 1
|
||||
|
||||
|
||||
def test_needs_setup_true_consecutive_resets_increment_version():
|
||||
"""Two boots with needs_setup=True → token_version increments each time."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="hash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
updated = provider.update_user.call_args[0][0]
|
||||
assert updated.token_version == 4
|
||||
|
||||
|
||||
# ── Subsequent boot: needs_setup=False → no-op ──────────────────────────
|
||||
|
||||
|
||||
def test_needs_setup_false_no_reset():
|
||||
"""Admin with needs_setup=False → no password reset, no update."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="stable-hash",
|
||||
system_role="admin",
|
||||
needs_setup=False,
|
||||
token_version=2,
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.update_user.assert_not_called()
|
||||
assert admin.password_hash == "stable-hash"
|
||||
assert admin.token_version == 2
|
||||
|
||||
|
||||
# ── Edge cases ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_no_admin_email_found_no_crash():
|
||||
"""Users exist but no admin@deerflow.dev → no crash, no reset."""
|
||||
provider = _make_provider(user_count=3, admin_user=None)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.update_user.assert_not_called()
|
||||
provider.create_user.assert_not_called()
|
||||
|
||||
|
||||
def test_migration_failure_is_non_fatal():
|
||||
"""_migrate_orphaned_threads exception is caught and logged."""
|
||||
from uuid import uuid4
|
||||
|
||||
admin_row = MagicMock()
|
||||
admin_row.id = uuid4()
|
||||
|
||||
provider = _make_provider(admin_count=1)
|
||||
sf = _make_session_factory(admin_row=admin_row)
|
||||
provider = _make_provider(user_count=0)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
# Should not raise
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_called_once()
|
||||
|
||||
|
||||
# ── Section 5.1-5.6 upgrade path: orphan thread migration ────────────────
|
||||
|
||||
|
||||
def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows():
|
||||
def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
||||
"""First boot finds Store-only legacy threads → stamps admin's id.
|
||||
|
||||
Validates the **TC-UPG-02 upgrade story**: an operator running main
|
||||
(no auth) accumulates threads in the LangGraph Store namespace
|
||||
``("threads",)`` with no ``metadata.user_id``. After upgrading to
|
||||
``("threads",)`` with no ``metadata.owner_id``. After upgrading to
|
||||
feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should
|
||||
rewrite each unowned item with the freshly created admin's id.
|
||||
"""
|
||||
@@ -215,7 +233,7 @@ def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows():
|
||||
SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}),
|
||||
SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}),
|
||||
SimpleNamespace(key="t3", value={"metadata": {}}),
|
||||
SimpleNamespace(key="t4", value={"metadata": {"user_id": "someone-else", "title": "preserved"}}),
|
||||
SimpleNamespace(key="t4", value={"metadata": {"owner_id": "someone-else", "title": "preserved"}}),
|
||||
]
|
||||
store = AsyncMock()
|
||||
# asearch returns the entire batch on first call, then an empty page
|
||||
@@ -235,11 +253,11 @@ def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows():
|
||||
assert len(aput_calls) == 3
|
||||
rewritten_keys = {call[1] for call in aput_calls}
|
||||
assert rewritten_keys == {"t1", "t2", "t3"}
|
||||
# Each rewrite carries the new user_id; titles preserved where present.
|
||||
# Each rewrite carries the new owner_id; titles preserved where present.
|
||||
by_key = {call[1]: call[2] for call in aput_calls}
|
||||
assert by_key["t1"]["metadata"]["user_id"] == "admin-id-42"
|
||||
assert by_key["t1"]["metadata"]["owner_id"] == "admin-id-42"
|
||||
assert by_key["t1"]["metadata"]["title"] == "old-thread-1"
|
||||
assert by_key["t3"]["metadata"]["user_id"] == "admin-id-42"
|
||||
assert by_key["t3"]["metadata"]["owner_id"] == "admin-id-42"
|
||||
# The pre-owned item must NOT have been rewritten.
|
||||
assert "t4" not in rewritten_keys
|
||||
|
||||
|
||||
@@ -60,8 +60,8 @@ class TestFeedbackRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_owner(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
assert record["user_id"] == "user-1"
|
||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, owner_id="user-1")
|
||||
assert record["owner_id"] == "user-1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -97,10 +97,10 @@ class TestFeedbackRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-2")
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=1, user_id="user-1")
|
||||
results = await repo.list_by_run("t1", "r1", user_id=None)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1)
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=1)
|
||||
results = await repo.list_by_run("t1", "r1")
|
||||
assert len(results) == 2
|
||||
assert all(r["run_id"] == "r1" for r in results)
|
||||
await _cleanup()
|
||||
@@ -135,9 +135,9 @@ class TestFeedbackRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-2")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-3")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1)
|
||||
stats = await repo.aggregate_by_run("t1", "r1")
|
||||
assert stats["total"] == 3
|
||||
assert stats["positive"] == 2
|
||||
@@ -154,80 +154,6 @@ class TestFeedbackRepository:
|
||||
assert stats["negative"] == 0
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_creates_new(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
assert record["rating"] == 1
|
||||
assert record["feedback_id"]
|
||||
assert record["user_id"] == "u1"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_updates_existing(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
first = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
second = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u1", comment="changed my mind")
|
||||
assert second["feedback_id"] == first["feedback_id"]
|
||||
assert second["rating"] == -1
|
||||
assert second["comment"] == "changed my mind"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_different_users_separate(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
r1 = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
r2 = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u2")
|
||||
assert r1["feedback_id"] != r2["feedback_id"]
|
||||
assert r1["rating"] == 1
|
||||
assert r2["rating"] == -1
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_invalid_rating(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
with pytest.raises(ValueError):
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1")
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
||||
assert deleted is True
|
||||
results = await repo.list_by_run("t1", "r1", user_id="u1")
|
||||
assert len(results) == 0
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run_nonexistent(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
||||
assert deleted is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_grouped(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||
await repo.upsert(run_id="r2", thread_id="t1", rating=-1, user_id="u1")
|
||||
await repo.upsert(run_id="r3", thread_id="t2", rating=1, user_id="u1")
|
||||
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||
assert "r1" in grouped
|
||||
assert "r2" in grouped
|
||||
assert "r3" not in grouped
|
||||
assert grouped["r1"]["rating"] == 1
|
||||
assert grouped["r2"]["rating"] == -1
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_grouped_empty(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||
assert grouped == {}
|
||||
await _cleanup()
|
||||
|
||||
|
||||
# -- Follow-up association --
|
||||
|
||||
|
||||
@@ -1,229 +0,0 @@
|
||||
"""Tests for the POST /api/v1/auth/initialize endpoint.
|
||||
|
||||
Covers: first-boot admin creation, rejection when system already
|
||||
initialized, invalid/missing init_token, password strength validation,
|
||||
and public accessibility (no auth cookie required).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-initialize-admin-min-32")
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
|
||||
_TEST_SECRET = "test-secret-key-initialize-admin-min-32"
|
||||
_INIT_TOKEN = "test-init-token-for-initialization-tests"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth(tmp_path):
|
||||
"""Fresh SQLite engine + auth config per test."""
|
||||
from app.gateway import deps
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
url = f"sqlite+aiosqlite:///{tmp_path}/init_admin.db"
|
||||
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
|
||||
deps._cached_local_provider = None
|
||||
deps._cached_repo = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
deps._cached_local_provider = None
|
||||
deps._cached_repo = None
|
||||
asyncio.run(close_engine())
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(_setup_auth):
|
||||
from app.gateway.app import create_app
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
app = create_app()
|
||||
# Pre-set the init token on app.state (normally done by the lifespan on
|
||||
# first boot; tests don't run the lifespan because it requires config.yaml).
|
||||
app.state.init_token = _INIT_TOKEN
|
||||
# Do NOT use TestClient as a context manager — that would trigger the
|
||||
# full lifespan which requires config.yaml. The auth endpoints work
|
||||
# without the lifespan (persistence engine is set up by _setup_auth).
|
||||
yield TestClient(app)
|
||||
|
||||
|
||||
def _init_payload(**extra):
|
||||
"""Build a valid /initialize payload with the test init_token."""
|
||||
return {
|
||||
"email": "admin@example.com",
|
||||
"password": "Str0ng!Pass99",
|
||||
"init_token": _INIT_TOKEN,
|
||||
**extra,
|
||||
}
|
||||
|
||||
|
||||
# ── Happy path ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_initialize_creates_admin_and_sets_cookie(client):
|
||||
"""POST /initialize when no admin exists → 201, session cookie set."""
|
||||
resp = client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["email"] == "admin@example.com"
|
||||
assert data["system_role"] == "admin"
|
||||
assert "access_token" in resp.cookies
|
||||
|
||||
|
||||
def test_initialize_needs_setup_false(client):
|
||||
"""Newly created admin via /initialize has needs_setup=False."""
|
||||
client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||
me = client.get("/api/v1/auth/me")
|
||||
assert me.status_code == 200
|
||||
assert me.json()["needs_setup"] is False
|
||||
|
||||
|
||||
# ── Token validation ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_initialize_rejects_wrong_token(client):
|
||||
"""Wrong init_token → 403 invalid_init_token."""
|
||||
resp = client.post(
|
||||
"/api/v1/auth/initialize",
|
||||
json={**_init_payload(), "init_token": "wrong-token"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert resp.json()["detail"]["code"] == "invalid_init_token"
|
||||
|
||||
|
||||
def test_initialize_rejects_empty_token(client):
|
||||
"""Empty init_token → 403 (fails constant-time comparison against stored token)."""
|
||||
resp = client.post(
|
||||
"/api/v1/auth/initialize",
|
||||
json={**_init_payload(), "init_token": ""},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_initialize_token_consumed_after_success(client):
|
||||
"""After a successful /initialize the token is consumed and cannot be reused."""
|
||||
client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||
# The token is now None; any subsequent call with the old token must be rejected (403)
|
||||
resp2 = client.post(
|
||||
"/api/v1/auth/initialize",
|
||||
json={**_init_payload(), "email": "other@example.com"},
|
||||
)
|
||||
assert resp2.status_code == 403
|
||||
|
||||
|
||||
# ── Rejection when already initialized ───────────────────────────────────
|
||||
|
||||
|
||||
def test_initialize_rejected_when_admin_exists(client):
|
||||
"""Second call to /initialize after admin exists → 409 system_already_initialized.
|
||||
|
||||
The first call consumes the token. Re-setting it on app.state simulates
|
||||
what would happen if the operator somehow restarted or manually refreshed
|
||||
the token (e.g., in testing).
|
||||
"""
|
||||
client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||
# Re-set the token so the second attempt can pass token validation
|
||||
# and reach the admin-exists check.
|
||||
client.app.state.init_token = _INIT_TOKEN
|
||||
resp2 = client.post(
|
||||
"/api/v1/auth/initialize",
|
||||
json={**_init_payload(), "email": "other@example.com"},
|
||||
)
|
||||
assert resp2.status_code == 409
|
||||
body = resp2.json()
|
||||
assert body["detail"]["code"] == "system_already_initialized"
|
||||
|
||||
|
||||
def test_initialize_token_not_consumed_on_admin_exists(client):
|
||||
"""Token is NOT consumed when the admin-exists guard rejects the request.
|
||||
|
||||
This prevents a DoS where an attacker calls with the correct token when
|
||||
admin already exists and permanently burns the init token.
|
||||
"""
|
||||
client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||
# Token consumed by success above; re-simulate the scenario:
|
||||
# admin exists, token is still valid (re-set), call should 409 and NOT consume token.
|
||||
client.app.state.init_token = _INIT_TOKEN
|
||||
client.post(
|
||||
"/api/v1/auth/initialize",
|
||||
json={**_init_payload(), "email": "other@example.com"},
|
||||
)
|
||||
# Token must still be set (not consumed) after the 409 rejection.
|
||||
assert client.app.state.init_token == _INIT_TOKEN
|
||||
|
||||
|
||||
def test_initialize_register_does_not_block_initialization(client):
|
||||
"""/register creating a user before /initialize doesn't block admin creation."""
|
||||
# Register a regular user first
|
||||
client.post("/api/v1/auth/register", json={"email": "regular@example.com", "password": "Tr0ub4dor3a"})
|
||||
# /initialize should still succeed (checks admin_count, not total user_count)
|
||||
resp = client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||
assert resp.status_code == 201
|
||||
assert resp.json()["system_role"] == "admin"
|
||||
|
||||
|
||||
# ── Endpoint is public (no cookie required) ───────────────────────────────
|
||||
|
||||
|
||||
def test_initialize_accessible_without_cookie(client):
|
||||
"""No access_token cookie needed for /initialize."""
|
||||
resp = client.post(
|
||||
"/api/v1/auth/initialize",
|
||||
json=_init_payload(),
|
||||
cookies={},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
|
||||
# ── Password validation ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_initialize_rejects_short_password(client):
|
||||
"""Password shorter than 8 chars → 422."""
|
||||
resp = client.post(
|
||||
"/api/v1/auth/initialize",
|
||||
json={**_init_payload(), "password": "short"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
def test_initialize_rejects_common_password(client):
|
||||
"""Common password → 422."""
|
||||
resp = client.post(
|
||||
"/api/v1/auth/initialize",
|
||||
json={**_init_payload(), "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ── setup-status reflects initialization ─────────────────────────────────
|
||||
|
||||
|
||||
def test_setup_status_before_initialization(client):
|
||||
"""setup-status returns needs_setup=True before /initialize is called."""
|
||||
resp = client.get("/api/v1/auth/setup-status")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["needs_setup"] is True
|
||||
|
||||
|
||||
def test_setup_status_after_initialization(client):
|
||||
"""setup-status returns needs_setup=False after /initialize succeeds."""
|
||||
client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||
resp = client.get("/api/v1/auth/setup-status")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["needs_setup"] is False
|
||||
|
||||
|
||||
def test_setup_status_false_when_only_regular_user_exists(client):
|
||||
"""setup-status returns needs_setup=True even when regular users exist (no admin)."""
|
||||
client.post("/api/v1/auth/register", json={"email": "regular@example.com", "password": "Tr0ub4dor3a"})
|
||||
resp = client.get("/api/v1/auth/setup-status")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["needs_setup"] is True
|
||||
@@ -152,10 +152,8 @@ def test_get_work_dir_uses_base_dir_when_no_thread_id(monkeypatch, tmp_path):
|
||||
def test_get_work_dir_uses_per_thread_path_when_thread_id_given(monkeypatch, tmp_path):
|
||||
"""P1.1: _get_work_dir(thread_id) uses {base_dir}/threads/{thread_id}/acp-workspace/."""
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.runtime import user_context as uc_module
|
||||
|
||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
|
||||
result = _get_work_dir("thread-abc-123")
|
||||
expected = tmp_path / "threads" / "thread-abc-123" / "acp-workspace"
|
||||
assert result == str(expected)
|
||||
@@ -312,10 +310,8 @@ async def test_invoke_acp_agent_uses_fixed_acp_workspace(monkeypatch, tmp_path):
|
||||
async def test_invoke_acp_agent_uses_per_thread_workspace_when_thread_id_in_config(monkeypatch, tmp_path):
|
||||
"""P1.1: When thread_id is in the RunnableConfig, ACP agent uses per-thread workspace."""
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.runtime import user_context as uc_module
|
||||
|
||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||
|
||||
@@ -175,46 +175,46 @@ def _make_ctx(user_id):
|
||||
def test_filter_injects_user_id():
|
||||
value = {}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["user_id"] == "user-a"
|
||||
assert value["metadata"]["owner_id"] == "user-a"
|
||||
|
||||
|
||||
def test_filter_preserves_existing_metadata():
|
||||
value = {"metadata": {"title": "hello"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["user_id"] == "user-a"
|
||||
assert value["metadata"]["owner_id"] == "user-a"
|
||||
assert value["metadata"]["title"] == "hello"
|
||||
|
||||
|
||||
def test_filter_returns_user_id_dict():
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
||||
assert result == {"user_id": "user-x"}
|
||||
assert result == {"owner_id": "user-x"}
|
||||
|
||||
|
||||
def test_filter_read_write_consistency():
|
||||
value = {}
|
||||
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
||||
assert value["metadata"]["user_id"] == filter_dict["user_id"]
|
||||
assert value["metadata"]["owner_id"] == filter_dict["owner_id"]
|
||||
|
||||
|
||||
def test_different_users_different_filters():
|
||||
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
||||
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
||||
assert f_a["user_id"] != f_b["user_id"]
|
||||
assert f_a["owner_id"] != f_b["owner_id"]
|
||||
|
||||
|
||||
def test_filter_overrides_conflicting_user_id():
|
||||
"""If value already has a different user_id in metadata, it gets overwritten."""
|
||||
value = {"metadata": {"user_id": "attacker"}}
|
||||
value = {"metadata": {"owner_id": "attacker"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
||||
assert value["metadata"]["user_id"] == "real-owner"
|
||||
assert value["metadata"]["owner_id"] == "real-owner"
|
||||
|
||||
|
||||
def test_filter_with_empty_metadata():
|
||||
"""Explicit empty metadata dict is fine."""
|
||||
value = {"metadata": {}}
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
||||
assert value["metadata"]["user_id"] == "user-z"
|
||||
assert result == {"user_id": "user-z"}
|
||||
assert value["metadata"]["owner_id"] == "user-z"
|
||||
assert result == {"owner_id": "user-z"}
|
||||
|
||||
|
||||
# ── Gateway parity ───────────────────────────────────────────────────────
|
||||
|
||||
@@ -48,7 +48,6 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
||||
agent_name="lead_agent",
|
||||
correction_detected=True,
|
||||
reinforcement_detected=False,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -89,5 +88,4 @@ def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
||||
agent_name="lead_agent",
|
||||
correction_detected=False,
|
||||
reinforcement_detected=True,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
"""Tests for user_id propagation through memory queue."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
|
||||
|
||||
def test_conversation_context_has_user_id():
|
||||
ctx = ConversationContext(thread_id="t1", messages=[], user_id="alice")
|
||||
assert ctx.user_id == "alice"
|
||||
|
||||
|
||||
def test_conversation_context_user_id_default_none():
|
||||
ctx = ConversationContext(thread_id="t1", messages=[])
|
||||
assert ctx.user_id is None
|
||||
|
||||
|
||||
def test_queue_add_stores_user_id():
|
||||
q = MemoryUpdateQueue()
|
||||
with patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
assert len(q._queue) == 1
|
||||
assert q._queue[0].user_id == "alice"
|
||||
q.clear()
|
||||
|
||||
|
||||
def test_queue_process_passes_user_id_to_updater():
|
||||
q = MemoryUpdateQueue()
|
||||
with patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.update_memory.return_value = True
|
||||
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
|
||||
q._process_queue()
|
||||
|
||||
mock_updater.update_memory.assert_called_once()
|
||||
call_kwargs = mock_updater.update_memory.call_args.kwargs
|
||||
assert call_kwargs["user_id"] == "alice"
|
||||
@@ -258,13 +258,12 @@ def test_update_memory_fact_route_preserves_omitted_fields() -> None:
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert update_fact.call_count == 1
|
||||
call_kwargs = update_fact.call_args.kwargs
|
||||
assert call_kwargs.get("fact_id") == "fact_edit"
|
||||
assert call_kwargs.get("content") == "User prefers spaces"
|
||||
assert call_kwargs.get("category") is None
|
||||
assert call_kwargs.get("confidence") is None
|
||||
assert "user_id" in call_kwargs
|
||||
update_fact.assert_called_once_with(
|
||||
fact_id="fact_edit",
|
||||
content="User prefers spaces",
|
||||
category=None,
|
||||
confidence=None,
|
||||
)
|
||||
assert response.json()["facts"] == updated_memory["facts"]
|
||||
|
||||
|
||||
|
||||
@@ -1,150 +0,0 @@
|
||||
"""Tests for per-user memory storage isolation."""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from deerflow.agents.memory.storage import FileMemoryStorage, create_empty_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_dir(tmp_path: Path) -> Path:
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage() -> FileMemoryStorage:
|
||||
return FileMemoryStorage()
|
||||
|
||||
|
||||
class TestUserIsolatedStorage:
|
||||
def test_save_and_load_per_user(self, storage: FileMemoryStorage, base_dir: Path):
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
memory_a = create_empty_memory()
|
||||
memory_a["user"]["workContext"]["summary"] = "User A context"
|
||||
storage.save(memory_a, user_id="alice")
|
||||
|
||||
memory_b = create_empty_memory()
|
||||
memory_b["user"]["workContext"]["summary"] = "User B context"
|
||||
storage.save(memory_b, user_id="bob")
|
||||
|
||||
loaded_a = storage.load(user_id="alice")
|
||||
loaded_b = storage.load(user_id="bob")
|
||||
|
||||
assert loaded_a["user"]["workContext"]["summary"] == "User A context"
|
||||
assert loaded_b["user"]["workContext"]["summary"] == "User B context"
|
||||
|
||||
def test_user_memory_file_location(self, base_dir: Path):
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id="alice")
|
||||
expected_path = base_dir / "users" / "alice" / "memory.json"
|
||||
assert expected_path.exists()
|
||||
|
||||
def test_cache_isolated_per_user(self, base_dir: Path):
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
memory_a = create_empty_memory()
|
||||
memory_a["user"]["workContext"]["summary"] = "A"
|
||||
s.save(memory_a, user_id="alice")
|
||||
|
||||
memory_b = create_empty_memory()
|
||||
memory_b["user"]["workContext"]["summary"] = "B"
|
||||
s.save(memory_b, user_id="bob")
|
||||
|
||||
loaded_a = s.load(user_id="alice")
|
||||
assert loaded_a["user"]["workContext"]["summary"] == "A"
|
||||
|
||||
def test_no_user_id_uses_legacy_path(self, base_dir: Path):
|
||||
from deerflow.config.paths import Paths
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
s = FileMemoryStorage()
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id=None)
|
||||
expected_path = base_dir / "memory.json"
|
||||
assert expected_path.exists()
|
||||
|
||||
def test_user_and_legacy_do_not_interfere(self, base_dir: Path):
|
||||
"""user_id=None (legacy) and user_id='alice' must use different files and caches."""
|
||||
from deerflow.config.paths import Paths
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||
s = FileMemoryStorage()
|
||||
|
||||
legacy_mem = create_empty_memory()
|
||||
legacy_mem["user"]["workContext"]["summary"] = "legacy"
|
||||
s.save(legacy_mem, user_id=None)
|
||||
|
||||
user_mem = create_empty_memory()
|
||||
user_mem["user"]["workContext"]["summary"] = "alice"
|
||||
s.save(user_mem, user_id="alice")
|
||||
|
||||
assert s.load(user_id=None)["user"]["workContext"]["summary"] == "legacy"
|
||||
assert s.load(user_id="alice")["user"]["workContext"]["summary"] == "alice"
|
||||
|
||||
def test_user_agent_memory_file_location(self, base_dir: Path):
|
||||
"""Per-user per-agent memory uses the user_agent_memory_file path."""
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
memory = create_empty_memory()
|
||||
memory["user"]["workContext"]["summary"] = "agent scoped"
|
||||
s.save(memory, "test-agent", user_id="alice")
|
||||
expected_path = base_dir / "users" / "alice" / "agents" / "test-agent" / "memory.json"
|
||||
assert expected_path.exists()
|
||||
|
||||
def test_cache_key_is_user_agent_tuple(self, base_dir: Path):
|
||||
"""Cache keys must be (user_id, agent_name) tuples, not bare agent names."""
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
memory = create_empty_memory()
|
||||
s.save(memory, user_id="alice")
|
||||
# After save, cache should have tuple key
|
||||
assert ("alice", None) in s._memory_cache
|
||||
|
||||
def test_reload_with_user_id(self, base_dir: Path):
|
||||
"""reload() with user_id should force re-read from the user-scoped file."""
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
paths = Paths(base_dir)
|
||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||
s = FileMemoryStorage()
|
||||
memory = create_empty_memory()
|
||||
memory["user"]["workContext"]["summary"] = "initial"
|
||||
s.save(memory, user_id="alice")
|
||||
|
||||
# Load once to prime cache
|
||||
s.load(user_id="alice")
|
||||
|
||||
# Write updated content directly to file
|
||||
user_file = base_dir / "users" / "alice" / "memory.json"
|
||||
import json
|
||||
|
||||
updated = create_empty_memory()
|
||||
updated["user"]["workContext"]["summary"] = "updated"
|
||||
user_file.write_text(json.dumps(updated))
|
||||
|
||||
# reload should pick up the new content
|
||||
reloaded = s.reload(user_id="alice")
|
||||
assert reloaded["user"]["workContext"]["summary"] == "updated"
|
||||
@@ -1,156 +0,0 @@
|
||||
"""Owner isolation tests for MemoryThreadMetaStore.
|
||||
|
||||
Mirrors the SQL-backed tests in test_owner_isolation.py but exercises
|
||||
the in-memory LangGraph Store backend used when database.backend=memory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||
|
||||
USER_A = SimpleNamespace(id="user-a", email="a@test.local")
|
||||
USER_B = SimpleNamespace(id="user-b", email="b@test.local")
|
||||
|
||||
|
||||
def _as_user(user):
|
||||
class _Ctx:
|
||||
def __enter__(self):
|
||||
self._token = set_current_user(user)
|
||||
return user
|
||||
|
||||
def __exit__(self, *exc):
|
||||
reset_current_user(self._token)
|
||||
|
||||
return _Ctx()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store():
|
||||
return MemoryThreadMetaStore(InMemoryStore())
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_search_isolation(store):
|
||||
"""search() returns only threads owned by the current user."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", display_name="A's thread")
|
||||
with _as_user(USER_B):
|
||||
await store.create("t-beta", display_name="B's thread")
|
||||
|
||||
with _as_user(USER_A):
|
||||
results = await store.search()
|
||||
assert [r["thread_id"] for r in results] == ["t-alpha"]
|
||||
|
||||
with _as_user(USER_B):
|
||||
results = await store.search()
|
||||
assert [r["thread_id"] for r in results] == ["t-beta"]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_get_isolation(store):
|
||||
"""get() returns None for threads owned by another user."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", display_name="A's thread")
|
||||
|
||||
with _as_user(USER_B):
|
||||
assert await store.get("t-alpha") is None
|
||||
|
||||
with _as_user(USER_A):
|
||||
result = await store.get("t-alpha")
|
||||
assert result is not None
|
||||
assert result["display_name"] == "A's thread"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_update_display_name_denied(store):
|
||||
"""User B cannot rename User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", display_name="original")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.update_display_name("t-alpha", "hacked")
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["display_name"] == "original"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_update_status_denied(store):
|
||||
"""User B cannot change status of User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.update_status("t-alpha", "error")
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["status"] == "idle"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_update_metadata_denied(store):
|
||||
"""User B cannot modify metadata of User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha", metadata={"key": "original"})
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.update_metadata("t-alpha", {"key": "hacked"})
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
assert row["metadata"]["key"] == "original"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_delete_denied(store):
|
||||
"""User B cannot delete User A's thread."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha")
|
||||
|
||||
with _as_user(USER_B):
|
||||
await store.delete("t-alpha")
|
||||
|
||||
with _as_user(USER_A):
|
||||
row = await store.get("t-alpha")
|
||||
assert row is not None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_no_context_raises(store):
|
||||
"""Calling methods without user context raises RuntimeError."""
|
||||
with pytest.raises(RuntimeError, match="no user context is set"):
|
||||
await store.search()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_explicit_none_bypasses_filter(store):
|
||||
"""user_id=None bypasses isolation (migration/CLI escape hatch)."""
|
||||
with _as_user(USER_A):
|
||||
await store.create("t-alpha")
|
||||
with _as_user(USER_B):
|
||||
await store.create("t-beta")
|
||||
|
||||
all_rows = await store.search(user_id=None)
|
||||
assert {r["thread_id"] for r in all_rows} == {"t-alpha", "t-beta"}
|
||||
|
||||
row = await store.get("t-alpha", user_id=None)
|
||||
assert row is not None
|
||||
@@ -301,8 +301,8 @@ def test_import_memory_data_saves_and_returns_imported_memory() -> None:
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
result = import_memory_data(imported_memory)
|
||||
|
||||
mock_storage.save.assert_called_once_with(imported_memory, None, user_id=None)
|
||||
mock_storage.load.assert_called_once_with(None, user_id=None)
|
||||
mock_storage.save.assert_called_once_with(imported_memory, None)
|
||||
mock_storage.load.assert_called_once_with(None)
|
||||
assert result == imported_memory
|
||||
|
||||
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Tests for user_id propagation in memory updater."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.agents.memory.updater import get_memory_data, clear_memory_data, _save_memory_to_file
|
||||
|
||||
|
||||
def test_get_memory_data_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.load.return_value = {"version": "1.0"}
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
get_memory_data(user_id="alice")
|
||||
mock_storage.load.assert_called_once_with(None, user_id="alice")
|
||||
|
||||
|
||||
def test_save_memory_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save.return_value = True
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
_save_memory_to_file({"version": "1.0"}, user_id="bob")
|
||||
mock_storage.save.assert_called_once_with({"version": "1.0"}, None, user_id="bob")
|
||||
|
||||
|
||||
def test_clear_memory_data_passes_user_id():
|
||||
mock_storage = MagicMock()
|
||||
mock_storage.save.return_value = True
|
||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||
clear_memory_data(user_id="charlie")
|
||||
# Verify save was called with user_id
|
||||
assert mock_storage.save.call_args.kwargs["user_id"] == "charlie"
|
||||
@@ -1,116 +0,0 @@
|
||||
"""Tests for per-user data migration."""
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_dir(tmp_path: Path) -> Path:
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def paths(base_dir: Path) -> Paths:
|
||||
return Paths(base_dir)
|
||||
|
||||
|
||||
class TestMigrateThreadDirs:
|
||||
def test_moves_thread_to_user_dir(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t1" / "user-data" / "workspace"
|
||||
legacy.mkdir(parents=True)
|
||||
(legacy / "file.txt").write_text("hello")
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||
|
||||
expected = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace" / "file.txt"
|
||||
assert expected.exists()
|
||||
assert expected.read_text() == "hello"
|
||||
assert not (base_dir / "threads" / "t1").exists()
|
||||
|
||||
def test_unowned_thread_goes_to_default(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t2" / "user-data" / "workspace"
|
||||
legacy.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={})
|
||||
|
||||
expected = base_dir / "users" / "default" / "threads" / "t2"
|
||||
assert expected.exists()
|
||||
|
||||
def test_idempotent_skip_already_migrated(self, base_dir: Path, paths: Paths):
|
||||
new_dir = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace"
|
||||
new_dir.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||
assert new_dir.exists()
|
||||
|
||||
def test_conflict_preserved(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t1" / "user-data" / "workspace"
|
||||
legacy.mkdir(parents=True)
|
||||
(legacy / "old.txt").write_text("old")
|
||||
|
||||
dest = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace"
|
||||
dest.mkdir(parents=True)
|
||||
(dest / "new.txt").write_text("new")
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||
|
||||
assert (dest / "new.txt").read_text() == "new"
|
||||
conflicts = base_dir / "migration-conflicts" / "t1"
|
||||
assert conflicts.exists()
|
||||
|
||||
def test_cleans_up_empty_legacy_dir(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t1" / "user-data"
|
||||
legacy.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
migrate_thread_dirs(paths, thread_owner_map={})
|
||||
|
||||
assert not (base_dir / "threads").exists()
|
||||
|
||||
def test_dry_run_does_not_move(self, base_dir: Path, paths: Paths):
|
||||
legacy = base_dir / "threads" / "t1" / "user-data"
|
||||
legacy.mkdir(parents=True)
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||
report = migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"}, dry_run=True)
|
||||
|
||||
assert len(report) == 1
|
||||
assert (base_dir / "threads" / "t1").exists() # not moved
|
||||
assert not (base_dir / "users" / "alice" / "threads" / "t1").exists()
|
||||
|
||||
|
||||
class TestMigrateMemory:
|
||||
def test_moves_global_memory(self, base_dir: Path, paths: Paths):
|
||||
legacy_mem = base_dir / "memory.json"
|
||||
legacy_mem.write_text(json.dumps({"version": "1.0", "facts": []}))
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_memory
|
||||
migrate_memory(paths, user_id="default")
|
||||
|
||||
expected = base_dir / "users" / "default" / "memory.json"
|
||||
assert expected.exists()
|
||||
assert not legacy_mem.exists()
|
||||
|
||||
def test_skips_if_destination_exists(self, base_dir: Path, paths: Paths):
|
||||
legacy_mem = base_dir / "memory.json"
|
||||
legacy_mem.write_text(json.dumps({"version": "old"}))
|
||||
|
||||
dest = base_dir / "users" / "default" / "memory.json"
|
||||
dest.parent.mkdir(parents=True)
|
||||
dest.write_text(json.dumps({"version": "new"}))
|
||||
|
||||
from scripts.migrate_user_isolation import migrate_memory
|
||||
migrate_memory(paths, user_id="default")
|
||||
|
||||
assert json.loads(dest.read_text())["version"] == "new"
|
||||
assert (base_dir / "memory.legacy.json").exists()
|
||||
|
||||
def test_no_legacy_memory_is_noop(self, base_dir: Path, paths: Paths):
|
||||
from scripts.migrate_user_isolation import migrate_memory
|
||||
migrate_memory(paths, user_id="default") # should not raise
|
||||
@@ -9,8 +9,8 @@ These tests bypass the HTTP layer and exercise the storage-layer
|
||||
owner filter directly by switching the ``user_context`` contextvar
|
||||
between two users. The safety property under test is:
|
||||
|
||||
After a repository write with user_id=A, a subsequent read with
|
||||
user_id=B must not return the row, and vice versa.
|
||||
After a repository write with owner_id=A, a subsequent read with
|
||||
owner_id=B must not return the row, and vice versa.
|
||||
|
||||
The HTTP layer is covered by test_auth_middleware.py, which proves
|
||||
that a request cookie reaches the ``set_current_user`` call. Together
|
||||
@@ -431,13 +431,13 @@ async def test_repository_without_context_raises(tmp_path):
|
||||
await cleanup()
|
||||
|
||||
|
||||
# ── Escape hatch: explicit user_id=None bypasses filter (for migration) ──
|
||||
# ── Escape hatch: explicit owner_id=None bypasses filter (for migration) ──
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.no_auto_user
|
||||
async def test_explicit_none_bypasses_filter(tmp_path):
|
||||
"""Migration scripts pass user_id=None to see all rows regardless of owner."""
|
||||
"""Migration scripts pass owner_id=None to see all rows regardless of owner."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||
|
||||
@@ -452,14 +452,14 @@ async def test_explicit_none_bypasses_filter(tmp_path):
|
||||
await repo.create("t-beta")
|
||||
|
||||
# Migration-style read: no contextvar, explicit None bypass.
|
||||
all_rows = await repo.search(user_id=None)
|
||||
all_rows = await repo.search(owner_id=None)
|
||||
thread_ids = {r["thread_id"] for r in all_rows}
|
||||
assert thread_ids == {"t-alpha", "t-beta"}
|
||||
|
||||
# Explicit get with None does not apply the filter either.
|
||||
row_a = await repo.get("t-alpha", user_id=None)
|
||||
row_a = await repo.get("t-alpha", owner_id=None)
|
||||
assert row_a is not None
|
||||
row_b = await repo.get("t-beta", user_id=None)
|
||||
row_b = await repo.get("t-beta", owner_id=None)
|
||||
assert row_b is not None
|
||||
finally:
|
||||
await cleanup()
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
"""Tests for user-scoped path resolution in Paths."""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def paths(tmp_path: Path) -> Paths:
|
||||
return Paths(tmp_path)
|
||||
|
||||
|
||||
class TestValidateUserId:
|
||||
def test_valid_user_id(self, paths: Paths):
|
||||
d = paths.user_dir("u-abc-123")
|
||||
assert d == paths.base_dir / "users" / "u-abc-123"
|
||||
|
||||
def test_rejects_path_traversal(self, paths: Paths):
|
||||
with pytest.raises(ValueError, match="Invalid user_id"):
|
||||
paths.user_dir("../escape")
|
||||
|
||||
def test_rejects_slash(self, paths: Paths):
|
||||
with pytest.raises(ValueError, match="Invalid user_id"):
|
||||
paths.user_dir("foo/bar")
|
||||
|
||||
def test_rejects_empty(self, paths: Paths):
|
||||
with pytest.raises(ValueError, match="Invalid user_id"):
|
||||
paths.user_dir("")
|
||||
|
||||
|
||||
class TestUserDir:
|
||||
def test_user_dir(self, paths: Paths):
|
||||
assert paths.user_dir("alice") == paths.base_dir / "users" / "alice"
|
||||
|
||||
|
||||
class TestUserMemoryFile:
|
||||
def test_user_memory_file(self, paths: Paths):
|
||||
assert paths.user_memory_file("bob") == paths.base_dir / "users" / "bob" / "memory.json"
|
||||
|
||||
|
||||
class TestUserAgentMemoryFile:
|
||||
def test_user_agent_memory_file(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "bob" / "agents" / "myagent" / "memory.json"
|
||||
assert paths.user_agent_memory_file("bob", "myagent") == expected
|
||||
|
||||
def test_user_agent_memory_file_lowercases_name(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "bob" / "agents" / "myagent" / "memory.json"
|
||||
assert paths.user_agent_memory_file("bob", "MyAgent") == expected
|
||||
|
||||
|
||||
class TestUserThreadDir:
|
||||
def test_user_thread_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1"
|
||||
assert paths.thread_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_thread_dir_no_user_id_falls_back_to_legacy(self, paths: Paths):
|
||||
expected = paths.base_dir / "threads" / "t1"
|
||||
assert paths.thread_dir("t1") == expected
|
||||
|
||||
|
||||
class TestUserSandboxDirs:
|
||||
def test_sandbox_work_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "workspace"
|
||||
assert paths.sandbox_work_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_sandbox_uploads_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "uploads"
|
||||
assert paths.sandbox_uploads_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_sandbox_outputs_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "outputs"
|
||||
assert paths.sandbox_outputs_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_sandbox_user_data_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data"
|
||||
assert paths.sandbox_user_data_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_acp_workspace_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "acp-workspace"
|
||||
assert paths.acp_workspace_dir("t1", user_id="u1") == expected
|
||||
|
||||
def test_legacy_sandbox_work_dir(self, paths: Paths):
|
||||
expected = paths.base_dir / "threads" / "t1" / "user-data" / "workspace"
|
||||
assert paths.sandbox_work_dir("t1") == expected
|
||||
|
||||
|
||||
class TestHostPathsWithUserId:
|
||||
def test_host_thread_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_thread_dir("t1", user_id="u1")
|
||||
assert "users" in result
|
||||
assert "u1" in result
|
||||
assert "threads" in result
|
||||
assert "t1" in result
|
||||
|
||||
def test_host_thread_dir_legacy(self, paths: Paths):
|
||||
result = paths.host_thread_dir("t1")
|
||||
assert "threads" in result
|
||||
assert "t1" in result
|
||||
assert "users" not in result
|
||||
|
||||
def test_host_sandbox_user_data_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_sandbox_user_data_dir("t1", user_id="u1")
|
||||
assert "users" in result
|
||||
assert "user-data" in result
|
||||
|
||||
def test_host_sandbox_work_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_sandbox_work_dir("t1", user_id="u1")
|
||||
assert "workspace" in result
|
||||
|
||||
def test_host_sandbox_uploads_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_sandbox_uploads_dir("t1", user_id="u1")
|
||||
assert "uploads" in result
|
||||
|
||||
def test_host_sandbox_outputs_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_sandbox_outputs_dir("t1", user_id="u1")
|
||||
assert "outputs" in result
|
||||
|
||||
def test_host_acp_workspace_dir_with_user_id(self, paths: Paths):
|
||||
result = paths.host_acp_workspace_dir("t1", user_id="u1")
|
||||
assert "acp-workspace" in result
|
||||
|
||||
|
||||
class TestEnsureAndDeleteWithUserId:
|
||||
def test_ensure_thread_dirs_creates_user_scoped(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||
assert paths.sandbox_work_dir("t1", user_id="u1").is_dir()
|
||||
assert paths.sandbox_uploads_dir("t1", user_id="u1").is_dir()
|
||||
assert paths.sandbox_outputs_dir("t1", user_id="u1").is_dir()
|
||||
assert paths.acp_workspace_dir("t1", user_id="u1").is_dir()
|
||||
|
||||
def test_delete_thread_dir_removes_user_scoped(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||
assert paths.thread_dir("t1", user_id="u1").exists()
|
||||
paths.delete_thread_dir("t1", user_id="u1")
|
||||
assert not paths.thread_dir("t1", user_id="u1").exists()
|
||||
|
||||
def test_delete_thread_dir_idempotent(self, paths: Paths):
|
||||
paths.delete_thread_dir("nonexistent", user_id="u1") # should not raise
|
||||
|
||||
def test_ensure_thread_dirs_legacy_still_works(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1")
|
||||
assert paths.sandbox_work_dir("t1").is_dir()
|
||||
|
||||
def test_user_scoped_and_legacy_are_independent(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||
paths.ensure_thread_dirs("t1")
|
||||
# Both exist independently
|
||||
assert paths.thread_dir("t1", user_id="u1").exists()
|
||||
assert paths.thread_dir("t1").exists()
|
||||
# Delete one doesn't affect the other
|
||||
paths.delete_thread_dir("t1", user_id="u1")
|
||||
assert not paths.thread_dir("t1", user_id="u1").exists()
|
||||
assert paths.thread_dir("t1").exists()
|
||||
|
||||
|
||||
class TestResolveVirtualPathWithUserId:
|
||||
def test_resolve_virtual_path_with_user_id(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||
result = paths.resolve_virtual_path("t1", "/mnt/user-data/workspace/file.txt", user_id="u1")
|
||||
expected_base = paths.sandbox_user_data_dir("t1", user_id="u1").resolve()
|
||||
assert str(result).startswith(str(expected_base))
|
||||
|
||||
def test_resolve_virtual_path_legacy(self, paths: Paths):
|
||||
paths.ensure_thread_dirs("t1")
|
||||
result = paths.resolve_virtual_path("t1", "/mnt/user-data/workspace/file.txt")
|
||||
expected_base = paths.sandbox_user_data_dir("t1").resolve()
|
||||
assert str(result).startswith(str(expected_base))
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Tests:
|
||||
1. DatabaseConfig property derivation (paths, URLs)
|
||||
2. MemoryRunStore CRUD + user_id filtering
|
||||
2. MemoryRunStore CRUD + owner_id filtering
|
||||
3. Base.to_dict() via inspect mixin
|
||||
4. Engine init/close lifecycle (memory + SQLite)
|
||||
5. Postgres missing-dep error message
|
||||
@@ -24,19 +24,18 @@ class TestDatabaseConfig:
|
||||
assert c.backend == "memory"
|
||||
assert c.pool_size == 5
|
||||
|
||||
def test_sqlite_paths_unified(self):
|
||||
def test_sqlite_paths_are_different(self):
|
||||
c = DatabaseConfig(backend="sqlite", sqlite_dir="./mydata")
|
||||
assert c.sqlite_path.endswith("deerflow.db")
|
||||
assert "mydata" in c.sqlite_path
|
||||
# Backward-compatible aliases point to the same file
|
||||
assert c.checkpointer_sqlite_path == c.sqlite_path
|
||||
assert c.app_sqlite_path == c.sqlite_path
|
||||
assert c.checkpointer_sqlite_path.endswith("checkpoints.db")
|
||||
assert c.app_sqlite_path.endswith("app.db")
|
||||
assert "mydata" in c.checkpointer_sqlite_path
|
||||
assert c.checkpointer_sqlite_path != c.app_sqlite_path
|
||||
|
||||
def test_app_sqlalchemy_url_sqlite(self):
|
||||
c = DatabaseConfig(backend="sqlite", sqlite_dir="./data")
|
||||
url = c.app_sqlalchemy_url
|
||||
assert url.startswith("sqlite+aiosqlite:///")
|
||||
assert "deerflow.db" in url
|
||||
assert "app.db" in url
|
||||
|
||||
def test_app_sqlalchemy_url_postgres(self):
|
||||
c = DatabaseConfig(
|
||||
@@ -106,17 +105,17 @@ class TestMemoryRunStore:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_owner_filter(self, store):
|
||||
await store.put("r1", thread_id="t1", user_id="alice")
|
||||
await store.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await store.list_by_thread("t1", user_id="alice")
|
||||
await store.put("r1", thread_id="t1", owner_id="alice")
|
||||
await store.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await store.list_by_thread("t1", owner_id="alice")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["user_id"] == "alice"
|
||||
assert rows[0]["owner_id"] == "alice"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_owner_none_returns_all(self, store):
|
||||
await store.put("r1", thread_id="t1", user_id="alice")
|
||||
await store.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await store.list_by_thread("t1", user_id=None)
|
||||
await store.put("r1", thread_id="t1", owner_id="alice")
|
||||
await store.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await store.list_by_thread("t1", owner_id=None)
|
||||
assert len(rows) == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
@@ -38,7 +38,7 @@ def test_present_files_keeps_virtual_outputs_path(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
present_file_tool_module,
|
||||
"get_paths",
|
||||
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path, *, user_id=None: artifact_path),
|
||||
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path: artifact_path),
|
||||
)
|
||||
|
||||
result = present_file_tool_module.present_file_tool.func(
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
"""Tests for paginated list_messages_by_run across all RunEventStore backends."""
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_store():
|
||||
return MemoryRunEventStore()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_default_returns_all(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
for i in range(3):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-b",
|
||||
event_type="human_message", category="message", content=f"msg-b-{i}",
|
||||
)
|
||||
await store.put(thread_id="t1", run_id="run-a", event_type="tool_call", category="trace", content="trace")
|
||||
|
||||
msgs = await store.list_messages_by_run("t1", "run-a")
|
||||
assert len(msgs) == 7
|
||||
assert all(m["category"] == "message" for m in msgs)
|
||||
assert all(m["run_id"] == "run-a" for m in msgs)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_with_limit(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
msgs = await store.list_messages_by_run("t1", "run-a", limit=3)
|
||||
assert len(msgs) == 3
|
||||
seqs = [m["seq"] for m in msgs]
|
||||
assert seqs == sorted(seqs)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_after_seq(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
all_msgs = await store.list_messages_by_run("t1", "run-a")
|
||||
cursor_seq = all_msgs[2]["seq"]
|
||||
msgs = await store.list_messages_by_run("t1", "run-a", after_seq=cursor_seq, limit=50)
|
||||
assert all(m["seq"] > cursor_seq for m in msgs)
|
||||
assert len(msgs) == 4
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_before_seq(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||
category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
|
||||
all_msgs = await store.list_messages_by_run("t1", "run-a")
|
||||
cursor_seq = all_msgs[4]["seq"]
|
||||
msgs = await store.list_messages_by_run("t1", "run-a", before_seq=cursor_seq, limit=50)
|
||||
assert all(m["seq"] < cursor_seq for m in msgs)
|
||||
assert len(msgs) == 4
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_does_not_include_other_run(base_store):
|
||||
store = base_store
|
||||
for i in range(7):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-a",
|
||||
event_type="human_message", category="message", content=f"msg-a-{i}",
|
||||
)
|
||||
for i in range(3):
|
||||
await store.put(
|
||||
thread_id="t1", run_id="run-b",
|
||||
event_type="human_message", category="message", content=f"msg-b-{i}",
|
||||
)
|
||||
|
||||
msgs = await store.list_messages_by_run("t1", "run-b")
|
||||
assert len(msgs) == 3
|
||||
assert all(m["run_id"] == "run-b" for m in msgs)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_messages_by_run_empty_run(base_store):
|
||||
store = base_store
|
||||
msgs = await store.list_messages_by_run("t1", "nonexistent")
|
||||
assert msgs == []
|
||||
@@ -709,81 +709,6 @@ class TestToolResultMessage:
|
||||
assert tool_end["metadata"]["tool_call_id"] == "call_from_obj"
|
||||
assert tool_end["metadata"]["tool_name"] == "web_search"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_invoke_end_to_end_unwraps_command(self, journal_setup):
|
||||
"""End-to-end: invoke a real LangChain tool that returns Command(update={'messages':[ToolMessage]}).
|
||||
|
||||
This goes through the real LangChain callback path (tool.invoke -> CallbackManager
|
||||
-> on_tool_start/on_tool_end), which is what the production agent uses. Mirrors
|
||||
the ``present_files`` tool shape exactly.
|
||||
"""
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
|
||||
j, store = journal_setup
|
||||
|
||||
@tool
|
||||
def fake_present_files(filepaths: list[str]) -> Command:
|
||||
"""Fake present_files that returns a Command with an inner ToolMessage."""
|
||||
return Command(
|
||||
update={
|
||||
"artifacts": filepaths,
|
||||
"messages": [ToolMessage("Successfully presented files", tool_call_id="tc_123")],
|
||||
},
|
||||
)
|
||||
|
||||
# Real LangChain callback dispatch (matches production agent path)
|
||||
cm = CallbackManager(handlers=[j])
|
||||
fake_present_files.invoke(
|
||||
{"filepaths": ["/mnt/user-data/outputs/report.md"]},
|
||||
config={"callbacks": cm, "run_id": uuid4()},
|
||||
)
|
||||
await j.flush()
|
||||
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1, f"expected 1 message event, got {len(messages)}: {messages}"
|
||||
content = messages[0]["content"]
|
||||
assert content["type"] == "tool"
|
||||
# CRITICAL: must be the inner ToolMessage text, not str(Command(...))
|
||||
assert content["content"] == "Successfully presented files", (
|
||||
f"Command unwrap failed; stored content = {content['content']!r}"
|
||||
)
|
||||
assert "Command(update=" not in str(content["content"])
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_end_unwraps_command_with_inner_tool_message(self, journal_setup):
|
||||
"""Tools like ``present_files`` return Command(update={'messages': [ToolMessage(...)]}).
|
||||
|
||||
LangGraph unwraps the inner ToolMessage into checkpoint state, so the
|
||||
event store must do the same — otherwise it captures ``str(Command(...))``
|
||||
and the /history response diverges from the real rendered message.
|
||||
"""
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
inner = ToolMessage(
|
||||
content="Successfully presented files",
|
||||
tool_call_id="call_present",
|
||||
name="present_files",
|
||||
status="success",
|
||||
)
|
||||
cmd = Command(update={"artifacts": ["/mnt/user-data/outputs/report.md"], "messages": [inner]})
|
||||
j.on_tool_end(cmd, run_id=run_id)
|
||||
await j.flush()
|
||||
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
content = messages[0]["content"]
|
||||
assert content["type"] == "tool"
|
||||
assert content["content"] == "Successfully presented files"
|
||||
assert content["tool_call_id"] == "call_present"
|
||||
assert content["name"] == "present_files"
|
||||
assert "Command(update=" not in str(content["content"])
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_message_object_overrides_kwargs(self, journal_setup):
|
||||
"""ToolMessage object fields take priority over kwargs."""
|
||||
|
||||
@@ -73,11 +73,11 @@ class TestRunRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_owner_filter(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", user_id="alice")
|
||||
await repo.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await repo.list_by_thread("t1", user_id="alice")
|
||||
await repo.put("r1", thread_id="t1", owner_id="alice")
|
||||
await repo.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await repo.list_by_thread("t1", owner_id="alice")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["user_id"] == "alice"
|
||||
assert rows[0]["owner_id"] == "alice"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -189,8 +189,8 @@ class TestRunRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_owner_none_returns_all(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", user_id="alice")
|
||||
await repo.put("r2", thread_id="t1", user_id="bob")
|
||||
rows = await repo.list_by_thread("t1", user_id=None)
|
||||
await repo.put("r1", thread_id="t1", owner_id="alice")
|
||||
await repo.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await repo.list_by_thread("t1", owner_id=None)
|
||||
assert len(rows) == 2
|
||||
await _cleanup()
|
||||
|
||||
@@ -1,243 +0,0 @@
|
||||
"""Tests for GET /api/runs/{run_id}/messages and GET /api/runs/{run_id}/feedback endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import runs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_app(run_store=None, event_store=None, feedback_repo=None):
|
||||
"""Build a test FastAPI app with stub auth and mocked state."""
|
||||
app = make_authed_test_app()
|
||||
app.include_router(runs.router)
|
||||
|
||||
if run_store is not None:
|
||||
app.state.run_store = run_store
|
||||
if event_store is not None:
|
||||
app.state.run_event_store = event_store
|
||||
if feedback_repo is not None:
|
||||
app.state.feedback_repo = feedback_repo
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _make_run_store(run_record: dict | None):
|
||||
"""Return an AsyncMock run store whose get() returns run_record."""
|
||||
store = MagicMock()
|
||||
store.get = AsyncMock(return_value=run_record)
|
||||
return store
|
||||
|
||||
|
||||
def _make_event_store(rows: list[dict]):
|
||||
"""Return an AsyncMock event store whose list_messages_by_run() returns rows."""
|
||||
store = MagicMock()
|
||||
store.list_messages_by_run = AsyncMock(return_value=rows)
|
||||
return store
|
||||
|
||||
|
||||
def _make_message(seq: int) -> dict:
|
||||
return {"seq": seq, "event_type": "on_chat_model_stream", "category": "message", "content": f"msg-{seq}"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_messages_returns_envelope():
|
||||
"""GET /api/runs/{run_id}/messages returns {data: [...], has_more: bool}."""
|
||||
rows = [_make_message(i) for i in range(1, 4)]
|
||||
run_record = {"run_id": "run-1", "thread_id": "thread-1"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=_make_event_store(rows),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-1/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "data" in body
|
||||
assert "has_more" in body
|
||||
assert body["has_more"] is False
|
||||
assert len(body["data"]) == 3
|
||||
|
||||
|
||||
def test_run_messages_404_when_run_not_found():
|
||||
"""Returns 404 when the run store returns None."""
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(None),
|
||||
event_store=_make_event_store([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/missing-run/messages")
|
||||
assert response.status_code == 404
|
||||
assert "missing-run" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_run_messages_has_more_true_when_extra_row_returned():
|
||||
"""has_more=True when event store returns limit+1 rows."""
|
||||
# Default limit is 50; provide 51 rows
|
||||
rows = [_make_message(i) for i in range(1, 52)] # 51 rows
|
||||
run_record = {"run_id": "run-2", "thread_id": "thread-2"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=_make_event_store(rows),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-2/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["has_more"] is True
|
||||
assert len(body["data"]) == 50 # trimmed to limit
|
||||
|
||||
|
||||
def test_run_messages_passes_after_seq_to_event_store():
|
||||
"""after_seq query param is forwarded to event_store.list_messages_by_run."""
|
||||
rows = [_make_message(10)]
|
||||
run_record = {"run_id": "run-3", "thread_id": "thread-3"}
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=event_store,
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-3/messages?after_seq=5")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-3", "run-3",
|
||||
limit=51, # default limit(50) + 1
|
||||
before_seq=None,
|
||||
after_seq=5,
|
||||
)
|
||||
|
||||
|
||||
def test_run_messages_respects_custom_limit():
|
||||
"""Custom limit is respected and capped at 200."""
|
||||
rows = [_make_message(i) for i in range(1, 6)]
|
||||
run_record = {"run_id": "run-4", "thread_id": "thread-4"}
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=event_store,
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-4/messages?limit=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-4", "run-4",
|
||||
limit=11, # 10 + 1
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
)
|
||||
|
||||
|
||||
def test_run_messages_passes_before_seq_to_event_store():
|
||||
"""before_seq query param is forwarded to event_store.list_messages_by_run."""
|
||||
rows = [_make_message(3)]
|
||||
run_record = {"run_id": "run-5", "thread_id": "thread-5"}
|
||||
event_store = _make_event_store(rows)
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=event_store,
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-5/messages?before_seq=10")
|
||||
assert response.status_code == 200
|
||||
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||
"thread-5", "run-5",
|
||||
limit=51,
|
||||
before_seq=10,
|
||||
after_seq=None,
|
||||
)
|
||||
|
||||
|
||||
def test_run_messages_empty_data():
|
||||
"""Returns empty data list when no messages exist."""
|
||||
run_record = {"run_id": "run-6", "thread_id": "thread-6"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
event_store=_make_event_store([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-6/messages")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["data"] == []
|
||||
assert body["has_more"] is False
|
||||
|
||||
|
||||
def _make_feedback_repo(rows: list[dict]):
|
||||
"""Return an AsyncMock feedback repo whose list_by_run() returns rows."""
|
||||
repo = MagicMock()
|
||||
repo.list_by_run = AsyncMock(return_value=rows)
|
||||
return repo
|
||||
|
||||
|
||||
def _make_feedback(run_id: str, idx: int) -> dict:
|
||||
return {"id": f"fb-{idx}", "run_id": run_id, "thread_id": "thread-x", "value": "up"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestRunFeedback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunFeedback:
|
||||
def test_returns_list_of_feedback_dicts(self):
|
||||
"""GET /api/runs/{run_id}/feedback returns a list of feedback dicts."""
|
||||
run_record = {"run_id": "run-fb-1", "thread_id": "thread-fb-1"}
|
||||
rows = [_make_feedback("run-fb-1", i) for i in range(3)]
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
feedback_repo=_make_feedback_repo(rows),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-fb-1/feedback")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert isinstance(body, list)
|
||||
assert len(body) == 3
|
||||
|
||||
def test_404_when_run_not_found(self):
|
||||
"""Returns 404 when run store returns None."""
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(None),
|
||||
feedback_repo=_make_feedback_repo([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/missing-run/feedback")
|
||||
assert response.status_code == 404
|
||||
assert "missing-run" in response.json()["detail"]
|
||||
|
||||
def test_empty_list_when_no_feedback(self):
|
||||
"""Returns empty list when no feedback exists for the run."""
|
||||
run_record = {"run_id": "run-fb-2", "thread_id": "thread-fb-2"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
feedback_repo=_make_feedback_repo([]),
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-fb-2/feedback")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
def test_503_when_feedback_repo_not_configured(self):
|
||||
"""Returns 503 when feedback_repo is None (no DB configured)."""
|
||||
run_record = {"run_id": "run-fb-3", "thread_id": "thread-fb-3"}
|
||||
app = _make_app(
|
||||
run_store=_make_run_store(run_record),
|
||||
)
|
||||
# Explicitly set feedback_repo to None to simulate missing DB
|
||||
app.state.feedback_repo = None
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/runs/run-fb-3/feedback")
|
||||
assert response.status_code == 503
|
||||
@@ -47,7 +47,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
# thread_meta_repo) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2", "Q3"]
|
||||
@@ -67,7 +67,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
# thread_meta_repo) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2"]
|
||||
@@ -87,7 +87,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
# thread_meta_repo) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == ["Q1", "Q2"]
|
||||
@@ -104,7 +104,7 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
# Bypass the require_permission decorator (which needs request +
|
||||
# thread_store) — these tests cover the parsing logic.
|
||||
# thread_meta_repo) — these tests cover the parsing logic.
|
||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||
|
||||
assert result.suggestions == []
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user