mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 23:46:50 +00:00
Compare commits
41 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d1bcae69b9 | |||
| 597fb0e5f9 | |||
| c38b3a9280 | |||
| cbbe39d28c | |||
| 82374eb18c | |||
| a36186cf54 | |||
| 9f28115889 | |||
| 7ce9333200 | |||
| 9af2f3e73c | |||
| dfa9fc47b3 | |||
| 3877aabcfd | |||
| e8f087cb37 | |||
| 3540e157f1 | |||
| 8f7eb28c0d | |||
| 500cdfc8e4 | |||
| 3580897c56 | |||
| 229c8095be | |||
| ce24424449 | |||
| 4810898cfa | |||
| 10cc651578 | |||
| 20f64bbf4f | |||
| e1cb78fecf | |||
| 6476eabdf5 | |||
| 95d5c156a1 | |||
| 18393b55d1 | |||
| 77491f2801 | |||
| 8d3cb6da72 | |||
| d1cf3f09b2 | |||
| 0d5b3a0ece | |||
| 4184d5ed2c | |||
| 60a5ad7279 | |||
| b2ec1f99b9 | |||
| 8da1903168 | |||
| 03952eca53 | |||
| 9197000690 | |||
| 36fb1c7804 | |||
| b61ce3527b | |||
| 2d5f6f1b3d | |||
| 69bf3dafd8 | |||
| 6cbec13495 | |||
| 31e5b586a1 |
+20
-8
@@ -158,7 +158,7 @@ from deerflow.config import get_app_config
|
|||||||
|
|
||||||
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
|
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
|
||||||
|
|
||||||
1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory
|
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
|
||||||
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
||||||
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
||||||
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
|
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
|
||||||
@@ -216,6 +216,9 @@ FastAPI application on port 8001 with health check at `GET /health`.
|
|||||||
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
||||||
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
|
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
|
||||||
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing |
|
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing |
|
||||||
|
| **Thread Runs** (`/api/threads/{id}/runs`) | `POST /` - create background run; `POST /stream` - create + SSE stream; `POST /wait` - create + block; `GET /` - list runs; `GET /{rid}` - run details; `POST /{rid}/cancel` - cancel; `GET /{rid}/join` - join SSE; `GET /{rid}/messages` - paginated messages `{data, has_more}`; `GET /{rid}/events` - full event stream; `GET /../messages` - thread messages with feedback; `GET /../token-usage` - aggregate tokens |
|
||||||
|
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
||||||
|
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
||||||
|
|
||||||
Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway.
|
Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway.
|
||||||
|
|
||||||
@@ -229,7 +232,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
|||||||
|
|
||||||
**Virtual Path System**:
|
**Virtual Path System**:
|
||||||
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
||||||
- Physical: `backend/.deer-flow/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
- Physical: `backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
||||||
- Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()`
|
- Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()`
|
||||||
- Detection: `is_local_sandbox()` checks `sandbox_id == "local"`
|
- Detection: `is_local_sandbox()` checks `sandbox_id == "local"`
|
||||||
|
|
||||||
@@ -269,7 +272,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
|||||||
- `invoke_acp_agent` - Invokes external ACP-compatible agents from `config.yaml`
|
- `invoke_acp_agent` - Invokes external ACP-compatible agents from `config.yaml`
|
||||||
- ACP launchers must be real ACP adapters. The standard `codex` CLI is not ACP-compatible by itself; configure a wrapper such as `npx -y @zed-industries/codex-acp` or an installed `codex-acp` binary
|
- ACP launchers must be real ACP adapters. The standard `codex` CLI is not ACP-compatible by itself; configure a wrapper such as `npx -y @zed-industries/codex-acp` or an installed `codex-acp` binary
|
||||||
- Missing ACP executables now return an actionable error message instead of a raw `[Errno 2]`
|
- Missing ACP executables now return an actionable error message instead of a raw `[Errno 2]`
|
||||||
- Each ACP agent uses a per-thread workspace at `{base_dir}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
|
- Each ACP agent uses a per-thread workspace at `{base_dir}/users/{user_id}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
|
||||||
- `image_search/` - Image search via DuckDuckGo
|
- `image_search/` - Image search via DuckDuckGo
|
||||||
|
|
||||||
### MCP System (`packages/harness/deerflow/mcp/`)
|
### MCP System (`packages/harness/deerflow/mcp/`)
|
||||||
@@ -338,18 +341,27 @@ Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow a
|
|||||||
|
|
||||||
**Components**:
|
**Components**:
|
||||||
- `updater.py` - LLM-based memory updates with fact extraction, whitespace-normalized fact deduplication (trims leading/trailing whitespace before comparing), and atomic file I/O
|
- `updater.py` - LLM-based memory updates with fact extraction, whitespace-normalized fact deduplication (trims leading/trailing whitespace before comparing), and atomic file I/O
|
||||||
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time)
|
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time); captures `user_id` at enqueue time so it survives the `threading.Timer` boundary
|
||||||
- `prompt.py` - Prompt templates for memory updates
|
- `prompt.py` - Prompt templates for memory updates
|
||||||
|
- `storage.py` - File-based storage with per-user isolation; cache keyed by `(user_id, agent_name)` tuple
|
||||||
|
|
||||||
**Data Structure** (stored in `backend/.deer-flow/memory.json`):
|
**Per-User Isolation**:
|
||||||
|
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
|
||||||
|
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
|
||||||
|
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
|
||||||
|
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
|
||||||
|
- Absolute `storage_path` in config opts out of per-user isolation
|
||||||
|
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
|
||||||
|
|
||||||
|
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
|
||||||
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
||||||
- **History**: `recentMonths`, `earlierContext`, `longTermBackground`
|
- **History**: `recentMonths`, `earlierContext`, `longTermBackground`
|
||||||
- **Facts**: Discrete facts with `id`, `content`, `category` (preference/knowledge/context/behavior/goal), `confidence` (0-1), `createdAt`, `source`
|
- **Facts**: Discrete facts with `id`, `content`, `category` (preference/knowledge/context/behavior/goal), `confidence` (0-1), `createdAt`, `source`
|
||||||
|
|
||||||
**Workflow**:
|
**Workflow**:
|
||||||
1. `MemoryMiddleware` filters messages (user inputs + final AI responses) and queues conversation
|
1. `MemoryMiddleware` filters messages (user inputs + final AI responses), captures `user_id` via `get_effective_user_id()`, and queues conversation with the captured `user_id`
|
||||||
2. Queue debounces (30s default), batches updates, deduplicates per-thread
|
2. Queue debounces (30s default), batches updates, deduplicates per-thread
|
||||||
3. Background thread invokes LLM to extract context updates and facts
|
3. Background thread invokes LLM to extract context updates and facts, using the stored `user_id` (not the contextvar, which is unavailable on timer threads)
|
||||||
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
||||||
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
||||||
|
|
||||||
@@ -357,7 +369,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
|
|||||||
|
|
||||||
**Configuration** (`config.yaml` → `memory`):
|
**Configuration** (`config.yaml` → `memory`):
|
||||||
- `enabled` / `injection_enabled` - Master switches
|
- `enabled` / `injection_enabled` - Master switches
|
||||||
- `storage_path` - Path to memory.json
|
- `storage_path` - Path to memory.json (absolute path opts out of per-user isolation)
|
||||||
- `debounce_seconds` - Wait time before processing (default: 30)
|
- `debounce_seconds` - Wait time before processing (default: 30)
|
||||||
- `model_name` - LLM for updates (null = default model)
|
- `model_name` - LLM for updates (null = default model)
|
||||||
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from app.channels.base import Channel
|
|||||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -344,8 +345,9 @@ class FeishuChannel(Channel):
|
|||||||
return f"Failed to obtain the [{type}]"
|
return f"Failed to obtain the [{type}]"
|
||||||
|
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
paths.ensure_thread_dirs(thread_id)
|
user_id = get_effective_user_id()
|
||||||
uploads_dir = paths.sandbox_uploads_dir(thread_id).resolve()
|
paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
||||||
|
uploads_dir = paths.sandbox_uploads_dir(thread_id, user_id=user_id).resolve()
|
||||||
|
|
||||||
ext = "png" if type == "image" else "bin"
|
ext = "png" if type == "image" else "bin"
|
||||||
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
|
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from langgraph_sdk.errors import ConflictError
|
|||||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
from app.channels.store import ChannelStore
|
from app.channels.store import ChannelStore
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -341,14 +342,15 @@ def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedA
|
|||||||
|
|
||||||
attachments: list[ResolvedAttachment] = []
|
attachments: list[ResolvedAttachment] = []
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
outputs_dir = paths.sandbox_outputs_dir(thread_id).resolve()
|
user_id = get_effective_user_id()
|
||||||
|
outputs_dir = paths.sandbox_outputs_dir(thread_id, user_id=user_id).resolve()
|
||||||
for virtual_path in artifacts:
|
for virtual_path in artifacts:
|
||||||
# Security: only allow files from the agent outputs directory
|
# Security: only allow files from the agent outputs directory
|
||||||
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
|
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
|
||||||
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
|
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
actual = paths.resolve_virtual_path(thread_id, virtual_path)
|
actual = paths.resolve_virtual_path(thread_id, virtual_path, user_id=user_id)
|
||||||
# Verify the resolved path is actually under the outputs directory
|
# Verify the resolved path is actually under the outputs directory
|
||||||
# (guards against path-traversal even after prefix check)
|
# (guards against path-traversal even after prefix check)
|
||||||
try:
|
try:
|
||||||
|
|||||||
+49
-53
@@ -2,7 +2,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from datetime import UTC
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@@ -41,77 +40,69 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
async def _ensure_admin_user(app: FastAPI) -> None:
|
async def _ensure_admin_user(app: FastAPI) -> None:
|
||||||
"""Auto-create the admin user on first boot if no users exist.
|
"""Startup hook: generate init token on first boot; migrate orphan threads otherwise.
|
||||||
|
|
||||||
After admin creation, migrate orphan threads from the LangGraph
|
After admin creation, migrate orphan threads from the LangGraph
|
||||||
store (metadata.owner_id unset) to the admin account. This is the
|
store (metadata.user_id unset) to the admin account. This is the
|
||||||
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
|
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
|
||||||
authentication have existing LangGraph thread data that needs an
|
authentication have existing LangGraph thread data that needs an
|
||||||
owner assigned.
|
owner assigned.
|
||||||
|
First boot (no admin exists):
|
||||||
|
- Generates a one-time ``init_token`` stored in ``app.state.init_token``
|
||||||
|
- Logs the token to stdout so the operator can copy-paste it into the
|
||||||
|
``/setup`` form to create the first admin account interactively.
|
||||||
|
- Does NOT create any user accounts automatically.
|
||||||
|
|
||||||
No SQL persistence migration is needed: the four owner_id columns
|
Subsequent boots (admin already exists):
|
||||||
|
- Runs the one-time "no-auth → with-auth" orphan thread migration for
|
||||||
|
existing LangGraph thread metadata that has no owner_id.
|
||||||
|
|
||||||
|
No SQL persistence migration is needed: the four user_id columns
|
||||||
(threads_meta, runs, run_events, feedback) only come into existence
|
(threads_meta, runs, run_events, feedback) only come into existence
|
||||||
alongside the auth module via create_all, so freshly created tables
|
alongside the auth module via create_all, so freshly created tables
|
||||||
never contain NULL-owner rows. "Existing persistence DB + new auth"
|
never contain NULL-owner rows.
|
||||||
is not a supported upgrade path — fresh install or wipe-and-retry.
|
|
||||||
|
|
||||||
Multi-worker safe: relies on SQLite UNIQUE constraint to resolve
|
|
||||||
races during admin creation. Only the worker that successfully
|
|
||||||
creates/updates the admin prints the password; losers silently skip.
|
|
||||||
"""
|
"""
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
from app.gateway.auth.credential_file import write_initial_credentials
|
from sqlalchemy import select
|
||||||
from app.gateway.deps import get_local_provider
|
|
||||||
|
|
||||||
def _announce_credentials(email: str, password: str, *, label: str, headline: str) -> None:
|
from app.gateway.deps import get_local_provider
|
||||||
"""Write the password to a 0600 file and log the path (never the secret)."""
|
from deerflow.persistence.engine import get_session_factory
|
||||||
cred_path = write_initial_credentials(email, password, label=label)
|
from deerflow.persistence.user.model import UserRow
|
||||||
logger.info("=" * 60)
|
|
||||||
logger.info(" %s", headline)
|
|
||||||
logger.info(" Credentials written to: %s (mode 0600)", cred_path)
|
|
||||||
logger.info(" Change it after login: Settings -> Account")
|
|
||||||
logger.info("=" * 60)
|
|
||||||
|
|
||||||
provider = get_local_provider()
|
provider = get_local_provider()
|
||||||
user_count = await provider.count_users()
|
admin_count = await provider.count_admin_users()
|
||||||
|
|
||||||
admin = None
|
if admin_count == 0:
|
||||||
|
init_token = secrets.token_urlsafe(32)
|
||||||
|
app.state.init_token = init_token
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(" First boot detected — no admin account exists.")
|
||||||
|
logger.info(" Use the one-time token below to create the admin account.")
|
||||||
|
logger.info(" Copy it into the /setup form when prompted.")
|
||||||
|
logger.info(" INIT TOKEN: %s", init_token)
|
||||||
|
logger.info(" Visit /setup to complete admin account creation.")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
return
|
||||||
|
|
||||||
if user_count == 0:
|
# Admin already exists — run orphan thread migration for any
|
||||||
password = secrets.token_urlsafe(16)
|
# LangGraph thread metadata that pre-dates the auth module.
|
||||||
try:
|
sf = get_session_factory()
|
||||||
admin = await provider.create_user(email="admin@deerflow.dev", password=password, system_role="admin", needs_setup=True)
|
if sf is None:
|
||||||
except ValueError:
|
return
|
||||||
return # Another worker already created the admin.
|
|
||||||
_announce_credentials(admin.email, password, label="initial", headline="Admin account created on first boot")
|
|
||||||
else:
|
|
||||||
# Admin exists but setup never completed — reset password so operator
|
|
||||||
# can always find it in the console without needing the CLI.
|
|
||||||
# Multi-worker guard: if admin was created less than 30s ago, another
|
|
||||||
# worker just created it and will print the password — skip reset.
|
|
||||||
admin = await provider.get_user_by_email("admin@deerflow.dev")
|
|
||||||
if admin and admin.needs_setup:
|
|
||||||
import time
|
|
||||||
|
|
||||||
age = time.time() - admin.created_at.replace(tzinfo=UTC).timestamp()
|
async with sf() as session:
|
||||||
if age >= 30:
|
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
|
||||||
from app.gateway.auth.password import hash_password_async
|
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
|
||||||
password = secrets.token_urlsafe(16)
|
if row is None:
|
||||||
admin.password_hash = await hash_password_async(password)
|
return # Should not happen (admin_count > 0 above), but be safe.
|
||||||
admin.token_version += 1
|
|
||||||
await provider.update_user(admin)
|
|
||||||
_announce_credentials(admin.email, password, label="reset", headline="Admin account setup incomplete — password reset")
|
|
||||||
|
|
||||||
if admin is None:
|
admin_id = str(row.id)
|
||||||
return # Nothing to bind orphans to.
|
|
||||||
|
|
||||||
admin_id = str(admin.id)
|
|
||||||
|
|
||||||
# LangGraph store orphan migration — non-fatal.
|
# LangGraph store orphan migration — non-fatal.
|
||||||
# This covers the "no-auth → with-auth" upgrade path for users
|
# This covers the "no-auth → with-auth" upgrade path for users
|
||||||
# whose existing LangGraph thread metadata has no owner_id set.
|
# whose existing LangGraph thread metadata has no user_id set.
|
||||||
store = getattr(app.state, "store", None)
|
store = getattr(app.state, "store", None)
|
||||||
if store is not None:
|
if store is not None:
|
||||||
try:
|
try:
|
||||||
@@ -143,7 +134,7 @@ async def _iter_store_items(store, namespace, *, page_size: int = 500):
|
|||||||
|
|
||||||
|
|
||||||
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
||||||
"""Migrate LangGraph store threads with no owner_id to the given admin.
|
"""Migrate LangGraph store threads with no user_id to the given admin.
|
||||||
|
|
||||||
Uses cursor pagination so all orphans are migrated regardless of
|
Uses cursor pagination so all orphans are migrated regardless of
|
||||||
count. Returns the number of rows migrated.
|
count. Returns the number of rows migrated.
|
||||||
@@ -151,8 +142,8 @@ async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
|||||||
migrated = 0
|
migrated = 0
|
||||||
async for item in _iter_store_items(store, ("threads",)):
|
async for item in _iter_store_items(store, ("threads",)):
|
||||||
metadata = item.value.get("metadata", {})
|
metadata = item.value.get("metadata", {})
|
||||||
if not metadata.get("owner_id"):
|
if not metadata.get("user_id"):
|
||||||
metadata["owner_id"] = admin_user_id
|
metadata["user_id"] = admin_user_id
|
||||||
item.value["metadata"] = metadata
|
item.value["metadata"] = metadata
|
||||||
await store.aput(("threads",), item.key, item.value)
|
await store.aput(("threads",), item.key, item.value)
|
||||||
migrated += 1
|
migrated += 1
|
||||||
@@ -374,6 +365,11 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
|
|||||||
"""
|
"""
|
||||||
return {"status": "healthy", "service": "deer-flow-gateway"}
|
return {"status": "healthy", "service": "deer-flow-gateway"}
|
||||||
|
|
||||||
|
# Ensure init_token always exists on app.state (None until lifespan sets it
|
||||||
|
# if no admin is found). This prevents AttributeError in tests that don't
|
||||||
|
# run the full lifespan.
|
||||||
|
app.state.init_token = None
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ class AuthErrorCode(StrEnum):
|
|||||||
EMAIL_ALREADY_EXISTS = "email_already_exists"
|
EMAIL_ALREADY_EXISTS = "email_already_exists"
|
||||||
PROVIDER_NOT_FOUND = "provider_not_found"
|
PROVIDER_NOT_FOUND = "provider_not_found"
|
||||||
NOT_AUTHENTICATED = "not_authenticated"
|
NOT_AUTHENTICATED = "not_authenticated"
|
||||||
|
SYSTEM_ALREADY_INITIALIZED = "system_already_initialized"
|
||||||
|
INVALID_INIT_TOKEN = "invalid_init_token"
|
||||||
|
|
||||||
|
|
||||||
class TokenError(StrEnum):
|
class TokenError(StrEnum):
|
||||||
|
|||||||
@@ -78,6 +78,10 @@ class LocalAuthProvider(AuthProvider):
|
|||||||
"""Return total number of registered users."""
|
"""Return total number of registered users."""
|
||||||
return await self._repo.count_users()
|
return await self._repo.count_users()
|
||||||
|
|
||||||
|
async def count_admin_users(self) -> int:
|
||||||
|
"""Return number of admin users."""
|
||||||
|
return await self._repo.count_admin_users()
|
||||||
|
|
||||||
async def update_user(self, user: User) -> User:
|
async def update_user(self, user: User) -> User:
|
||||||
"""Update an existing user."""
|
"""Update an existing user."""
|
||||||
return await self._repo.update_user(user)
|
return await self._repo.update_user(user)
|
||||||
|
|||||||
@@ -83,6 +83,11 @@ class UserRepository(ABC):
|
|||||||
"""Return total number of registered users."""
|
"""Return total number of registered users."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def count_admin_users(self) -> int:
|
||||||
|
"""Return number of users with system_role == 'admin'."""
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||||
"""Get user by OAuth provider and ID.
|
"""Get user by OAuth provider and ID.
|
||||||
|
|||||||
@@ -114,6 +114,11 @@ class SQLiteUserRepository(UserRepository):
|
|||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
return await session.scalar(stmt) or 0
|
return await session.scalar(stmt) or 0
|
||||||
|
|
||||||
|
async def count_admin_users(self) -> int:
|
||||||
|
stmt = select(func.count()).select_from(UserRow).where(UserRow.system_role == "admin")
|
||||||
|
async with self._sf() as session:
|
||||||
|
return await session.scalar(stmt) or 0
|
||||||
|
|
||||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||||
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
|
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ _PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
|
|||||||
"/api/v1/auth/register",
|
"/api/v1/auth/register",
|
||||||
"/api/v1/auth/logout",
|
"/api/v1/auth/logout",
|
||||||
"/api/v1/auth/setup-status",
|
"/api/v1/auth/setup-status",
|
||||||
|
"/api/v1/auth/initialize",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -233,18 +233,18 @@ def require_permission(
|
|||||||
# (``threads_meta`` table). We verify ownership via
|
# (``threads_meta`` table). We verify ownership via
|
||||||
# ``ThreadMetaStore.check_access``: it returns True for
|
# ``ThreadMetaStore.check_access``: it returns True for
|
||||||
# missing rows (untracked legacy thread) and for rows whose
|
# missing rows (untracked legacy thread) and for rows whose
|
||||||
# ``owner_id`` is NULL (shared / pre-auth data), so this is
|
# ``user_id`` is NULL (shared / pre-auth data), so this is
|
||||||
# strict-deny rather than strict-allow — only an *existing*
|
# strict-deny rather than strict-allow — only an *existing*
|
||||||
# row with a *different* owner_id triggers 404.
|
# row with a *different* user_id triggers 404.
|
||||||
if owner_check:
|
if owner_check:
|
||||||
thread_id = kwargs.get("thread_id")
|
thread_id = kwargs.get("thread_id")
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||||
|
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_store
|
||||||
|
|
||||||
thread_meta_repo = get_thread_meta_repo(request)
|
thread_store = get_thread_store(request)
|
||||||
allowed = await thread_meta_repo.check_access(
|
allowed = await thread_store.check_access(
|
||||||
thread_id,
|
thread_id,
|
||||||
str(auth.user.id),
|
str(auth.user.id),
|
||||||
require_existing=require_existing,
|
require_existing=require_existing,
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ _AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
|
|||||||
"/api/v1/auth/login/local",
|
"/api/v1/auth/login/local",
|
||||||
"/api/v1/auth/logout",
|
"/api/v1/auth/logout",
|
||||||
"/api/v1/auth/register",
|
"/api/v1/auth/register",
|
||||||
|
"/api/v1/auth/initialize",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+16
-10
@@ -1,8 +1,7 @@
|
|||||||
"""Centralized accessors for singleton objects stored on ``app.state``.
|
"""Centralized accessors for singleton objects stored on ``app.state``.
|
||||||
|
|
||||||
**Getters** (used by routers): raise 503 when a required dependency is
|
**Getters** (used by routers): raise 503 when a required dependency is
|
||||||
missing, except ``get_store`` and ``get_thread_meta_repo`` which return
|
missing, except ``get_store`` which returns ``None``.
|
||||||
``None``.
|
|
||||||
|
|
||||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||||
"""
|
"""
|
||||||
@@ -20,6 +19,7 @@ from deerflow.runtime import RunContext, RunManager
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||||
|
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -31,10 +31,10 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
async with langgraph_runtime(app):
|
async with langgraph_runtime(app):
|
||||||
yield
|
yield
|
||||||
"""
|
"""
|
||||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
|
||||||
from deerflow.config import get_app_config
|
from deerflow.config import get_app_config
|
||||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
||||||
from deerflow.runtime import make_store, make_stream_bridge
|
from deerflow.runtime import make_store, make_stream_bridge
|
||||||
|
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||||
from deerflow.runtime.events.store import make_run_event_store
|
from deerflow.runtime.events.store import make_run_event_store
|
||||||
|
|
||||||
async with AsyncExitStack() as stack:
|
async with AsyncExitStack() as stack:
|
||||||
@@ -53,18 +53,18 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
if sf is not None:
|
if sf is not None:
|
||||||
from deerflow.persistence.feedback import FeedbackRepository
|
from deerflow.persistence.feedback import FeedbackRepository
|
||||||
from deerflow.persistence.run import RunRepository
|
from deerflow.persistence.run import RunRepository
|
||||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
|
||||||
|
|
||||||
app.state.run_store = RunRepository(sf)
|
app.state.run_store = RunRepository(sf)
|
||||||
app.state.feedback_repo = FeedbackRepository(sf)
|
app.state.feedback_repo = FeedbackRepository(sf)
|
||||||
app.state.thread_meta_repo = ThreadMetaRepository(sf)
|
|
||||||
else:
|
else:
|
||||||
from deerflow.persistence.thread_meta import MemoryThreadMetaStore
|
|
||||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||||
|
|
||||||
app.state.run_store = MemoryRunStore()
|
app.state.run_store = MemoryRunStore()
|
||||||
app.state.feedback_repo = None
|
app.state.feedback_repo = None
|
||||||
app.state.thread_meta_repo = MemoryThreadMetaStore(app.state.store)
|
|
||||||
|
from deerflow.persistence.thread_meta import make_thread_store
|
||||||
|
|
||||||
|
app.state.thread_store = make_thread_store(sf, app.state.store)
|
||||||
|
|
||||||
# Run event store (has its own factory with config-driven backend selection)
|
# Run event store (has its own factory with config-driven backend selection)
|
||||||
run_events_config = getattr(config, "run_events", None)
|
run_events_config = getattr(config, "run_events", None)
|
||||||
@@ -80,7 +80,7 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Getters -- called by routers per-request
|
# Getters – called by routers per-request
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -110,7 +110,12 @@ def get_store(request: Request):
|
|||||||
return getattr(request.app.state, "store", None)
|
return getattr(request.app.state, "store", None)
|
||||||
|
|
||||||
|
|
||||||
get_thread_meta_repo = _require("thread_meta_repo", "Thread metadata store")
|
def get_thread_store(request: Request) -> ThreadMetaStore:
|
||||||
|
"""Return the thread metadata store (SQL or memory-backed)."""
|
||||||
|
val = getattr(request.app.state, "thread_store", None)
|
||||||
|
if val is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Thread metadata store not available")
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
def get_run_context(request: Request) -> RunContext:
|
def get_run_context(request: Request) -> RunContext:
|
||||||
@@ -128,10 +133,11 @@ def get_run_context(request: Request) -> RunContext:
|
|||||||
store=get_store(request),
|
store=get_store(request),
|
||||||
event_store=get_run_event_store(request),
|
event_store=get_run_event_store(request),
|
||||||
run_events_config=getattr(get_app_config(), "run_events", None),
|
run_events_config=getattr(get_app_config(), "run_events", None),
|
||||||
thread_meta_repo=get_thread_meta_repo(request),
|
thread_store=get_thread_store(request),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Auth helpers (used by authz.py and auth middleware)
|
# Auth helpers (used by authz.py and auth middleware)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -93,14 +93,14 @@ async def authenticate(request):
|
|||||||
|
|
||||||
@auth.on
|
@auth.on
|
||||||
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
||||||
"""Inject owner_id metadata on writes; filter by owner_id on reads.
|
"""Inject user_id metadata on writes; filter by user_id on reads.
|
||||||
|
|
||||||
Gateway stores thread ownership as ``metadata.owner_id``.
|
Gateway stores thread ownership as ``metadata.user_id``.
|
||||||
This handler ensures LangGraph Server enforces the same isolation.
|
This handler ensures LangGraph Server enforces the same isolation.
|
||||||
"""
|
"""
|
||||||
# On create/update: stamp owner_id into metadata
|
# On create/update: stamp user_id into metadata
|
||||||
metadata = value.setdefault("metadata", {})
|
metadata = value.setdefault("metadata", {})
|
||||||
metadata["owner_id"] = ctx.user.identity
|
metadata["user_id"] = ctx.user.identity
|
||||||
|
|
||||||
# Return filter dict — LangGraph applies it to search/read/delete
|
# Return filter dict — LangGraph applies it to search/read/delete
|
||||||
return {"owner_id": ctx.user.identity}
|
return {"user_id": ctx.user.identity}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
|
|
||||||
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
||||||
@@ -22,7 +23,7 @@ def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
|||||||
HTTPException: If the path is invalid or outside allowed directories.
|
HTTPException: If the path is invalid or outside allowed directories.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return get_paths().resolve_virtual_path(thread_id, virtual_path)
|
return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=get_effective_user_id())
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
status = 403 if "traversal" in str(e) else 400
|
status = 403 if "traversal" in str(e) else 400
|
||||||
raise HTTPException(status_code=status, detail=str(e))
|
raise HTTPException(status_code=status, detail=str(e))
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import secrets
|
||||||
import time
|
import time
|
||||||
from ipaddress import ip_address, ip_network
|
from ipaddress import ip_address, ip_network
|
||||||
|
|
||||||
@@ -378,9 +379,74 @@ async def get_me(request: Request):
|
|||||||
|
|
||||||
@router.get("/setup-status")
|
@router.get("/setup-status")
|
||||||
async def setup_status():
|
async def setup_status():
|
||||||
"""Check if admin account exists. Always False after first boot."""
|
"""Check if an admin account exists. Returns needs_setup=True when no admin exists."""
|
||||||
user_count = await get_local_provider().count_users()
|
admin_count = await get_local_provider().count_admin_users()
|
||||||
return {"needs_setup": user_count == 0}
|
return {"needs_setup": admin_count == 0}
|
||||||
|
|
||||||
|
|
||||||
|
class InitializeAdminRequest(BaseModel):
|
||||||
|
"""Request model for first-boot admin account creation."""
|
||||||
|
|
||||||
|
email: EmailStr
|
||||||
|
password: str = Field(..., min_length=8)
|
||||||
|
init_token: str | None = Field(default=None, description="One-time initialization token printed to server logs on first boot")
|
||||||
|
|
||||||
|
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/initialize", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def initialize_admin(request: Request, response: Response, body: InitializeAdminRequest):
|
||||||
|
"""Create the first admin account on initial system setup.
|
||||||
|
|
||||||
|
Only callable when no admin exists. Returns 409 Conflict if an admin
|
||||||
|
already exists. Requires the one-time ``init_token`` that is logged to
|
||||||
|
stdout at startup whenever the system has no admin account.
|
||||||
|
|
||||||
|
On success the token is consumed (one-time use), the admin account is
|
||||||
|
created with ``needs_setup=False``, and the session cookie is set.
|
||||||
|
"""
|
||||||
|
# Validate the one-time initialization token. The token is generated
|
||||||
|
# at startup and stored in app.state.init_token; it is consumed here on
|
||||||
|
# the first successful call so it cannot be replayed.
|
||||||
|
# Using str | None allows a missing/null token to return 403 (not 422),
|
||||||
|
# giving a consistent error response regardless of whether the token is
|
||||||
|
# absent or incorrect.
|
||||||
|
stored_token: str | None = getattr(request.app.state, "init_token", None)
|
||||||
|
provided_token: str = body.init_token or ""
|
||||||
|
if stored_token is None or not secrets.compare_digest(stored_token, provided_token):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_INIT_TOKEN, message="Invalid or expired initialization token").model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
admin_count = await get_local_provider().count_admin_users()
|
||||||
|
if admin_count > 0:
|
||||||
|
# Do NOT consume the token on this error path — consuming it here
|
||||||
|
# would allow an attacker to exhaust the token by calling with the
|
||||||
|
# correct token when admin already exists (denial-of-service).
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="admin", needs_setup=False)
|
||||||
|
except ValueError:
|
||||||
|
# DB unique-constraint race: another concurrent request beat us.
|
||||||
|
# Do NOT consume the token here for the same reason as above.
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Consume the token only after successful initialization — this is the
|
||||||
|
# single place where one-time use is enforced.
|
||||||
|
request.app.state.init_token = None
|
||||||
|
|
||||||
|
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||||
|
_set_session_cookie(response, token, request)
|
||||||
|
|
||||||
|
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
||||||
|
|
||||||
|
|
||||||
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
|
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
|
||||||
|
|||||||
@@ -30,11 +30,16 @@ class FeedbackCreateRequest(BaseModel):
|
|||||||
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackUpsertRequest(BaseModel):
|
||||||
|
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
|
||||||
|
comment: str | None = Field(default=None, description="Optional text feedback")
|
||||||
|
|
||||||
|
|
||||||
class FeedbackResponse(BaseModel):
|
class FeedbackResponse(BaseModel):
|
||||||
feedback_id: str
|
feedback_id: str
|
||||||
run_id: str
|
run_id: str
|
||||||
thread_id: str
|
thread_id: str
|
||||||
owner_id: str | None = None
|
user_id: str | None = None
|
||||||
message_id: str | None = None
|
message_id: str | None = None
|
||||||
rating: int
|
rating: int
|
||||||
comment: str | None = None
|
comment: str | None = None
|
||||||
@@ -53,6 +58,57 @@ class FeedbackStatsResponse(BaseModel):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||||
|
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||||
|
async def upsert_feedback(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
body: FeedbackUpsertRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create or update feedback for a run (idempotent)."""
|
||||||
|
if body.rating not in (1, -1):
|
||||||
|
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
||||||
|
|
||||||
|
user_id = await get_current_user(request)
|
||||||
|
|
||||||
|
run_store = get_run_store(request)
|
||||||
|
run = await run_store.get(run_id)
|
||||||
|
if run is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
if run.get("thread_id") != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
||||||
|
|
||||||
|
feedback_repo = get_feedback_repo(request)
|
||||||
|
return await feedback_repo.upsert(
|
||||||
|
run_id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
rating=body.rating,
|
||||||
|
user_id=user_id,
|
||||||
|
comment=body.comment,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{thread_id}/runs/{run_id}/feedback")
|
||||||
|
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
||||||
|
async def delete_run_feedback(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete the current user's feedback for a run."""
|
||||||
|
user_id = await get_current_user(request)
|
||||||
|
feedback_repo = get_feedback_repo(request)
|
||||||
|
deleted = await feedback_repo.delete_by_run(
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail="No feedback found for this run")
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||||
async def create_feedback(
|
async def create_feedback(
|
||||||
@@ -80,7 +136,7 @@ async def create_feedback(
|
|||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
rating=body.rating,
|
rating=body.rating,
|
||||||
owner_id=user_id,
|
user_id=user_id,
|
||||||
message_id=body.message_id,
|
message_id=body.message_id,
|
||||||
comment=body.comment,
|
comment=body.comment,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from deerflow.agents.memory.updater import (
|
|||||||
update_memory_fact,
|
update_memory_fact,
|
||||||
)
|
)
|
||||||
from deerflow.config.memory_config import get_memory_config
|
from deerflow.config.memory_config import get_memory_config
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["memory"])
|
router = APIRouter(prefix="/api", tags=["memory"])
|
||||||
|
|
||||||
@@ -147,7 +148,7 @@ async def get_memory() -> MemoryResponse:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
memory_data = get_memory_data()
|
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@@ -167,7 +168,7 @@ async def reload_memory() -> MemoryResponse:
|
|||||||
Returns:
|
Returns:
|
||||||
The reloaded memory data.
|
The reloaded memory data.
|
||||||
"""
|
"""
|
||||||
memory_data = reload_memory_data()
|
memory_data = reload_memory_data(user_id=get_effective_user_id())
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@@ -181,7 +182,7 @@ async def reload_memory() -> MemoryResponse:
|
|||||||
async def clear_memory() -> MemoryResponse:
|
async def clear_memory() -> MemoryResponse:
|
||||||
"""Clear all persisted memory data."""
|
"""Clear all persisted memory data."""
|
||||||
try:
|
try:
|
||||||
memory_data = clear_memory_data()
|
memory_data = clear_memory_data(user_id=get_effective_user_id())
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
||||||
|
|
||||||
@@ -202,6 +203,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
|||||||
content=request.content,
|
content=request.content,
|
||||||
category=request.category,
|
category=request.category,
|
||||||
confidence=request.confidence,
|
confidence=request.confidence,
|
||||||
|
user_id=get_effective_user_id(),
|
||||||
)
|
)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise _map_memory_fact_value_error(exc) from exc
|
raise _map_memory_fact_value_error(exc) from exc
|
||||||
@@ -221,7 +223,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
|||||||
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
||||||
"""Delete a single fact from memory by fact id."""
|
"""Delete a single fact from memory by fact id."""
|
||||||
try:
|
try:
|
||||||
memory_data = delete_memory_fact(fact_id)
|
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
@@ -245,6 +247,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
|||||||
content=request.content,
|
content=request.content,
|
||||||
category=request.category,
|
category=request.category,
|
||||||
confidence=request.confidence,
|
confidence=request.confidence,
|
||||||
|
user_id=get_effective_user_id(),
|
||||||
)
|
)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise _map_memory_fact_value_error(exc) from exc
|
raise _map_memory_fact_value_error(exc) from exc
|
||||||
@@ -265,7 +268,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
|||||||
)
|
)
|
||||||
async def export_memory() -> MemoryResponse:
|
async def export_memory() -> MemoryResponse:
|
||||||
"""Export the current memory data."""
|
"""Export the current memory data."""
|
||||||
memory_data = get_memory_data()
|
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@@ -279,7 +282,7 @@ async def export_memory() -> MemoryResponse:
|
|||||||
async def import_memory(request: MemoryResponse) -> MemoryResponse:
|
async def import_memory(request: MemoryResponse) -> MemoryResponse:
|
||||||
"""Import and persist memory data."""
|
"""Import and persist memory data."""
|
||||||
try:
|
try:
|
||||||
memory_data = import_memory_data(request.model_dump())
|
memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id())
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
||||||
|
|
||||||
@@ -337,7 +340,7 @@ async def get_memory_status() -> MemoryStatusResponse:
|
|||||||
Combined memory configuration and current data.
|
Combined memory configuration and current data.
|
||||||
"""
|
"""
|
||||||
config = get_memory_config()
|
config = get_memory_config()
|
||||||
memory_data = get_memory_data()
|
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||||
|
|
||||||
return MemoryStatusResponse(
|
return MemoryStatusResponse(
|
||||||
config=MemoryConfigResponse(
|
config=MemoryConfigResponse(
|
||||||
|
|||||||
@@ -11,10 +11,11 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, HTTPException, Query, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
|
from app.gateway.authz import require_permission
|
||||||
|
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||||
from app.gateway.services import sse_consumer, start_run
|
from app.gateway.services import sse_consumer, start_run
|
||||||
from deerflow.runtime import serialize_channel_values
|
from deerflow.runtime import serialize_channel_values
|
||||||
@@ -85,3 +86,57 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
|||||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||||
|
|
||||||
return {"status": record.status.value, "error": record.error}
|
return {"status": record.status.value, "error": record.error}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Run-scoped read endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_run(run_id: str, request: Request) -> dict:
|
||||||
|
"""Fetch run by run_id with user ownership check. Raises 404 if not found."""
|
||||||
|
run_store = get_run_store(request)
|
||||||
|
record = await run_store.get(run_id) # user_id=AUTO filters by contextvar
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{run_id}/messages")
|
||||||
|
@require_permission("runs", "read")
|
||||||
|
async def run_messages(
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
limit: int = Query(default=50, le=200, ge=1),
|
||||||
|
before_seq: int | None = Query(default=None),
|
||||||
|
after_seq: int | None = Query(default=None),
|
||||||
|
) -> dict:
|
||||||
|
"""Return paginated messages for a run (cursor-based).
|
||||||
|
|
||||||
|
Pagination:
|
||||||
|
- after_seq: messages with seq > after_seq (forward)
|
||||||
|
- before_seq: messages with seq < before_seq (backward)
|
||||||
|
- neither: latest messages
|
||||||
|
|
||||||
|
Response: { data: [...], has_more: bool }
|
||||||
|
"""
|
||||||
|
run = await _resolve_run(run_id, request)
|
||||||
|
event_store = get_run_event_store(request)
|
||||||
|
rows = await event_store.list_messages_by_run(
|
||||||
|
run["thread_id"], run_id,
|
||||||
|
limit=limit + 1,
|
||||||
|
before_seq=before_seq,
|
||||||
|
after_seq=after_seq,
|
||||||
|
)
|
||||||
|
has_more = len(rows) > limit
|
||||||
|
data = rows[:limit] if has_more else rows
|
||||||
|
return {"data": data, "has_more": has_more}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{run_id}/feedback")
|
||||||
|
@require_permission("runs", "read")
|
||||||
|
async def run_feedback(run_id: str, request: Request) -> list[dict]:
|
||||||
|
"""Return all feedback for a run."""
|
||||||
|
run = await _resolve_run(run_id, request)
|
||||||
|
feedback_repo = get_feedback_repo(request)
|
||||||
|
return await feedback_repo.list_by_run(run["thread_id"], run_id)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from fastapi.responses import Response, StreamingResponse
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||||
from app.gateway.services import sse_consumer, start_run
|
from app.gateway.services import sse_consumer, start_run
|
||||||
from deerflow.runtime import RunRecord, serialize_channel_values
|
from deerflow.runtime import RunRecord, serialize_channel_values
|
||||||
|
|
||||||
@@ -291,17 +291,62 @@ async def list_thread_messages(
|
|||||||
before_seq: int | None = Query(default=None),
|
before_seq: int | None = Query(default=None),
|
||||||
after_seq: int | None = Query(default=None),
|
after_seq: int | None = Query(default=None),
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Return displayable messages for a thread (across all runs)."""
|
"""Return displayable messages for a thread (across all runs), with feedback attached."""
|
||||||
event_store = get_run_event_store(request)
|
event_store = get_run_event_store(request)
|
||||||
return await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
messages = await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
||||||
|
|
||||||
|
# Attach feedback to the last AI message of each run
|
||||||
|
feedback_repo = get_feedback_repo(request)
|
||||||
|
user_id = await get_current_user(request)
|
||||||
|
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
|
||||||
|
|
||||||
|
# Find the last ai_message per run_id
|
||||||
|
last_ai_per_run: dict[str, int] = {} # run_id -> index in messages list
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
if msg.get("event_type") == "ai_message":
|
||||||
|
last_ai_per_run[msg["run_id"]] = i
|
||||||
|
|
||||||
|
# Attach feedback field
|
||||||
|
last_ai_indices = set(last_ai_per_run.values())
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
if i in last_ai_indices:
|
||||||
|
run_id = msg["run_id"]
|
||||||
|
fb = feedback_map.get(run_id)
|
||||||
|
msg["feedback"] = {
|
||||||
|
"feedback_id": fb["feedback_id"],
|
||||||
|
"rating": fb["rating"],
|
||||||
|
"comment": fb.get("comment"),
|
||||||
|
} if fb else None
|
||||||
|
else:
|
||||||
|
msg["feedback"] = None
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/messages")
|
@router.get("/{thread_id}/runs/{run_id}/messages")
|
||||||
@require_permission("runs", "read", owner_check=True)
|
@require_permission("runs", "read", owner_check=True)
|
||||||
async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]:
|
async def list_run_messages(
|
||||||
"""Return displayable messages for a specific run."""
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
limit: int = Query(default=50, le=200, ge=1),
|
||||||
|
before_seq: int | None = Query(default=None),
|
||||||
|
after_seq: int | None = Query(default=None),
|
||||||
|
) -> dict:
|
||||||
|
"""Return paginated messages for a specific run.
|
||||||
|
|
||||||
|
Response: { data: [...], has_more: bool }
|
||||||
|
"""
|
||||||
event_store = get_run_event_store(request)
|
event_store = get_run_event_store(request)
|
||||||
return await event_store.list_messages_by_run(thread_id, run_id)
|
rows = await event_store.list_messages_by_run(
|
||||||
|
thread_id, run_id,
|
||||||
|
limit=limit + 1,
|
||||||
|
before_seq=before_seq,
|
||||||
|
after_seq=after_seq,
|
||||||
|
)
|
||||||
|
has_more = len(rows) > limit
|
||||||
|
data = rows[:limit] if has_more else rows
|
||||||
|
return {"data": data, "has_more": has_more}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/events")
|
@router.get("/{thread_id}/runs/{run_id}/events")
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ matching the LangGraph Platform wire format expected by the
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -21,10 +22,11 @@ from fastapi import APIRouter, HTTPException, Request
|
|||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from app.gateway.deps import get_checkpointer
|
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store
|
||||||
from app.gateway.utils import sanitize_log_param
|
from app.gateway.utils import sanitize_log_param
|
||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
from deerflow.runtime import serialize_channel_values
|
from deerflow.runtime import serialize_channel_values
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||||
@@ -34,7 +36,7 @@ router = APIRouter(prefix="/api/threads", tags=["threads"])
|
|||||||
# them. Pydantic ``@field_validator("metadata")`` strips them on every
|
# them. Pydantic ``@field_validator("metadata")`` strips them on every
|
||||||
# inbound model below so a malicious client cannot reflect a forged
|
# inbound model below so a malicious client cannot reflect a forged
|
||||||
# owner identity through the API surface. Defense-in-depth — the
|
# owner identity through the API surface. Defense-in-depth — the
|
||||||
# row-level invariant is still ``threads_meta.owner_id`` populated from
|
# row-level invariant is still ``threads_meta.user_id`` populated from
|
||||||
# the auth contextvar; this list closes the metadata-blob echo gap.
|
# the auth contextvar; this list closes the metadata-blob echo gap.
|
||||||
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
|
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
|
||||||
|
|
||||||
@@ -142,11 +144,11 @@ class ThreadHistoryRequest(BaseModel):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
|
def _delete_thread_data(thread_id: str, paths: Paths | None = None, *, user_id: str | None = None) -> ThreadDeleteResponse:
|
||||||
"""Delete local persisted filesystem data for a thread."""
|
"""Delete local persisted filesystem data for a thread."""
|
||||||
path_manager = paths or get_paths()
|
path_manager = paths or get_paths()
|
||||||
try:
|
try:
|
||||||
path_manager.delete_thread_dir(thread_id)
|
path_manager.delete_thread_dir(thread_id, user_id=user_id)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@@ -194,10 +196,10 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
|||||||
and removes the thread_meta row from the configured ThreadMetaStore
|
and removes the thread_meta row from the configured ThreadMetaStore
|
||||||
(sqlite or memory).
|
(sqlite or memory).
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_store
|
||||||
|
|
||||||
# Clean local filesystem
|
# Clean local filesystem
|
||||||
response = _delete_thread_data(thread_id)
|
response = _delete_thread_data(thread_id, user_id=get_effective_user_id())
|
||||||
|
|
||||||
# Remove checkpoints (best-effort)
|
# Remove checkpoints (best-effort)
|
||||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
checkpointer = getattr(request.app.state, "checkpointer", None)
|
||||||
@@ -211,8 +213,8 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
|||||||
# Remove thread_meta row (best-effort) — required for sqlite backend
|
# Remove thread_meta row (best-effort) — required for sqlite backend
|
||||||
# so the deleted thread no longer appears in /threads/search.
|
# so the deleted thread no longer appears in /threads/search.
|
||||||
try:
|
try:
|
||||||
thread_meta_repo = get_thread_meta_repo(request)
|
thread_store = get_thread_store(request)
|
||||||
await thread_meta_repo.delete(thread_id)
|
await thread_store.delete(thread_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
|
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
@@ -227,17 +229,17 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
and an empty checkpoint (so state endpoints work immediately).
|
and an empty checkpoint (so state endpoints work immediately).
|
||||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
Idempotent: returns the existing record when ``thread_id`` already exists.
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_store
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
checkpointer = get_checkpointer(request)
|
||||||
thread_meta_repo = get_thread_meta_repo(request)
|
thread_store = get_thread_store(request)
|
||||||
thread_id = body.thread_id or str(uuid.uuid4())
|
thread_id = body.thread_id or str(uuid.uuid4())
|
||||||
now = time.time()
|
now = time.time()
|
||||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||||
|
|
||||||
# Idempotency: return existing record when already present
|
# Idempotency: return existing record when already present
|
||||||
existing_record = await thread_meta_repo.get(thread_id)
|
existing_record = await thread_store.get(thread_id)
|
||||||
if existing_record is not None:
|
if existing_record is not None:
|
||||||
return ThreadResponse(
|
return ThreadResponse(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
@@ -249,7 +251,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
|
|
||||||
# Write thread_meta so the thread appears in /threads/search immediately
|
# Write thread_meta so the thread appears in /threads/search immediately
|
||||||
try:
|
try:
|
||||||
await thread_meta_repo.create(
|
await thread_store.create(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=getattr(body, "assistant_id", None),
|
assistant_id=getattr(body, "assistant_id", None),
|
||||||
metadata=body.metadata,
|
metadata=body.metadata,
|
||||||
@@ -293,9 +295,9 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
|||||||
Delegates to the configured ThreadMetaStore implementation
|
Delegates to the configured ThreadMetaStore implementation
|
||||||
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_store
|
||||||
|
|
||||||
repo = get_thread_meta_repo(request)
|
repo = get_thread_store(request)
|
||||||
rows = await repo.search(
|
rows = await repo.search(
|
||||||
metadata=body.metadata or None,
|
metadata=body.metadata or None,
|
||||||
status=body.status,
|
status=body.status,
|
||||||
@@ -320,22 +322,22 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
|||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
||||||
"""Merge metadata into a thread record."""
|
"""Merge metadata into a thread record."""
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_store
|
||||||
|
|
||||||
thread_meta_repo = get_thread_meta_repo(request)
|
thread_store = get_thread_store(request)
|
||||||
record = await thread_meta_repo.get(thread_id)
|
record = await thread_store.get(thread_id)
|
||||||
if record is None:
|
if record is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||||
|
|
||||||
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
|
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
|
||||||
try:
|
try:
|
||||||
await thread_meta_repo.update_metadata(thread_id, body.metadata)
|
await thread_store.update_metadata(thread_id, body.metadata)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
|
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
|
||||||
raise HTTPException(status_code=500, detail="Failed to update thread")
|
raise HTTPException(status_code=500, detail="Failed to update thread")
|
||||||
|
|
||||||
# Re-read to get the merged metadata + refreshed updated_at
|
# Re-read to get the merged metadata + refreshed updated_at
|
||||||
record = await thread_meta_repo.get(thread_id) or record
|
record = await thread_store.get(thread_id) or record
|
||||||
return ThreadResponse(
|
return ThreadResponse(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
status=record.get("status", "idle"),
|
status=record.get("status", "idle"),
|
||||||
@@ -354,12 +356,12 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
|||||||
execution status from the checkpointer. Falls back to the checkpointer
|
execution status from the checkpointer. Falls back to the checkpointer
|
||||||
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
|
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_store
|
||||||
|
|
||||||
thread_meta_repo = get_thread_meta_repo(request)
|
thread_store = get_thread_store(request)
|
||||||
checkpointer = get_checkpointer(request)
|
checkpointer = get_checkpointer(request)
|
||||||
|
|
||||||
record: dict | None = await thread_meta_repo.get(thread_id)
|
record: dict | None = await thread_store.get(thread_id)
|
||||||
|
|
||||||
# Derive accurate status from the checkpointer
|
# Derive accurate status from the checkpointer
|
||||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
@@ -402,6 +404,165 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Event-store-backed message loader
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_LEGACY_CMD_INNER_CONTENT_RE = re.compile(
|
||||||
|
r"ToolMessage\(content=(?P<q>['\"])(?P<inner>.*?)(?P=q)",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_legacy_command_repr(content_field: Any) -> Any:
|
||||||
|
"""Recover the inner ToolMessage text from a legacy ``str(Command(...))`` repr.
|
||||||
|
|
||||||
|
Runs captured before the ``on_tool_end`` fix in ``journal.py`` stored
|
||||||
|
``str(Command(update={'messages':[ToolMessage(content='X', ...)]}))`` as the
|
||||||
|
tool_result content. New runs store ``'X'`` directly. For legacy rows, try
|
||||||
|
to extract ``'X'`` defensively; return the original string if extraction
|
||||||
|
fails (still no worse than the checkpoint fallback for summarized threads).
|
||||||
|
"""
|
||||||
|
if not isinstance(content_field, str) or not content_field.startswith("Command(update="):
|
||||||
|
return content_field
|
||||||
|
match = _LEGACY_CMD_INNER_CONTENT_RE.search(content_field)
|
||||||
|
return match.group("inner") if match else content_field
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_event_store_messages(request: Request, thread_id: str) -> list[dict] | None:
|
||||||
|
"""Load the full message stream for ``thread_id`` from the event store.
|
||||||
|
|
||||||
|
The event store is append-only and unaffected by summarization — the
|
||||||
|
checkpoint's ``channel_values["messages"]`` is rewritten in-place when the
|
||||||
|
SummarizationMiddleware runs, which drops all pre-summarize messages. The
|
||||||
|
event store retains the full transcript, so callers in Gateway mode should
|
||||||
|
prefer it for rendering the conversation history.
|
||||||
|
|
||||||
|
In addition to the core message content, this helper attaches two extra
|
||||||
|
fields to every returned dict:
|
||||||
|
|
||||||
|
- ``run_id``: the ``run_id`` of the event that produced this message.
|
||||||
|
Always present.
|
||||||
|
- ``feedback``: thumbs-up/down data. Present only on the **final
|
||||||
|
``ai_message`` of each run** (matching the per-run feedback semantics
|
||||||
|
of ``POST /api/threads/{id}/runs/{run_id}/feedback``). The frontend uses
|
||||||
|
the presence of this field to decide whether to render the feedback
|
||||||
|
button, which sidesteps the positional-index mapping bug that an
|
||||||
|
out-of-band ``/messages`` fetch exhibited.
|
||||||
|
|
||||||
|
Behaviour contract:
|
||||||
|
|
||||||
|
- **Full pagination.** ``RunEventStore.list_messages`` returns the newest
|
||||||
|
``limit`` records when no cursor is given, so a fixed limit silently
|
||||||
|
drops older messages on long threads. We size the read from
|
||||||
|
``count_messages()`` and then page forward with ``after_seq`` cursors.
|
||||||
|
- **Copy-on-read.** Each content dict is copied before ``id`` is patched
|
||||||
|
so the live store object is never mutated; ``MemoryRunEventStore``
|
||||||
|
returns live references.
|
||||||
|
- **Stable ids.** Messages with ``id=None`` (human + tool_result) receive
|
||||||
|
a deterministic ``uuid5(NAMESPACE_URL, f"{thread_id}:{seq}")`` so React
|
||||||
|
keys are stable across requests without altering stored data. AI messages
|
||||||
|
retain their LLM-assigned ``lc_run--*`` ids.
|
||||||
|
- **Legacy Command repr.** Rows captured before the ``journal.py``
|
||||||
|
``on_tool_end`` fix stored ``str(Command(update={...}))`` as the tool
|
||||||
|
result content. ``_sanitize_legacy_command_repr`` extracts the inner
|
||||||
|
ToolMessage text.
|
||||||
|
- **User context.** ``DbRunEventStore`` is user-scoped by default via
|
||||||
|
``resolve_user_id(AUTO)`` in ``runtime/user_context.py``. This helper
|
||||||
|
must run inside a request where ``@require_permission`` has populated
|
||||||
|
the user contextvar. Both callers below are decorated appropriately.
|
||||||
|
Do not call this helper from CLI or migration scripts without passing
|
||||||
|
``user_id=None`` explicitly to the underlying store methods.
|
||||||
|
|
||||||
|
Returns ``None`` when the event store is not configured or has no message
|
||||||
|
events for this thread, so callers fall back to checkpoint messages.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
event_store = get_run_event_store(request)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
total = await event_store.count_messages(thread_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("count_messages failed for thread %s", sanitize_log_param(thread_id))
|
||||||
|
return None
|
||||||
|
if not total:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Batch by page_size to keep memory bounded for very long threads.
|
||||||
|
page_size = 500
|
||||||
|
collected: list[dict] = []
|
||||||
|
after_seq: int | None = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
page = await event_store.list_messages(thread_id, limit=page_size, after_seq=after_seq)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("list_messages failed for thread %s", sanitize_log_param(thread_id))
|
||||||
|
return None
|
||||||
|
if not page:
|
||||||
|
break
|
||||||
|
collected.extend(page)
|
||||||
|
if len(page) < page_size:
|
||||||
|
break
|
||||||
|
next_cursor = page[-1].get("seq")
|
||||||
|
if next_cursor is None or (after_seq is not None and next_cursor <= after_seq):
|
||||||
|
break
|
||||||
|
after_seq = next_cursor
|
||||||
|
|
||||||
|
# Build the message list; track the final ``ai_message`` index per run so
|
||||||
|
# feedback can be attached at the right position (matches thread_runs.py).
|
||||||
|
messages: list[dict] = []
|
||||||
|
last_ai_per_run: dict[str, int] = {}
|
||||||
|
for evt in collected:
|
||||||
|
raw = evt.get("content")
|
||||||
|
if not isinstance(raw, dict) or "type" not in raw:
|
||||||
|
continue
|
||||||
|
content = dict(raw)
|
||||||
|
if content.get("id") is None:
|
||||||
|
content["id"] = str(uuid.uuid5(uuid.NAMESPACE_URL, f"{thread_id}:{evt['seq']}"))
|
||||||
|
if content.get("type") == "tool":
|
||||||
|
content["content"] = _sanitize_legacy_command_repr(content.get("content"))
|
||||||
|
run_id = evt.get("run_id")
|
||||||
|
if run_id:
|
||||||
|
content["run_id"] = run_id
|
||||||
|
if evt.get("event_type") == "ai_message" and run_id:
|
||||||
|
last_ai_per_run[run_id] = len(messages)
|
||||||
|
messages.append(content)
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Attach feedback to the final ai_message of each run. If the feedback
|
||||||
|
# subsystem is unavailable, leave the ``feedback`` field absent entirely
|
||||||
|
# so the frontend hides the button rather than showing it over a broken
|
||||||
|
# write path.
|
||||||
|
feedback_available = False
|
||||||
|
feedback_map: dict[str, dict] = {}
|
||||||
|
try:
|
||||||
|
feedback_repo = get_feedback_repo(request)
|
||||||
|
user_id = await get_current_user(request)
|
||||||
|
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
|
||||||
|
feedback_available = True
|
||||||
|
except Exception:
|
||||||
|
logger.exception("feedback lookup failed for thread %s", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
|
if feedback_available:
|
||||||
|
for run_id, idx in last_ai_per_run.items():
|
||||||
|
fb = feedback_map.get(run_id)
|
||||||
|
messages[idx]["feedback"] = (
|
||||||
|
{
|
||||||
|
"feedback_id": fb["feedback_id"],
|
||||||
|
"rating": fb["rating"],
|
||||||
|
"comment": fb.get("comment"),
|
||||||
|
}
|
||||||
|
if fb
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
@require_permission("threads", "read", owner_check=True)
|
||||||
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
||||||
@@ -440,8 +601,15 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
|||||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||||
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
||||||
|
|
||||||
|
values = serialize_channel_values(channel_values)
|
||||||
|
|
||||||
|
# Prefer event-store messages: append-only, immune to summarization.
|
||||||
|
es_messages = await _get_event_store_messages(request, thread_id)
|
||||||
|
if es_messages is not None:
|
||||||
|
values["messages"] = es_messages
|
||||||
|
|
||||||
return ThreadStateResponse(
|
return ThreadStateResponse(
|
||||||
values=serialize_channel_values(channel_values),
|
values=values,
|
||||||
next=next_tasks,
|
next=next_tasks,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||||
@@ -462,10 +630,10 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
|||||||
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
|
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
|
||||||
change immediately in both sqlite and memory backends.
|
change immediately in both sqlite and memory backends.
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_store
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
checkpointer = get_checkpointer(request)
|
||||||
thread_meta_repo = get_thread_meta_repo(request)
|
thread_store = get_thread_store(request)
|
||||||
|
|
||||||
# checkpoint_ns must be present in the config for aput — default to ""
|
# checkpoint_ns must be present in the config for aput — default to ""
|
||||||
# (the root graph namespace). checkpoint_id is optional; omitting it
|
# (the root graph namespace). checkpoint_id is optional; omitting it
|
||||||
@@ -529,7 +697,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
|||||||
new_title = body.values["title"]
|
new_title = body.values["title"]
|
||||||
if new_title: # Skip empty strings and None
|
if new_title: # Skip empty strings and None
|
||||||
try:
|
try:
|
||||||
await thread_meta_repo.update_display_name(thread_id, new_title)
|
await thread_store.update_display_name(thread_id, new_title)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
@@ -559,6 +727,11 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
|||||||
if body.before:
|
if body.before:
|
||||||
config["configurable"]["checkpoint_id"] = body.before
|
config["configurable"]["checkpoint_id"] = body.before
|
||||||
|
|
||||||
|
# Load the full event-store message stream once; attach to the latest
|
||||||
|
# checkpoint entry only (matching the prior semantics). The event store
|
||||||
|
# is append-only and immune to summarization.
|
||||||
|
es_messages = await _get_event_store_messages(request, thread_id)
|
||||||
|
|
||||||
entries: list[HistoryEntry] = []
|
entries: list[HistoryEntry] = []
|
||||||
is_latest_checkpoint = True
|
is_latest_checkpoint = True
|
||||||
try:
|
try:
|
||||||
@@ -582,11 +755,17 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
|||||||
if thread_data := channel_values.get("thread_data"):
|
if thread_data := channel_values.get("thread_data"):
|
||||||
values["thread_data"] = thread_data
|
values["thread_data"] = thread_data
|
||||||
|
|
||||||
# Attach messages from checkpointer only for the latest checkpoint
|
# Attach messages only to the latest checkpoint. Prefer the
|
||||||
|
# event-store stream (complete and unaffected by summarization);
|
||||||
|
# fall back to checkpoint channel_values when the event store is
|
||||||
|
# unavailable or empty.
|
||||||
if is_latest_checkpoint:
|
if is_latest_checkpoint:
|
||||||
messages = channel_values.get("messages")
|
if es_messages is not None:
|
||||||
if messages:
|
values["messages"] = es_messages
|
||||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
else:
|
||||||
|
messages = channel_values.get("messages")
|
||||||
|
if messages:
|
||||||
|
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||||
is_latest_checkpoint = False
|
is_latest_checkpoint = False
|
||||||
|
|
||||||
# Derive next tasks
|
# Derive next tasks
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||||
from deerflow.uploads.manager import (
|
from deerflow.uploads.manager import (
|
||||||
PathTraversalError,
|
PathTraversalError,
|
||||||
@@ -69,7 +70,7 @@ async def upload_files(
|
|||||||
uploads_dir = ensure_uploads_dir(thread_id)
|
uploads_dir = ensure_uploads_dir(thread_id)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
|
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||||
uploaded_files = []
|
uploaded_files = []
|
||||||
|
|
||||||
sandbox_provider = get_sandbox_provider()
|
sandbox_provider = get_sandbox_provider()
|
||||||
@@ -147,7 +148,7 @@ async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
|||||||
enrich_file_listing(result, thread_id)
|
enrich_file_listing(result, thread_id)
|
||||||
|
|
||||||
# Gateway additionally includes the sandbox-relative path.
|
# Gateway additionally includes the sandbox-relative path.
|
||||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
|
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||||
for f in result["files"]:
|
for f in result["files"]:
|
||||||
f["path"] = str(sandbox_uploads / f["filename"])
|
f["path"] = str(sandbox_uploads / f["filename"])
|
||||||
|
|
||||||
|
|||||||
@@ -229,15 +229,15 @@ async def start_run(
|
|||||||
# even for threads that were never explicitly created via POST /threads
|
# even for threads that were never explicitly created via POST /threads
|
||||||
# (e.g. stateless runs).
|
# (e.g. stateless runs).
|
||||||
try:
|
try:
|
||||||
existing = await run_ctx.thread_meta_repo.get(thread_id)
|
existing = await run_ctx.thread_store.get(thread_id)
|
||||||
if existing is None:
|
if existing is None:
|
||||||
await run_ctx.thread_meta_repo.create(
|
await run_ctx.thread_store.create(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=body.assistant_id,
|
assistant_id=body.assistant_id,
|
||||||
metadata=body.metadata,
|
metadata=body.metadata,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await run_ctx.thread_meta_repo.update_status(thread_id, "running")
|
await run_ctx.thread_store.update_status(thread_id, "running")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
@@ -285,7 +285,7 @@ async def start_run(
|
|||||||
record.task = task
|
record.task = task
|
||||||
|
|
||||||
# Title sync is handled by worker.py's finally block which reads the
|
# Title sync is handled by worker.py's finally block which reads the
|
||||||
# title from the checkpoint and calls thread_meta_repo.update_display_name
|
# title from the checkpoint and calls thread_store.update_display_name
|
||||||
# after the run completes.
|
# after the run completes.
|
||||||
|
|
||||||
return record
|
return record
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ title:
|
|||||||
# checkpointer.py
|
# checkpointer.py
|
||||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||||
|
|
||||||
checkpointer = SqliteSaver.from_conn_string("checkpoints.db")
|
checkpointer = SqliteSaver.from_conn_string("deerflow.db")
|
||||||
```
|
```
|
||||||
|
|
||||||
```json
|
```json
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointer
|
|
||||||
from .factory import create_deerflow_agent
|
from .factory import create_deerflow_agent
|
||||||
from .features import Next, Prev, RuntimeFeatures
|
from .features import Next, Prev, RuntimeFeatures
|
||||||
from .lead_agent import make_lead_agent
|
from .lead_agent import make_lead_agent
|
||||||
@@ -18,7 +17,4 @@ __all__ = [
|
|||||||
"make_lead_agent",
|
"make_lead_agent",
|
||||||
"SandboxState",
|
"SandboxState",
|
||||||
"ThreadState",
|
"ThreadState",
|
||||||
"get_checkpointer",
|
|
||||||
"reset_checkpointer",
|
|
||||||
"make_checkpointer",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -164,30 +164,6 @@ Skip simple one-off tasks.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _skill_mutability_label(category: str) -> str:
|
|
||||||
return "[custom, editable]" if category == "custom" else "[built-in]"
|
|
||||||
|
|
||||||
|
|
||||||
def clear_skills_system_prompt_cache() -> None:
|
|
||||||
_get_cached_skills_prompt_section.cache_clear()
|
|
||||||
|
|
||||||
|
|
||||||
def _build_skill_evolution_section(skill_evolution_enabled: bool) -> str:
|
|
||||||
if not skill_evolution_enabled:
|
|
||||||
return ""
|
|
||||||
return """
|
|
||||||
## Skill Self-Evolution
|
|
||||||
After completing a task, consider creating or updating a skill when:
|
|
||||||
- The task required 5+ tool calls to resolve
|
|
||||||
- You overcame non-obvious errors or pitfalls
|
|
||||||
- The user corrected your approach and the corrected version worked
|
|
||||||
- You discovered a non-trivial, recurring workflow
|
|
||||||
If you used a skill and encountered issues not covered by it, patch it immediately.
|
|
||||||
Prefer patch over edit. Before creating a new skill, confirm with the user first.
|
|
||||||
Skip simple one-off tasks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def _build_subagent_section(max_concurrent: int) -> str:
|
def _build_subagent_section(max_concurrent: int) -> str:
|
||||||
"""Build the subagent system prompt section with dynamic concurrency limit.
|
"""Build the subagent system prompt section with dynamic concurrency limit.
|
||||||
|
|
||||||
@@ -543,12 +519,13 @@ def _get_memory_context(agent_name: str | None = None) -> str:
|
|||||||
try:
|
try:
|
||||||
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
||||||
from deerflow.config.memory_config import get_memory_config
|
from deerflow.config.memory_config import get_memory_config
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
config = get_memory_config()
|
config = get_memory_config()
|
||||||
if not config.enabled or not config.injection_enabled:
|
if not config.enabled or not config.injection_enabled:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
memory_data = get_memory_data(agent_name)
|
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
||||||
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
||||||
|
|
||||||
if not memory_content.strip():
|
if not memory_content.strip():
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class ConversationContext:
|
|||||||
messages: list[Any]
|
messages: list[Any]
|
||||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||||
agent_name: str | None = None
|
agent_name: str | None = None
|
||||||
|
user_id: str | None = None
|
||||||
correction_detected: bool = False
|
correction_detected: bool = False
|
||||||
reinforcement_detected: bool = False
|
reinforcement_detected: bool = False
|
||||||
|
|
||||||
@@ -44,6 +45,7 @@ class MemoryUpdateQueue:
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
messages: list[Any],
|
messages: list[Any],
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
correction_detected: bool = False,
|
correction_detected: bool = False,
|
||||||
reinforcement_detected: bool = False,
|
reinforcement_detected: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -53,6 +55,9 @@ class MemoryUpdateQueue:
|
|||||||
thread_id: The thread ID.
|
thread_id: The thread ID.
|
||||||
messages: The conversation messages.
|
messages: The conversation messages.
|
||||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||||
|
user_id: The user ID captured at enqueue time. Stored in ConversationContext so it
|
||||||
|
survives the threading.Timer boundary (ContextVar does not propagate across
|
||||||
|
raw threads).
|
||||||
correction_detected: Whether recent turns include an explicit correction signal.
|
correction_detected: Whether recent turns include an explicit correction signal.
|
||||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||||
"""
|
"""
|
||||||
@@ -71,6 +76,7 @@ class MemoryUpdateQueue:
|
|||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
|
user_id=user_id,
|
||||||
correction_detected=merged_correction_detected,
|
correction_detected=merged_correction_detected,
|
||||||
reinforcement_detected=merged_reinforcement_detected,
|
reinforcement_detected=merged_reinforcement_detected,
|
||||||
)
|
)
|
||||||
@@ -136,6 +142,7 @@ class MemoryUpdateQueue:
|
|||||||
agent_name=context.agent_name,
|
agent_name=context.agent_name,
|
||||||
correction_detected=context.correction_detected,
|
correction_detected=context.correction_detected,
|
||||||
reinforcement_detected=context.reinforcement_detected,
|
reinforcement_detected=context.reinforcement_detected,
|
||||||
|
user_id=context.user_id,
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||||
|
|||||||
@@ -43,17 +43,17 @@ class MemoryStorage(abc.ABC):
|
|||||||
"""Abstract base class for memory storage providers."""
|
"""Abstract base class for memory storage providers."""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def load(self, agent_name: str | None = None) -> dict[str, Any]:
|
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Load memory data for the given agent."""
|
"""Load memory data for the given agent."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
|
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Force reload memory data for the given agent."""
|
"""Force reload memory data for the given agent."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||||
"""Save memory data for the given agent."""
|
"""Save memory data for the given agent."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -63,9 +63,9 @@ class FileMemoryStorage(MemoryStorage):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize the file memory storage."""
|
"""Initialize the file memory storage."""
|
||||||
# Per-agent memory cache: keyed by agent_name (None = global)
|
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global)
|
||||||
# Value: (memory_data, file_mtime)
|
# Value: (memory_data, file_mtime)
|
||||||
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
|
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {}
|
||||||
|
|
||||||
def _validate_agent_name(self, agent_name: str) -> None:
|
def _validate_agent_name(self, agent_name: str) -> None:
|
||||||
"""Validate that the agent name is safe to use in filesystem paths.
|
"""Validate that the agent name is safe to use in filesystem paths.
|
||||||
@@ -78,21 +78,29 @@ class FileMemoryStorage(MemoryStorage):
|
|||||||
if not AGENT_NAME_PATTERN.match(agent_name):
|
if not AGENT_NAME_PATTERN.match(agent_name):
|
||||||
raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}")
|
raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}")
|
||||||
|
|
||||||
def _get_memory_file_path(self, agent_name: str | None = None) -> Path:
|
def _get_memory_file_path(self, agent_name: str | None = None, *, user_id: str | None = None) -> Path:
|
||||||
"""Get the path to the memory file."""
|
"""Get the path to the memory file."""
|
||||||
|
if user_id is not None:
|
||||||
|
if agent_name is not None:
|
||||||
|
self._validate_agent_name(agent_name)
|
||||||
|
return get_paths().user_agent_memory_file(user_id, agent_name)
|
||||||
|
config = get_memory_config()
|
||||||
|
if config.storage_path and Path(config.storage_path).is_absolute():
|
||||||
|
return Path(config.storage_path)
|
||||||
|
return get_paths().user_memory_file(user_id)
|
||||||
|
# Legacy: no user_id
|
||||||
if agent_name is not None:
|
if agent_name is not None:
|
||||||
self._validate_agent_name(agent_name)
|
self._validate_agent_name(agent_name)
|
||||||
return get_paths().agent_memory_file(agent_name)
|
return get_paths().agent_memory_file(agent_name)
|
||||||
|
|
||||||
config = get_memory_config()
|
config = get_memory_config()
|
||||||
if config.storage_path:
|
if config.storage_path:
|
||||||
p = Path(config.storage_path)
|
p = Path(config.storage_path)
|
||||||
return p if p.is_absolute() else get_paths().base_dir / p
|
return p if p.is_absolute() else get_paths().base_dir / p
|
||||||
return get_paths().memory_file
|
return get_paths().memory_file
|
||||||
|
|
||||||
def _load_memory_from_file(self, agent_name: str | None = None) -> dict[str, Any]:
|
def _load_memory_from_file(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Load memory data from file."""
|
"""Load memory data from file."""
|
||||||
file_path = self._get_memory_file_path(agent_name)
|
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||||
|
|
||||||
if not file_path.exists():
|
if not file_path.exists():
|
||||||
return create_empty_memory()
|
return create_empty_memory()
|
||||||
@@ -105,40 +113,42 @@ class FileMemoryStorage(MemoryStorage):
|
|||||||
logger.warning("Failed to load memory file: %s", e)
|
logger.warning("Failed to load memory file: %s", e)
|
||||||
return create_empty_memory()
|
return create_empty_memory()
|
||||||
|
|
||||||
def load(self, agent_name: str | None = None) -> dict[str, Any]:
|
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Load memory data (cached with file modification time check)."""
|
"""Load memory data (cached with file modification time check)."""
|
||||||
file_path = self._get_memory_file_path(agent_name)
|
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
|
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||||
except OSError:
|
except OSError:
|
||||||
current_mtime = None
|
current_mtime = None
|
||||||
|
|
||||||
cached = self._memory_cache.get(agent_name)
|
cache_key = (user_id, agent_name)
|
||||||
|
cached = self._memory_cache.get(cache_key)
|
||||||
|
|
||||||
if cached is None or cached[1] != current_mtime:
|
if cached is None or cached[1] != current_mtime:
|
||||||
memory_data = self._load_memory_from_file(agent_name)
|
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
||||||
self._memory_cache[agent_name] = (memory_data, current_mtime)
|
self._memory_cache[cache_key] = (memory_data, current_mtime)
|
||||||
return memory_data
|
return memory_data
|
||||||
|
|
||||||
return cached[0]
|
return cached[0]
|
||||||
|
|
||||||
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
|
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Reload memory data from file, forcing cache invalidation."""
|
"""Reload memory data from file, forcing cache invalidation."""
|
||||||
file_path = self._get_memory_file_path(agent_name)
|
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||||
memory_data = self._load_memory_from_file(agent_name)
|
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mtime = file_path.stat().st_mtime if file_path.exists() else None
|
mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||||
except OSError:
|
except OSError:
|
||||||
mtime = None
|
mtime = None
|
||||||
|
|
||||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
cache_key = (user_id, agent_name)
|
||||||
|
self._memory_cache[cache_key] = (memory_data, mtime)
|
||||||
return memory_data
|
return memory_data
|
||||||
|
|
||||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||||
"""Save memory data to file and update cache."""
|
"""Save memory data to file and update cache."""
|
||||||
file_path = self._get_memory_file_path(agent_name)
|
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -155,7 +165,8 @@ class FileMemoryStorage(MemoryStorage):
|
|||||||
except OSError:
|
except OSError:
|
||||||
mtime = None
|
mtime = None
|
||||||
|
|
||||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
cache_key = (user_id, agent_name)
|
||||||
|
self._memory_cache[cache_key] = (memory_data, mtime)
|
||||||
logger.info("Memory saved to %s", file_path)
|
logger.info("Memory saved to %s", file_path)
|
||||||
return True
|
return True
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
|
|||||||
@@ -27,27 +27,28 @@ def _create_empty_memory() -> dict[str, Any]:
|
|||||||
return create_empty_memory()
|
return create_empty_memory()
|
||||||
|
|
||||||
|
|
||||||
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||||
"""Backward-compatible wrapper around the configured memory storage save path."""
|
"""Backward-compatible wrapper around the configured memory storage save path."""
|
||||||
return get_memory_storage().save(memory_data, agent_name)
|
return get_memory_storage().save(memory_data, agent_name, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
def get_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Get the current memory data via storage provider."""
|
"""Get the current memory data via storage provider."""
|
||||||
return get_memory_storage().load(agent_name)
|
return get_memory_storage().load(agent_name, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
def reload_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Reload memory data via storage provider."""
|
"""Reload memory data via storage provider."""
|
||||||
return get_memory_storage().reload(agent_name)
|
return get_memory_storage().reload(agent_name, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]:
|
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Persist imported memory data via storage provider.
|
"""Persist imported memory data via storage provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_data: Full memory payload to persist.
|
memory_data: Full memory payload to persist.
|
||||||
agent_name: If provided, imports into per-agent memory.
|
agent_name: If provided, imports into per-agent memory.
|
||||||
|
user_id: If provided, scopes memory to a specific user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The saved memory data after storage normalization.
|
The saved memory data after storage normalization.
|
||||||
@@ -56,15 +57,15 @@ def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = Non
|
|||||||
OSError: If persisting the imported memory fails.
|
OSError: If persisting the imported memory fails.
|
||||||
"""
|
"""
|
||||||
storage = get_memory_storage()
|
storage = get_memory_storage()
|
||||||
if not storage.save(memory_data, agent_name):
|
if not storage.save(memory_data, agent_name, user_id=user_id):
|
||||||
raise OSError("Failed to save imported memory data")
|
raise OSError("Failed to save imported memory data")
|
||||||
return storage.load(agent_name)
|
return storage.load(agent_name, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
def clear_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Clear all stored memory data and persist an empty structure."""
|
"""Clear all stored memory data and persist an empty structure."""
|
||||||
cleared_memory = create_empty_memory()
|
cleared_memory = create_empty_memory()
|
||||||
if not _save_memory_to_file(cleared_memory, agent_name):
|
if not _save_memory_to_file(cleared_memory, agent_name, user_id=user_id):
|
||||||
raise OSError("Failed to save cleared memory data")
|
raise OSError("Failed to save cleared memory data")
|
||||||
return cleared_memory
|
return cleared_memory
|
||||||
|
|
||||||
@@ -81,6 +82,8 @@ def create_memory_fact(
|
|||||||
category: str = "context",
|
category: str = "context",
|
||||||
confidence: float = 0.5,
|
confidence: float = 0.5,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Create a new fact and persist the updated memory data."""
|
"""Create a new fact and persist the updated memory data."""
|
||||||
normalized_content = content.strip()
|
normalized_content = content.strip()
|
||||||
@@ -90,7 +93,7 @@ def create_memory_fact(
|
|||||||
normalized_category = category.strip() or "context"
|
normalized_category = category.strip() or "context"
|
||||||
validated_confidence = _validate_confidence(confidence)
|
validated_confidence = _validate_confidence(confidence)
|
||||||
now = utc_now_iso_z()
|
now = utc_now_iso_z()
|
||||||
memory_data = get_memory_data(agent_name)
|
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||||
updated_memory = dict(memory_data)
|
updated_memory = dict(memory_data)
|
||||||
facts = list(memory_data.get("facts", []))
|
facts = list(memory_data.get("facts", []))
|
||||||
facts.append(
|
facts.append(
|
||||||
@@ -105,15 +108,15 @@ def create_memory_fact(
|
|||||||
)
|
)
|
||||||
updated_memory["facts"] = facts
|
updated_memory["facts"] = facts
|
||||||
|
|
||||||
if not _save_memory_to_file(updated_memory, agent_name):
|
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||||
raise OSError("Failed to save memory data after creating fact")
|
raise OSError("Failed to save memory data after creating fact")
|
||||||
|
|
||||||
return updated_memory
|
return updated_memory
|
||||||
|
|
||||||
|
|
||||||
def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]:
|
def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||||
"""Delete a fact by its id and persist the updated memory data."""
|
"""Delete a fact by its id and persist the updated memory data."""
|
||||||
memory_data = get_memory_data(agent_name)
|
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||||
facts = memory_data.get("facts", [])
|
facts = memory_data.get("facts", [])
|
||||||
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
|
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
|
||||||
if len(updated_facts) == len(facts):
|
if len(updated_facts) == len(facts):
|
||||||
@@ -122,7 +125,7 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str,
|
|||||||
updated_memory = dict(memory_data)
|
updated_memory = dict(memory_data)
|
||||||
updated_memory["facts"] = updated_facts
|
updated_memory["facts"] = updated_facts
|
||||||
|
|
||||||
if not _save_memory_to_file(updated_memory, agent_name):
|
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||||
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
|
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
|
||||||
|
|
||||||
return updated_memory
|
return updated_memory
|
||||||
@@ -134,9 +137,11 @@ def update_memory_fact(
|
|||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
confidence: float | None = None,
|
confidence: float | None = None,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Update an existing fact and persist the updated memory data."""
|
"""Update an existing fact and persist the updated memory data."""
|
||||||
memory_data = get_memory_data(agent_name)
|
memory_data = get_memory_data(agent_name, user_id=user_id)
|
||||||
updated_memory = dict(memory_data)
|
updated_memory = dict(memory_data)
|
||||||
updated_facts: list[dict[str, Any]] = []
|
updated_facts: list[dict[str, Any]] = []
|
||||||
found = False
|
found = False
|
||||||
@@ -163,7 +168,7 @@ def update_memory_fact(
|
|||||||
|
|
||||||
updated_memory["facts"] = updated_facts
|
updated_memory["facts"] = updated_facts
|
||||||
|
|
||||||
if not _save_memory_to_file(updated_memory, agent_name):
|
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
||||||
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
|
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
|
||||||
|
|
||||||
return updated_memory
|
return updated_memory
|
||||||
@@ -276,6 +281,7 @@ class MemoryUpdater:
|
|||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
correction_detected: bool = False,
|
correction_detected: bool = False,
|
||||||
reinforcement_detected: bool = False,
|
reinforcement_detected: bool = False,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Update memory based on conversation messages.
|
"""Update memory based on conversation messages.
|
||||||
|
|
||||||
@@ -285,6 +291,7 @@ class MemoryUpdater:
|
|||||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||||
correction_detected: Whether recent turns include an explicit correction signal.
|
correction_detected: Whether recent turns include an explicit correction signal.
|
||||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||||
|
user_id: If provided, scopes memory to a specific user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if update was successful, False otherwise.
|
True if update was successful, False otherwise.
|
||||||
@@ -298,7 +305,7 @@ class MemoryUpdater:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Get current memory
|
# Get current memory
|
||||||
current_memory = get_memory_data(agent_name)
|
current_memory = get_memory_data(agent_name, user_id=user_id)
|
||||||
|
|
||||||
# Format conversation for prompt
|
# Format conversation for prompt
|
||||||
conversation_text = format_conversation_for_update(messages)
|
conversation_text = format_conversation_for_update(messages)
|
||||||
@@ -353,7 +360,7 @@ class MemoryUpdater:
|
|||||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||||
|
|
||||||
# Save
|
# Save
|
||||||
return get_memory_storage().save(updated_memory, agent_name)
|
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||||
@@ -455,6 +462,7 @@ def update_memory_from_conversation(
|
|||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
correction_detected: bool = False,
|
correction_detected: bool = False,
|
||||||
reinforcement_detected: bool = False,
|
reinforcement_detected: bool = False,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Convenience function to update memory from a conversation.
|
"""Convenience function to update memory from a conversation.
|
||||||
|
|
||||||
@@ -464,9 +472,10 @@ def update_memory_from_conversation(
|
|||||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||||
correction_detected: Whether recent turns include an explicit correction signal.
|
correction_detected: Whether recent turns include an explicit correction signal.
|
||||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||||
|
user_id: If provided, scopes memory to a specific user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if successful, False otherwise.
|
True if successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
updater = MemoryUpdater()
|
updater = MemoryUpdater()
|
||||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
|
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected, user_id=user_id)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from langgraph.runtime import Runtime
|
|||||||
|
|
||||||
from deerflow.agents.memory.queue import get_memory_queue
|
from deerflow.agents.memory.queue import get_memory_queue
|
||||||
from deerflow.config.memory_config import get_memory_config
|
from deerflow.config.memory_config import get_memory_config
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -236,11 +237,16 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
|||||||
# Queue the filtered conversation for memory update
|
# Queue the filtered conversation for memory update
|
||||||
correction_detected = detect_correction(filtered_messages)
|
correction_detected = detect_correction(filtered_messages)
|
||||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||||
|
# Capture user_id at enqueue time while the request context is still alive.
|
||||||
|
# threading.Timer fires on a different thread where ContextVar values are not
|
||||||
|
# propagated, so we must store user_id explicitly in ConversationContext.
|
||||||
|
user_id = get_effective_user_id()
|
||||||
queue = get_memory_queue()
|
queue = get_memory_queue()
|
||||||
queue.add(
|
queue.add(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
messages=filtered_messages,
|
messages=filtered_messages,
|
||||||
agent_name=self._agent_name,
|
agent_name=self._agent_name,
|
||||||
|
user_id=user_id,
|
||||||
correction_detected=correction_detected,
|
correction_detected=correction_detected,
|
||||||
reinforcement_detected=reinforcement_detected,
|
reinforcement_detected=reinforcement_detected,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from langgraph.runtime import Runtime
|
|||||||
|
|
||||||
from deerflow.agents.thread_state import ThreadDataState
|
from deerflow.agents.thread_state import ThreadDataState
|
||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -46,32 +47,34 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
|||||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||||
self._lazy_init = lazy_init
|
self._lazy_init = lazy_init
|
||||||
|
|
||||||
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
|
def _get_thread_paths(self, thread_id: str, user_id: str | None = None) -> dict[str, str]:
|
||||||
"""Get the paths for a thread's data directories.
|
"""Get the paths for a thread's data directories.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thread_id: The thread ID.
|
thread_id: The thread ID.
|
||||||
|
user_id: Optional user ID for per-user path isolation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with workspace_path, uploads_path, and outputs_path.
|
Dictionary with workspace_path, uploads_path, and outputs_path.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"workspace_path": str(self._paths.sandbox_work_dir(thread_id)),
|
"workspace_path": str(self._paths.sandbox_work_dir(thread_id, user_id=user_id)),
|
||||||
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)),
|
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id, user_id=user_id)),
|
||||||
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)),
|
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id, user_id=user_id)),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
|
def _create_thread_directories(self, thread_id: str, user_id: str | None = None) -> dict[str, str]:
|
||||||
"""Create the thread data directories.
|
"""Create the thread data directories.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thread_id: The thread ID.
|
thread_id: The thread ID.
|
||||||
|
user_id: Optional user ID for per-user path isolation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with the created directory paths.
|
Dictionary with the created directory paths.
|
||||||
"""
|
"""
|
||||||
self._paths.ensure_thread_dirs(thread_id)
|
self._paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
||||||
return self._get_thread_paths(thread_id)
|
return self._get_thread_paths(thread_id, user_id=user_id)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||||
@@ -84,12 +87,14 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
|||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
raise ValueError("Thread ID is required in runtime context or config.configurable")
|
raise ValueError("Thread ID is required in runtime context or config.configurable")
|
||||||
|
|
||||||
|
user_id = get_effective_user_id()
|
||||||
|
|
||||||
if self._lazy_init:
|
if self._lazy_init:
|
||||||
# Lazy initialization: only compute paths, don't create directories
|
# Lazy initialization: only compute paths, don't create directories
|
||||||
paths = self._get_thread_paths(thread_id)
|
paths = self._get_thread_paths(thread_id, user_id=user_id)
|
||||||
else:
|
else:
|
||||||
# Eager initialization: create directories immediately
|
# Eager initialization: create directories immediately
|
||||||
paths = self._create_thread_directories(thread_id)
|
paths = self._create_thread_directories(thread_id, user_id=user_id)
|
||||||
logger.debug("Created thread data directories for thread %s", thread_id)
|
logger.debug("Created thread data directories for thread %s", thread_id)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from langchain_core.messages import HumanMessage
|
|||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.utils.file_conversion import extract_outline
|
from deerflow.utils.file_conversion import extract_outline
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -221,7 +222,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
|||||||
thread_id = get_config().get("configurable", {}).get("thread_id")
|
thread_id = get_config().get("configurable", {}).get("thread_id")
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
||||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
uploads_dir = self._paths.sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) if thread_id else None
|
||||||
|
|
||||||
# Get newly uploaded files from the current message's additional_kwargs.files
|
# Get newly uploaded files from the current message's additional_kwargs.files
|
||||||
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
|
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from deerflow.config.app_config import get_app_config, reload_app_config
|
|||||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.skills.installer import install_skill_from_archive
|
from deerflow.skills.installer import install_skill_from_archive
|
||||||
from deerflow.uploads.manager import (
|
from deerflow.uploads.manager import (
|
||||||
claim_unique_filename,
|
claim_unique_filename,
|
||||||
@@ -240,7 +241,7 @@ class DeerFlowClient:
|
|||||||
}
|
}
|
||||||
checkpointer = self._checkpointer
|
checkpointer = self._checkpointer
|
||||||
if checkpointer is None:
|
if checkpointer is None:
|
||||||
from deerflow.agents.checkpointer import get_checkpointer
|
from deerflow.runtime.checkpointer import get_checkpointer
|
||||||
|
|
||||||
checkpointer = get_checkpointer()
|
checkpointer = get_checkpointer()
|
||||||
if checkpointer is not None:
|
if checkpointer is not None:
|
||||||
@@ -374,7 +375,7 @@ class DeerFlowClient:
|
|||||||
"""
|
"""
|
||||||
checkpointer = self._checkpointer
|
checkpointer = self._checkpointer
|
||||||
if checkpointer is None:
|
if checkpointer is None:
|
||||||
from deerflow.agents.checkpointer.provider import get_checkpointer
|
from deerflow.runtime.checkpointer.provider import get_checkpointer
|
||||||
|
|
||||||
checkpointer = get_checkpointer()
|
checkpointer = get_checkpointer()
|
||||||
|
|
||||||
@@ -429,7 +430,7 @@ class DeerFlowClient:
|
|||||||
"""
|
"""
|
||||||
checkpointer = self._checkpointer
|
checkpointer = self._checkpointer
|
||||||
if checkpointer is None:
|
if checkpointer is None:
|
||||||
from deerflow.agents.checkpointer.provider import get_checkpointer
|
from deerflow.runtime.checkpointer.provider import get_checkpointer
|
||||||
|
|
||||||
checkpointer = get_checkpointer()
|
checkpointer = get_checkpointer()
|
||||||
|
|
||||||
@@ -769,19 +770,19 @@ class DeerFlowClient:
|
|||||||
"""
|
"""
|
||||||
from deerflow.agents.memory.updater import get_memory_data
|
from deerflow.agents.memory.updater import get_memory_data
|
||||||
|
|
||||||
return get_memory_data()
|
return get_memory_data(user_id=get_effective_user_id())
|
||||||
|
|
||||||
def export_memory(self) -> dict:
|
def export_memory(self) -> dict:
|
||||||
"""Export current memory data for backup or transfer."""
|
"""Export current memory data for backup or transfer."""
|
||||||
from deerflow.agents.memory.updater import get_memory_data
|
from deerflow.agents.memory.updater import get_memory_data
|
||||||
|
|
||||||
return get_memory_data()
|
return get_memory_data(user_id=get_effective_user_id())
|
||||||
|
|
||||||
def import_memory(self, memory_data: dict) -> dict:
|
def import_memory(self, memory_data: dict) -> dict:
|
||||||
"""Import and persist full memory data."""
|
"""Import and persist full memory data."""
|
||||||
from deerflow.agents.memory.updater import import_memory_data
|
from deerflow.agents.memory.updater import import_memory_data
|
||||||
|
|
||||||
return import_memory_data(memory_data)
|
return import_memory_data(memory_data, user_id=get_effective_user_id())
|
||||||
|
|
||||||
def get_model(self, name: str) -> dict | None:
|
def get_model(self, name: str) -> dict | None:
|
||||||
"""Get a specific model's configuration by name.
|
"""Get a specific model's configuration by name.
|
||||||
@@ -956,13 +957,13 @@ class DeerFlowClient:
|
|||||||
"""
|
"""
|
||||||
from deerflow.agents.memory.updater import reload_memory_data
|
from deerflow.agents.memory.updater import reload_memory_data
|
||||||
|
|
||||||
return reload_memory_data()
|
return reload_memory_data(user_id=get_effective_user_id())
|
||||||
|
|
||||||
def clear_memory(self) -> dict:
|
def clear_memory(self) -> dict:
|
||||||
"""Clear all persisted memory data."""
|
"""Clear all persisted memory data."""
|
||||||
from deerflow.agents.memory.updater import clear_memory_data
|
from deerflow.agents.memory.updater import clear_memory_data
|
||||||
|
|
||||||
return clear_memory_data()
|
return clear_memory_data(user_id=get_effective_user_id())
|
||||||
|
|
||||||
def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict:
|
def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict:
|
||||||
"""Create a single fact manually."""
|
"""Create a single fact manually."""
|
||||||
@@ -1179,7 +1180,7 @@ class DeerFlowClient:
|
|||||||
ValueError: If the path is invalid.
|
ValueError: If the path is invalid.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
actual = get_paths().resolve_virtual_path(thread_id, path)
|
actual = get_paths().resolve_virtual_path(thread_id, path, user_id=get_effective_user_id())
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
if "traversal" in str(exc):
|
if "traversal" in str(exc):
|
||||||
from deerflow.uploads.manager import PathTraversalError
|
from deerflow.uploads.manager import PathTraversalError
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ except ImportError: # pragma: no cover - Windows fallback
|
|||||||
|
|
||||||
from deerflow.config import get_app_config
|
from deerflow.config import get_app_config
|
||||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.sandbox.sandbox import Sandbox
|
from deerflow.sandbox.sandbox import Sandbox
|
||||||
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
||||||
|
|
||||||
@@ -260,15 +261,16 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
mounted Docker socket (DooD), the host Docker daemon can resolve the paths.
|
mounted Docker socket (DooD), the host Docker daemon can resolve the paths.
|
||||||
"""
|
"""
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
paths.ensure_thread_dirs(thread_id)
|
user_id = get_effective_user_id()
|
||||||
|
paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
(paths.host_sandbox_work_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
|
(paths.host_sandbox_work_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
|
||||||
(paths.host_sandbox_uploads_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
|
(paths.host_sandbox_uploads_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
|
||||||
(paths.host_sandbox_outputs_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
|
(paths.host_sandbox_outputs_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
|
||||||
# ACP workspace: read-only inside the sandbox (lead agent reads results;
|
# ACP workspace: read-only inside the sandbox (lead agent reads results;
|
||||||
# the ACP subprocess writes from the host side, not from within the container).
|
# the ACP subprocess writes from the host side, not from within the container).
|
||||||
(paths.host_acp_workspace_dir(thread_id), "/mnt/acp-workspace", True),
|
(paths.host_acp_workspace_dir(thread_id, user_id=user_id), "/mnt/acp-workspace", True),
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -480,8 +482,9 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
across multiple processes, preventing container-name conflicts.
|
across multiple processes, preventing container-name conflicts.
|
||||||
"""
|
"""
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
paths.ensure_thread_dirs(thread_id)
|
user_id = get_effective_user_id()
|
||||||
lock_path = paths.thread_dir(thread_id) / f"{sandbox_id}.lock"
|
paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
||||||
|
lock_path = paths.thread_dir(thread_id, user_id=user_id) / f"{sandbox_id}.lock"
|
||||||
|
|
||||||
with open(lock_path, "a", encoding="utf-8") as lock_file:
|
with open(lock_path, "a", encoding="utf-8") as lock_file:
|
||||||
locked = False
|
locked = False
|
||||||
|
|||||||
@@ -4,8 +4,12 @@ Controls BOTH the LangGraph checkpointer and the DeerFlow application
|
|||||||
persistence layer (runs, threads metadata, users, etc.). The user
|
persistence layer (runs, threads metadata, users, etc.). The user
|
||||||
configures one backend; the system handles physical separation details.
|
configures one backend; the system handles physical separation details.
|
||||||
|
|
||||||
SQLite mode: checkpointer and app use different .db files in the same
|
SQLite mode: checkpointer and app share a single .db file
|
||||||
directory to avoid write-lock contention. This is automatic.
|
({sqlite_dir}/deerflow.db) with WAL journal mode enabled on every
|
||||||
|
connection. WAL allows concurrent readers and a single writer without
|
||||||
|
blocking, making a unified file safe for both workloads. Writers
|
||||||
|
that contend for the lock wait via the default 5-second sqlite3
|
||||||
|
busy timeout rather than failing immediately.
|
||||||
|
|
||||||
Postgres mode: both use the same database URL but maintain independent
|
Postgres mode: both use the same database URL but maintain independent
|
||||||
connection pools with different lifecycles.
|
connection pools with different lifecycles.
|
||||||
@@ -40,7 +44,7 @@ class DatabaseConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
sqlite_dir: str = Field(
|
sqlite_dir: str = Field(
|
||||||
default=".deer-flow/data",
|
default=".deer-flow/data",
|
||||||
description=("Directory for SQLite database files. Checkpointer uses {sqlite_dir}/checkpoints.db, application data uses {sqlite_dir}/app.db."),
|
description=("Directory for the SQLite database file. Both checkpointer and application data share {sqlite_dir}/deerflow.db."),
|
||||||
)
|
)
|
||||||
postgres_url: str = Field(
|
postgres_url: str = Field(
|
||||||
default="",
|
default="",
|
||||||
@@ -69,21 +73,27 @@ class DatabaseConfig(BaseModel):
|
|||||||
|
|
||||||
return str(Path(self.sqlite_dir).resolve())
|
return str(Path(self.sqlite_dir).resolve())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sqlite_path(self) -> str:
|
||||||
|
"""Unified SQLite file path shared by checkpointer and app."""
|
||||||
|
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
|
||||||
|
|
||||||
|
# Backward-compatible aliases
|
||||||
@property
|
@property
|
||||||
def checkpointer_sqlite_path(self) -> str:
|
def checkpointer_sqlite_path(self) -> str:
|
||||||
"""SQLite file path for the LangGraph checkpointer."""
|
"""SQLite file path for the LangGraph checkpointer (alias for sqlite_path)."""
|
||||||
return os.path.join(self._resolved_sqlite_dir, "checkpoints.db")
|
return self.sqlite_path
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def app_sqlite_path(self) -> str:
|
def app_sqlite_path(self) -> str:
|
||||||
"""SQLite file path for application ORM data."""
|
"""SQLite file path for application ORM data (alias for sqlite_path)."""
|
||||||
return os.path.join(self._resolved_sqlite_dir, "app.db")
|
return self.sqlite_path
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def app_sqlalchemy_url(self) -> str:
|
def app_sqlalchemy_url(self) -> str:
|
||||||
"""SQLAlchemy async URL for the application ORM engine."""
|
"""SQLAlchemy async URL for the application ORM engine."""
|
||||||
if self.backend == "sqlite":
|
if self.backend == "sqlite":
|
||||||
return f"sqlite+aiosqlite:///{self.app_sqlite_path}"
|
return f"sqlite+aiosqlite:///{self.sqlite_path}"
|
||||||
if self.backend == "postgres":
|
if self.backend == "postgres":
|
||||||
url = self.postgres_url
|
url = self.postgres_url
|
||||||
if url.startswith("postgresql://"):
|
if url.startswith("postgresql://"):
|
||||||
|
|||||||
@@ -14,8 +14,9 @@ class MemoryConfig(BaseModel):
|
|||||||
default="",
|
default="",
|
||||||
description=(
|
description=(
|
||||||
"Path to store memory data. "
|
"Path to store memory data. "
|
||||||
"If empty, defaults to `{base_dir}/memory.json` (see Paths.memory_file). "
|
"If empty, defaults to per-user memory at `{base_dir}/users/{user_id}/memory.json`. "
|
||||||
"Absolute paths are used as-is. "
|
"Absolute paths are used as-is and opt out of per-user isolation "
|
||||||
|
"(all users share the same file). "
|
||||||
"Relative paths are resolved against `Paths.base_dir` "
|
"Relative paths are resolved against `Paths.base_dir` "
|
||||||
"(not the backend working directory). "
|
"(not the backend working directory). "
|
||||||
"Note: if you previously set this to `.deer-flow/memory.json`, "
|
"Note: if you previously set this to `.deer-flow/memory.json`, "
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from pathlib import Path, PureWindowsPath
|
|||||||
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
||||||
|
|
||||||
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||||
|
_SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||||
|
|
||||||
|
|
||||||
def _default_local_base_dir() -> Path:
|
def _default_local_base_dir() -> Path:
|
||||||
@@ -22,6 +23,13 @@ def _validate_thread_id(thread_id: str) -> str:
|
|||||||
return thread_id
|
return thread_id
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_user_id(user_id: str) -> str:
|
||||||
|
"""Validate a user ID before using it in filesystem paths."""
|
||||||
|
if not _SAFE_USER_ID_RE.match(user_id):
|
||||||
|
raise ValueError(f"Invalid user_id {user_id!r}: only alphanumeric characters, hyphens, and underscores are allowed.")
|
||||||
|
return user_id
|
||||||
|
|
||||||
|
|
||||||
def _join_host_path(base: str, *parts: str) -> str:
|
def _join_host_path(base: str, *parts: str) -> str:
|
||||||
"""Join host filesystem path segments while preserving native style.
|
"""Join host filesystem path segments while preserving native style.
|
||||||
|
|
||||||
@@ -134,44 +142,63 @@ class Paths:
|
|||||||
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
|
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
|
||||||
return self.agent_dir(name) / "memory.json"
|
return self.agent_dir(name) / "memory.json"
|
||||||
|
|
||||||
def thread_dir(self, thread_id: str) -> Path:
|
def user_dir(self, user_id: str) -> Path:
|
||||||
|
"""Directory for a specific user: `{base_dir}/users/{user_id}/`."""
|
||||||
|
return self.base_dir / "users" / _validate_user_id(user_id)
|
||||||
|
|
||||||
|
def user_memory_file(self, user_id: str) -> Path:
|
||||||
|
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
|
||||||
|
return self.user_dir(user_id) / "memory.json"
|
||||||
|
|
||||||
|
def user_agent_memory_file(self, user_id: str, agent_name: str) -> Path:
|
||||||
|
"""Per-user per-agent memory: `{base_dir}/users/{user_id}/agents/{name}/memory.json`."""
|
||||||
|
return self.user_dir(user_id) / "agents" / agent_name.lower() / "memory.json"
|
||||||
|
|
||||||
|
def thread_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||||
"""
|
"""
|
||||||
Host path for a thread's data: `{base_dir}/threads/{thread_id}/`
|
Host path for a thread's data.
|
||||||
|
|
||||||
|
When *user_id* is provided:
|
||||||
|
`{base_dir}/users/{user_id}/threads/{thread_id}/`
|
||||||
|
Otherwise (legacy layout):
|
||||||
|
`{base_dir}/threads/{thread_id}/`
|
||||||
|
|
||||||
This directory contains a `user-data/` subdirectory that is mounted
|
This directory contains a `user-data/` subdirectory that is mounted
|
||||||
as `/mnt/user-data/` inside the sandbox.
|
as `/mnt/user-data/` inside the sandbox.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `thread_id` contains unsafe characters (path separators
|
ValueError: If `thread_id` or `user_id` contains unsafe characters (path
|
||||||
or `..`) that could cause directory traversal.
|
separators or `..`) that could cause directory traversal.
|
||||||
"""
|
"""
|
||||||
|
if user_id is not None:
|
||||||
|
return self.user_dir(user_id) / "threads" / _validate_thread_id(thread_id)
|
||||||
return self.base_dir / "threads" / _validate_thread_id(thread_id)
|
return self.base_dir / "threads" / _validate_thread_id(thread_id)
|
||||||
|
|
||||||
def sandbox_work_dir(self, thread_id: str) -> Path:
|
def sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||||
"""
|
"""
|
||||||
Host path for the agent's workspace directory.
|
Host path for the agent's workspace directory.
|
||||||
Host: `{base_dir}/threads/{thread_id}/user-data/workspace/`
|
Host: `{base_dir}/threads/{thread_id}/user-data/workspace/`
|
||||||
Sandbox: `/mnt/user-data/workspace/`
|
Sandbox: `/mnt/user-data/workspace/`
|
||||||
"""
|
"""
|
||||||
return self.thread_dir(thread_id) / "user-data" / "workspace"
|
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "workspace"
|
||||||
|
|
||||||
def sandbox_uploads_dir(self, thread_id: str) -> Path:
|
def sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||||
"""
|
"""
|
||||||
Host path for user-uploaded files.
|
Host path for user-uploaded files.
|
||||||
Host: `{base_dir}/threads/{thread_id}/user-data/uploads/`
|
Host: `{base_dir}/threads/{thread_id}/user-data/uploads/`
|
||||||
Sandbox: `/mnt/user-data/uploads/`
|
Sandbox: `/mnt/user-data/uploads/`
|
||||||
"""
|
"""
|
||||||
return self.thread_dir(thread_id) / "user-data" / "uploads"
|
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "uploads"
|
||||||
|
|
||||||
def sandbox_outputs_dir(self, thread_id: str) -> Path:
|
def sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||||
"""
|
"""
|
||||||
Host path for agent-generated artifacts.
|
Host path for agent-generated artifacts.
|
||||||
Host: `{base_dir}/threads/{thread_id}/user-data/outputs/`
|
Host: `{base_dir}/threads/{thread_id}/user-data/outputs/`
|
||||||
Sandbox: `/mnt/user-data/outputs/`
|
Sandbox: `/mnt/user-data/outputs/`
|
||||||
"""
|
"""
|
||||||
return self.thread_dir(thread_id) / "user-data" / "outputs"
|
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "outputs"
|
||||||
|
|
||||||
def acp_workspace_dir(self, thread_id: str) -> Path:
|
def acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||||
"""
|
"""
|
||||||
Host path for the ACP workspace of a specific thread.
|
Host path for the ACP workspace of a specific thread.
|
||||||
Host: `{base_dir}/threads/{thread_id}/acp-workspace/`
|
Host: `{base_dir}/threads/{thread_id}/acp-workspace/`
|
||||||
@@ -180,41 +207,43 @@ class Paths:
|
|||||||
Each thread gets its own isolated ACP workspace so that concurrent
|
Each thread gets its own isolated ACP workspace so that concurrent
|
||||||
sessions cannot read each other's ACP agent outputs.
|
sessions cannot read each other's ACP agent outputs.
|
||||||
"""
|
"""
|
||||||
return self.thread_dir(thread_id) / "acp-workspace"
|
return self.thread_dir(thread_id, user_id=user_id) / "acp-workspace"
|
||||||
|
|
||||||
def sandbox_user_data_dir(self, thread_id: str) -> Path:
|
def sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
||||||
"""
|
"""
|
||||||
Host path for the user-data root.
|
Host path for the user-data root.
|
||||||
Host: `{base_dir}/threads/{thread_id}/user-data/`
|
Host: `{base_dir}/threads/{thread_id}/user-data/`
|
||||||
Sandbox: `/mnt/user-data/`
|
Sandbox: `/mnt/user-data/`
|
||||||
"""
|
"""
|
||||||
return self.thread_dir(thread_id) / "user-data"
|
return self.thread_dir(thread_id, user_id=user_id) / "user-data"
|
||||||
|
|
||||||
def host_thread_dir(self, thread_id: str) -> str:
|
def host_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||||
"""Host path for a thread directory, preserving Windows path syntax."""
|
"""Host path for a thread directory, preserving Windows path syntax."""
|
||||||
|
if user_id is not None:
|
||||||
|
return _join_host_path(self._host_base_dir_str(), "users", _validate_user_id(user_id), "threads", _validate_thread_id(thread_id))
|
||||||
return _join_host_path(self._host_base_dir_str(), "threads", _validate_thread_id(thread_id))
|
return _join_host_path(self._host_base_dir_str(), "threads", _validate_thread_id(thread_id))
|
||||||
|
|
||||||
def host_sandbox_user_data_dir(self, thread_id: str) -> str:
|
def host_sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||||
"""Host path for a thread's user-data root."""
|
"""Host path for a thread's user-data root."""
|
||||||
return _join_host_path(self.host_thread_dir(thread_id), "user-data")
|
return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "user-data")
|
||||||
|
|
||||||
def host_sandbox_work_dir(self, thread_id: str) -> str:
|
def host_sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||||
"""Host path for the workspace mount source."""
|
"""Host path for the workspace mount source."""
|
||||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "workspace")
|
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "workspace")
|
||||||
|
|
||||||
def host_sandbox_uploads_dir(self, thread_id: str) -> str:
|
def host_sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||||
"""Host path for the uploads mount source."""
|
"""Host path for the uploads mount source."""
|
||||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "uploads")
|
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "uploads")
|
||||||
|
|
||||||
def host_sandbox_outputs_dir(self, thread_id: str) -> str:
|
def host_sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||||
"""Host path for the outputs mount source."""
|
"""Host path for the outputs mount source."""
|
||||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "outputs")
|
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "outputs")
|
||||||
|
|
||||||
def host_acp_workspace_dir(self, thread_id: str) -> str:
|
def host_acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
||||||
"""Host path for the ACP workspace mount source."""
|
"""Host path for the ACP workspace mount source."""
|
||||||
return _join_host_path(self.host_thread_dir(thread_id), "acp-workspace")
|
return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "acp-workspace")
|
||||||
|
|
||||||
def ensure_thread_dirs(self, thread_id: str) -> None:
|
def ensure_thread_dirs(self, thread_id: str, *, user_id: str | None = None) -> None:
|
||||||
"""Create all standard sandbox directories for a thread.
|
"""Create all standard sandbox directories for a thread.
|
||||||
|
|
||||||
Directories are created with mode 0o777 so that sandbox containers
|
Directories are created with mode 0o777 so that sandbox containers
|
||||||
@@ -228,24 +257,24 @@ class Paths:
|
|||||||
ACP agent invocation.
|
ACP agent invocation.
|
||||||
"""
|
"""
|
||||||
for d in [
|
for d in [
|
||||||
self.sandbox_work_dir(thread_id),
|
self.sandbox_work_dir(thread_id, user_id=user_id),
|
||||||
self.sandbox_uploads_dir(thread_id),
|
self.sandbox_uploads_dir(thread_id, user_id=user_id),
|
||||||
self.sandbox_outputs_dir(thread_id),
|
self.sandbox_outputs_dir(thread_id, user_id=user_id),
|
||||||
self.acp_workspace_dir(thread_id),
|
self.acp_workspace_dir(thread_id, user_id=user_id),
|
||||||
]:
|
]:
|
||||||
d.mkdir(parents=True, exist_ok=True)
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
d.chmod(0o777)
|
d.chmod(0o777)
|
||||||
|
|
||||||
def delete_thread_dir(self, thread_id: str) -> None:
|
def delete_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> None:
|
||||||
"""Delete all persisted data for a thread.
|
"""Delete all persisted data for a thread.
|
||||||
|
|
||||||
The operation is idempotent: missing thread directories are ignored.
|
The operation is idempotent: missing thread directories are ignored.
|
||||||
"""
|
"""
|
||||||
thread_dir = self.thread_dir(thread_id)
|
thread_dir = self.thread_dir(thread_id, user_id=user_id)
|
||||||
if thread_dir.exists():
|
if thread_dir.exists():
|
||||||
shutil.rmtree(thread_dir)
|
shutil.rmtree(thread_dir)
|
||||||
|
|
||||||
def resolve_virtual_path(self, thread_id: str, virtual_path: str) -> Path:
|
def resolve_virtual_path(self, thread_id: str, virtual_path: str, *, user_id: str | None = None) -> Path:
|
||||||
"""Resolve a sandbox virtual path to the actual host filesystem path.
|
"""Resolve a sandbox virtual path to the actual host filesystem path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -253,6 +282,7 @@ class Paths:
|
|||||||
virtual_path: Virtual path as seen inside the sandbox, e.g.
|
virtual_path: Virtual path as seen inside the sandbox, e.g.
|
||||||
``/mnt/user-data/outputs/report.pdf``.
|
``/mnt/user-data/outputs/report.pdf``.
|
||||||
Leading slashes are stripped before matching.
|
Leading slashes are stripped before matching.
|
||||||
|
user_id: Optional user ID for user-scoped path resolution.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The resolved absolute host filesystem path.
|
The resolved absolute host filesystem path.
|
||||||
@@ -270,7 +300,7 @@ class Paths:
|
|||||||
raise ValueError(f"Path must start with /{prefix}")
|
raise ValueError(f"Path must start with /{prefix}")
|
||||||
|
|
||||||
relative = stripped[len(prefix) :].lstrip("/")
|
relative = stripped[len(prefix) :].lstrip("/")
|
||||||
base = self.sandbox_user_data_dir(thread_id).resolve()
|
base = self.sandbox_user_data_dir(thread_id, user_id=user_id).resolve()
|
||||||
actual = (base / relative).resolve()
|
actual = (base / relative).resolve()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -98,6 +98,11 @@ async def init_engine(
|
|||||||
# SQLite deployment (TC-UPG-06 in AUTH_TEST_PLAN.md). The companion
|
# SQLite deployment (TC-UPG-06 in AUTH_TEST_PLAN.md). The companion
|
||||||
# ``synchronous=NORMAL`` is the safe-and-fast pairing — fsync only
|
# ``synchronous=NORMAL`` is the safe-and-fast pairing — fsync only
|
||||||
# at WAL checkpoint boundaries instead of every commit.
|
# at WAL checkpoint boundaries instead of every commit.
|
||||||
|
# Note: we do not set PRAGMA busy_timeout here — Python's sqlite3
|
||||||
|
# driver already defaults to a 5-second busy timeout (see the
|
||||||
|
# ``timeout`` kwarg of ``sqlite3.connect``), and aiosqlite /
|
||||||
|
# SQLAlchemy's aiosqlite dialect inherit that default. Setting
|
||||||
|
# it again would be a no-op.
|
||||||
@event.listens_for(_engine.sync_engine, "connect")
|
@event.listens_for(_engine.sync_engine, "connect")
|
||||||
def _enable_sqlite_wal(dbapi_conn, _record): # noqa: ARG001 — SQLAlchemy contract
|
def _enable_sqlite_wal(dbapi_conn, _record): # noqa: ARG001 — SQLAlchemy contract
|
||||||
cursor = dbapi_conn.cursor()
|
cursor = dbapi_conn.cursor()
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from sqlalchemy import DateTime, String, Text
|
from sqlalchemy import DateTime, String, Text, UniqueConstraint
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from deerflow.persistence.base import Base
|
from deerflow.persistence.base import Base
|
||||||
@@ -13,10 +13,14 @@ from deerflow.persistence.base import Base
|
|||||||
class FeedbackRow(Base):
|
class FeedbackRow(Base):
|
||||||
__tablename__ = "feedback"
|
__tablename__ = "feedback"
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
|
||||||
|
)
|
||||||
|
|
||||||
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||||
message_id: Mapped[str | None] = mapped_column(String(64))
|
message_id: Mapped[str | None] = mapped_column(String(64))
|
||||||
# message_id is an optional RunEventStore event identifier —
|
# message_id is an optional RunEventStore event identifier —
|
||||||
# allows feedback to target a specific message or the entire run
|
# allows feedback to target a specific message or the entire run
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from sqlalchemy import case, func, select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
from deerflow.persistence.feedback.model import FeedbackRow
|
from deerflow.persistence.feedback.model import FeedbackRow
|
||||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||||
|
|
||||||
|
|
||||||
class FeedbackRepository:
|
class FeedbackRepository:
|
||||||
@@ -33,19 +33,19 @@ class FeedbackRepository:
|
|||||||
run_id: str,
|
run_id: str,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
rating: int,
|
rating: int,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
comment: str | None = None,
|
comment: str | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Create a feedback record. rating must be +1 or -1."""
|
"""Create a feedback record. rating must be +1 or -1."""
|
||||||
if rating not in (1, -1):
|
if rating not in (1, -1):
|
||||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create")
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.create")
|
||||||
row = FeedbackRow(
|
row = FeedbackRow(
|
||||||
feedback_id=str(uuid.uuid4()),
|
feedback_id=str(uuid.uuid4()),
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
owner_id=resolved_owner_id,
|
user_id=resolved_user_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
rating=rating,
|
rating=rating,
|
||||||
comment=comment,
|
comment=comment,
|
||||||
@@ -61,14 +61,14 @@ class FeedbackRepository:
|
|||||||
self,
|
self,
|
||||||
feedback_id: str,
|
feedback_id: str,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> dict | None:
|
) -> dict | None:
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get")
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.get")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(FeedbackRow, feedback_id)
|
row = await session.get(FeedbackRow, feedback_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||||
return None
|
return None
|
||||||
return self._row_to_dict(row)
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
@@ -78,12 +78,12 @@ class FeedbackRepository:
|
|||||||
run_id: str,
|
run_id: str,
|
||||||
*,
|
*,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run")
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_run")
|
||||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
@@ -94,12 +94,12 @@ class FeedbackRepository:
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread")
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread")
|
||||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
@@ -109,19 +109,97 @@ class FeedbackRepository:
|
|||||||
self,
|
self,
|
||||||
feedback_id: str,
|
feedback_id: str,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete")
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(FeedbackRow, feedback_id)
|
row = await session.get(FeedbackRow, feedback_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return False
|
return False
|
||||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||||
return False
|
return False
|
||||||
await session.delete(row)
|
await session.delete(row)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def upsert(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str,
|
||||||
|
thread_id: str,
|
||||||
|
rating: int,
|
||||||
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
|
comment: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Create or update feedback for (thread_id, run_id, user_id). rating must be +1 or -1."""
|
||||||
|
if rating not in (1, -1):
|
||||||
|
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||||
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.upsert")
|
||||||
|
async with self._sf() as session:
|
||||||
|
stmt = select(FeedbackRow).where(
|
||||||
|
FeedbackRow.thread_id == thread_id,
|
||||||
|
FeedbackRow.run_id == run_id,
|
||||||
|
FeedbackRow.user_id == resolved_user_id,
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is not None:
|
||||||
|
row.rating = rating
|
||||||
|
row.comment = comment
|
||||||
|
row.created_at = datetime.now(UTC)
|
||||||
|
else:
|
||||||
|
row = FeedbackRow(
|
||||||
|
feedback_id=str(uuid.uuid4()),
|
||||||
|
run_id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
user_id=resolved_user_id,
|
||||||
|
rating=rating,
|
||||||
|
comment=comment,
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
session.add(row)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(row)
|
||||||
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
|
async def delete_by_run(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
|
) -> bool:
|
||||||
|
"""Delete the current user's feedback for a run. Returns True if a record was deleted."""
|
||||||
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete_by_run")
|
||||||
|
async with self._sf() as session:
|
||||||
|
stmt = select(FeedbackRow).where(
|
||||||
|
FeedbackRow.thread_id == thread_id,
|
||||||
|
FeedbackRow.run_id == run_id,
|
||||||
|
FeedbackRow.user_id == resolved_user_id,
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
return False
|
||||||
|
await session.delete(row)
|
||||||
|
await session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def list_by_thread_grouped(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
|
) -> dict[str, dict]:
|
||||||
|
"""Return feedback grouped by run_id for a thread: {run_id: feedback_dict}."""
|
||||||
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread_grouped")
|
||||||
|
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||||
|
if resolved_user_id is not None:
|
||||||
|
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||||
|
async with self._sf() as session:
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return {row.run_id: self._row_to_dict(row) for row in result.scalars()}
|
||||||
|
|
||||||
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
|
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
|
||||||
"""Aggregate feedback stats for a run using database-side counting."""
|
"""Aggregate feedback stats for a run using database-side counting."""
|
||||||
stmt = select(
|
stmt = select(
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
script_location = %(here)s
|
script_location = %(here)s
|
||||||
# Default URL for offline mode / autogenerate.
|
# Default URL for offline mode / autogenerate.
|
||||||
# Runtime uses engine from DeerFlow config.
|
# Runtime uses engine from DeerFlow config.
|
||||||
sqlalchemy.url = sqlite+aiosqlite:///./data/app.db
|
sqlalchemy.url = sqlite+aiosqlite:///./data/deerflow.db
|
||||||
|
|
||||||
[loggers]
|
[loggers]
|
||||||
keys = root,sqlalchemy,alembic
|
keys = root,sqlalchemy,alembic
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class RunEventRow(Base):
|
|||||||
# Owner of the conversation this event belongs to. Nullable for data
|
# Owner of the conversation this event belongs to. Nullable for data
|
||||||
# created before auth was introduced; populated by auth middleware on
|
# created before auth was introduced; populated by auth middleware on
|
||||||
# new writes and by the boot-time orphan migration on existing rows.
|
# new writes and by the boot-time orphan migration on existing rows.
|
||||||
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
user_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||||
# "message" | "trace" | "lifecycle"
|
# "message" | "trace" | "lifecycle"
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class RunRow(Base):
|
|||||||
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
assistant_id: Mapped[str | None] = mapped_column(String(128))
|
assistant_id: Mapped[str | None] = mapped_column(String(128))
|
||||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||||
status: Mapped[str] = mapped_column(String(20), default="pending")
|
status: Mapped[str] = mapped_column(String(20), default="pending")
|
||||||
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
|
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|||||||
|
|
||||||
from deerflow.persistence.run.model import RunRow
|
from deerflow.persistence.run.model import RunRow
|
||||||
from deerflow.runtime.runs.store.base import RunStore
|
from deerflow.runtime.runs.store.base import RunStore
|
||||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||||
|
|
||||||
|
|
||||||
class RunRepository(RunStore):
|
class RunRepository(RunStore):
|
||||||
@@ -69,7 +69,7 @@ class RunRepository(RunStore):
|
|||||||
*,
|
*,
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=None,
|
assistant_id=None,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
status="pending",
|
status="pending",
|
||||||
multitask_strategy="reject",
|
multitask_strategy="reject",
|
||||||
metadata=None,
|
metadata=None,
|
||||||
@@ -78,13 +78,13 @@ class RunRepository(RunStore):
|
|||||||
created_at=None,
|
created_at=None,
|
||||||
follow_up_to_run_id=None,
|
follow_up_to_run_id=None,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put")
|
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put")
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
row = RunRow(
|
row = RunRow(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
owner_id=resolved_owner_id,
|
user_id=resolved_user_id,
|
||||||
status=status,
|
status=status,
|
||||||
multitask_strategy=multitask_strategy,
|
multitask_strategy=multitask_strategy,
|
||||||
metadata_json=self._safe_json(metadata) or {},
|
metadata_json=self._safe_json(metadata) or {},
|
||||||
@@ -102,14 +102,14 @@ class RunRepository(RunStore):
|
|||||||
self,
|
self,
|
||||||
run_id,
|
run_id,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get")
|
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.get")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(RunRow, run_id)
|
row = await session.get(RunRow, run_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||||
return None
|
return None
|
||||||
return self._row_to_dict(row)
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
@@ -117,13 +117,13 @@ class RunRepository(RunStore):
|
|||||||
self,
|
self,
|
||||||
thread_id,
|
thread_id,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
limit=100,
|
limit=100,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread")
|
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.list_by_thread")
|
||||||
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(RunRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(RunRow.user_id == resolved_user_id)
|
||||||
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
@@ -141,14 +141,14 @@ class RunRepository(RunStore):
|
|||||||
self,
|
self,
|
||||||
run_id,
|
run_id,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete")
|
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.delete")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(RunRow, run_id)
|
row = await session.get(RunRow, run_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return
|
return
|
||||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||||
return
|
return
|
||||||
await session.delete(row)
|
await session.delete(row)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -1,13 +1,38 @@
|
|||||||
"""Thread metadata persistence — ORM, abstract store, and concrete implementations."""
|
"""Thread metadata persistence — ORM, abstract store, and concrete implementations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||||
from deerflow.persistence.thread_meta.sql import ThreadMetaRepository
|
from deerflow.persistence.thread_meta.sql import ThreadMetaRepository
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langgraph.store.base import BaseStore
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MemoryThreadMetaStore",
|
"MemoryThreadMetaStore",
|
||||||
"ThreadMetaRepository",
|
"ThreadMetaRepository",
|
||||||
"ThreadMetaRow",
|
"ThreadMetaRow",
|
||||||
"ThreadMetaStore",
|
"ThreadMetaStore",
|
||||||
|
"make_thread_store",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def make_thread_store(
|
||||||
|
session_factory: async_sessionmaker[AsyncSession] | None,
|
||||||
|
store: BaseStore | None = None,
|
||||||
|
) -> ThreadMetaStore:
|
||||||
|
"""Create the appropriate ThreadMetaStore based on available backends.
|
||||||
|
|
||||||
|
Returns a SQL-backed repository when a session factory is available,
|
||||||
|
otherwise falls back to the in-memory LangGraph Store implementation.
|
||||||
|
"""
|
||||||
|
if session_factory is not None:
|
||||||
|
return ThreadMetaRepository(session_factory)
|
||||||
|
if store is None:
|
||||||
|
raise ValueError("make_thread_store requires either a session_factory (SQL) or a store (memory)")
|
||||||
|
return MemoryThreadMetaStore(store)
|
||||||
|
|||||||
@@ -3,12 +3,21 @@
|
|||||||
Implementations:
|
Implementations:
|
||||||
- ThreadMetaRepository: SQL-backed (sqlite / postgres via SQLAlchemy)
|
- ThreadMetaRepository: SQL-backed (sqlite / postgres via SQLAlchemy)
|
||||||
- MemoryThreadMetaStore: wraps LangGraph BaseStore (memory mode)
|
- MemoryThreadMetaStore: wraps LangGraph BaseStore (memory mode)
|
||||||
|
|
||||||
|
All mutating and querying methods accept a ``user_id`` parameter with
|
||||||
|
three-state semantics (see :mod:`deerflow.runtime.user_context`):
|
||||||
|
|
||||||
|
- ``AUTO`` (default): resolve from the request-scoped contextvar.
|
||||||
|
- Explicit ``str``: use the provided value verbatim.
|
||||||
|
- Explicit ``None``: bypass owner filtering (migration/CLI only).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
|
||||||
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel
|
||||||
|
|
||||||
|
|
||||||
class ThreadMetaStore(abc.ABC):
|
class ThreadMetaStore(abc.ABC):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -17,14 +26,14 @@ class ThreadMetaStore(abc.ABC):
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
assistant_id: str | None = None,
|
assistant_id: str | None = None,
|
||||||
owner_id: str | None = None,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
display_name: str | None = None,
|
display_name: str | None = None,
|
||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get(self, thread_id: str) -> dict | None:
|
async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -35,26 +44,33 @@ class ThreadMetaStore(abc.ABC):
|
|||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def update_status(self, thread_id: str, status: str) -> None:
|
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
|
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
"""Merge ``metadata`` into the thread's metadata field.
|
"""Merge ``metadata`` into the thread's metadata field.
|
||||||
|
|
||||||
Existing keys are overwritten by the new values; keys absent from
|
Existing keys are overwritten by the new values; keys absent from
|
||||||
``metadata`` are preserved. No-op if the thread does not exist.
|
``metadata`` are preserved. No-op if the thread does not exist
|
||||||
|
or the owner check fails.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def delete(self, thread_id: str) -> None:
|
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||||
|
"""Check if ``user_id`` has access to ``thread_id``."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from typing import Any
|
|||||||
from langgraph.store.base import BaseStore
|
from langgraph.store.base import BaseStore
|
||||||
|
|
||||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||||
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||||
|
|
||||||
THREADS_NS: tuple[str, ...] = ("threads",)
|
THREADS_NS: tuple[str, ...] = ("threads",)
|
||||||
|
|
||||||
@@ -21,20 +22,37 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
|||||||
def __init__(self, store: BaseStore) -> None:
|
def __init__(self, store: BaseStore) -> None:
|
||||||
self._store = store
|
self._store = store
|
||||||
|
|
||||||
|
async def _get_owned_record(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
user_id: str | None | _AutoSentinel,
|
||||||
|
method_name: str,
|
||||||
|
) -> dict | None:
|
||||||
|
"""Fetch a record and verify ownership. Returns a mutable copy, or None."""
|
||||||
|
resolved = resolve_user_id(user_id, method_name=method_name)
|
||||||
|
item = await self._store.aget(THREADS_NS, thread_id)
|
||||||
|
if item is None:
|
||||||
|
return None
|
||||||
|
record = dict(item.value)
|
||||||
|
if resolved is not None and record.get("user_id") != resolved:
|
||||||
|
return None
|
||||||
|
return record
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
assistant_id: str | None = None,
|
assistant_id: str | None = None,
|
||||||
owner_id: str | None = None,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
display_name: str | None = None,
|
display_name: str | None = None,
|
||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create")
|
||||||
now = time.time()
|
now = time.time()
|
||||||
record: dict[str, Any] = {
|
record: dict[str, Any] = {
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
"assistant_id": assistant_id,
|
"assistant_id": assistant_id,
|
||||||
"owner_id": owner_id,
|
"user_id": resolved_user_id,
|
||||||
"display_name": display_name,
|
"display_name": display_name,
|
||||||
"status": "idle",
|
"status": "idle",
|
||||||
"metadata": metadata or {},
|
"metadata": metadata or {},
|
||||||
@@ -45,9 +63,8 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
|||||||
await self._store.aput(THREADS_NS, thread_id, record)
|
await self._store.aput(THREADS_NS, thread_id, record)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
async def get(self, thread_id: str) -> dict | None:
|
async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None:
|
||||||
item = await self._store.aget(THREADS_NS, thread_id)
|
return await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.get")
|
||||||
return item.value if item is not None else None
|
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
@@ -56,12 +73,16 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
|||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
|
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search")
|
||||||
filter_dict: dict[str, Any] = {}
|
filter_dict: dict[str, Any] = {}
|
||||||
if metadata:
|
if metadata:
|
||||||
filter_dict.update(metadata)
|
filter_dict.update(metadata)
|
||||||
if status:
|
if status:
|
||||||
filter_dict["status"] = status
|
filter_dict["status"] = status
|
||||||
|
if resolved_user_id is not None:
|
||||||
|
filter_dict["user_id"] = resolved_user_id
|
||||||
|
|
||||||
items = await self._store.asearch(
|
items = await self._store.asearch(
|
||||||
THREADS_NS,
|
THREADS_NS,
|
||||||
@@ -71,37 +92,45 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
|||||||
)
|
)
|
||||||
return [self._item_to_dict(item) for item in items]
|
return [self._item_to_dict(item) for item in items]
|
||||||
|
|
||||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||||
item = await self._store.aget(THREADS_NS, thread_id)
|
item = await self._store.aget(THREADS_NS, thread_id)
|
||||||
if item is None:
|
if item is None:
|
||||||
|
return not require_existing
|
||||||
|
record_user_id = item.value.get("user_id")
|
||||||
|
if record_user_id is None:
|
||||||
|
return True
|
||||||
|
return record_user_id == user_id
|
||||||
|
|
||||||
|
async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
|
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_display_name")
|
||||||
|
if record is None:
|
||||||
return
|
return
|
||||||
record = dict(item.value)
|
|
||||||
record["display_name"] = display_name
|
record["display_name"] = display_name
|
||||||
record["updated_at"] = time.time()
|
record["updated_at"] = time.time()
|
||||||
await self._store.aput(THREADS_NS, thread_id, record)
|
await self._store.aput(THREADS_NS, thread_id, record)
|
||||||
|
|
||||||
async def update_status(self, thread_id: str, status: str) -> None:
|
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
item = await self._store.aget(THREADS_NS, thread_id)
|
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_status")
|
||||||
if item is None:
|
if record is None:
|
||||||
return
|
return
|
||||||
record = dict(item.value)
|
|
||||||
record["status"] = status
|
record["status"] = status
|
||||||
record["updated_at"] = time.time()
|
record["updated_at"] = time.time()
|
||||||
await self._store.aput(THREADS_NS, thread_id, record)
|
await self._store.aput(THREADS_NS, thread_id, record)
|
||||||
|
|
||||||
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
|
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
"""Merge ``metadata`` into the in-memory record. No-op if absent."""
|
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_metadata")
|
||||||
item = await self._store.aget(THREADS_NS, thread_id)
|
if record is None:
|
||||||
if item is None:
|
|
||||||
return
|
return
|
||||||
record = dict(item.value)
|
|
||||||
merged = dict(record.get("metadata") or {})
|
merged = dict(record.get("metadata") or {})
|
||||||
merged.update(metadata)
|
merged.update(metadata)
|
||||||
record["metadata"] = merged
|
record["metadata"] = merged
|
||||||
record["updated_at"] = time.time()
|
record["updated_at"] = time.time()
|
||||||
await self._store.aput(THREADS_NS, thread_id, record)
|
await self._store.aput(THREADS_NS, thread_id, record)
|
||||||
|
|
||||||
async def delete(self, thread_id: str) -> None:
|
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
|
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
|
||||||
|
if record is None:
|
||||||
|
return
|
||||||
await self._store.adelete(THREADS_NS, thread_id)
|
await self._store.adelete(THREADS_NS, thread_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -111,7 +140,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
|||||||
return {
|
return {
|
||||||
"thread_id": item.key,
|
"thread_id": item.key,
|
||||||
"assistant_id": val.get("assistant_id"),
|
"assistant_id": val.get("assistant_id"),
|
||||||
"owner_id": val.get("owner_id"),
|
"user_id": val.get("user_id"),
|
||||||
"display_name": val.get("display_name"),
|
"display_name": val.get("display_name"),
|
||||||
"status": val.get("status", "idle"),
|
"status": val.get("status", "idle"),
|
||||||
"metadata": val.get("metadata", {}),
|
"metadata": val.get("metadata", {}),
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class ThreadMetaRow(Base):
|
|||||||
|
|
||||||
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
|
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
|
||||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||||
display_name: Mapped[str | None] = mapped_column(String(256))
|
display_name: Mapped[str | None] = mapped_column(String(256))
|
||||||
status: Mapped[str] = mapped_column(String(20), default="idle")
|
status: Mapped[str] = mapped_column(String(20), default="idle")
|
||||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|||||||
|
|
||||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||||
|
|
||||||
|
|
||||||
class ThreadMetaRepository(ThreadMetaStore):
|
class ThreadMetaRepository(ThreadMetaStore):
|
||||||
@@ -32,18 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
assistant_id: str | None = None,
|
assistant_id: str | None = None,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
display_name: str | None = None,
|
display_name: str | None = None,
|
||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# Auto-resolve owner_id from contextvar when AUTO; explicit None
|
# Auto-resolve user_id from contextvar when AUTO; explicit None
|
||||||
# creates an orphan row (used by migration scripts).
|
# creates an orphan row (used by migration scripts).
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create")
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.create")
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
row = ThreadMetaRow(
|
row = ThreadMetaRow(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
owner_id=resolved_owner_id,
|
user_id=resolved_user_id,
|
||||||
display_name=display_name,
|
display_name=display_name,
|
||||||
metadata_json=metadata or {},
|
metadata_json=metadata or {},
|
||||||
created_at=now,
|
created_at=now,
|
||||||
@@ -59,40 +59,34 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> dict | None:
|
) -> dict | None:
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get")
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.get")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
# Enforce owner filter unless explicitly bypassed (owner_id=None).
|
# Enforce owner filter unless explicitly bypassed (user_id=None).
|
||||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||||
return None
|
return None
|
||||||
return self._row_to_dict(row)
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
|
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||||
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
|
"""Check if ``user_id`` has access to ``thread_id``.
|
||||||
async with self._sf() as session:
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
return [self._row_to_dict(r) for r in result.scalars()]
|
|
||||||
|
|
||||||
async def check_access(self, thread_id: str, owner_id: str, *, require_existing: bool = False) -> bool:
|
|
||||||
"""Check if ``owner_id`` has access to ``thread_id``.
|
|
||||||
|
|
||||||
Two modes — one row, two distinct semantics depending on what
|
Two modes — one row, two distinct semantics depending on what
|
||||||
the caller is about to do:
|
the caller is about to do:
|
||||||
|
|
||||||
- ``require_existing=False`` (default, permissive):
|
- ``require_existing=False`` (default, permissive):
|
||||||
Returns True for: row missing (untracked legacy thread),
|
Returns True for: row missing (untracked legacy thread),
|
||||||
``row.owner_id`` is None (shared / pre-auth data),
|
``row.user_id`` is None (shared / pre-auth data),
|
||||||
or ``row.owner_id == owner_id``. Use for **read-style**
|
or ``row.user_id == user_id``. Use for **read-style**
|
||||||
decorators where treating an untracked thread as accessible
|
decorators where treating an untracked thread as accessible
|
||||||
preserves backward-compat.
|
preserves backward-compat.
|
||||||
|
|
||||||
- ``require_existing=True`` (strict):
|
- ``require_existing=True`` (strict):
|
||||||
Returns True **only** when the row exists AND
|
Returns True **only** when the row exists AND
|
||||||
(``row.owner_id == owner_id`` OR ``row.owner_id is None``).
|
(``row.user_id == user_id`` OR ``row.user_id is None``).
|
||||||
Use for **destructive / mutating** decorators (DELETE, PATCH,
|
Use for **destructive / mutating** decorators (DELETE, PATCH,
|
||||||
state-update) so a thread that has *already been deleted*
|
state-update) so a thread that has *already been deleted*
|
||||||
cannot be re-targeted by any caller — closing the
|
cannot be re-targeted by any caller — closing the
|
||||||
@@ -103,9 +97,9 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return not require_existing
|
return not require_existing
|
||||||
if row.owner_id is None:
|
if row.user_id is None:
|
||||||
return True
|
return True
|
||||||
return row.owner_id == owner_id
|
return row.user_id == user_id
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
@@ -114,17 +108,17 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Search threads with optional metadata and status filters.
|
"""Search threads with optional metadata and status filters.
|
||||||
|
|
||||||
Owner filter is enforced by default: caller must be in a user
|
Owner filter is enforced by default: caller must be in a user
|
||||||
context. Pass ``owner_id=None`` to bypass (migration/CLI).
|
context. Pass ``user_id=None`` to bypass (migration/CLI).
|
||||||
"""
|
"""
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search")
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search")
|
||||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id)
|
||||||
if status:
|
if status:
|
||||||
stmt = stmt.where(ThreadMetaRow.status == status)
|
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||||
|
|
||||||
@@ -144,24 +138,24 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [self._row_to_dict(r) for r in result.scalars()]
|
return [self._row_to_dict(r) for r in result.scalars()]
|
||||||
|
|
||||||
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool:
|
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool:
|
||||||
"""Return True if the row exists and is owned (or filter bypassed)."""
|
"""Return True if the row exists and is owned (or filter bypassed)."""
|
||||||
if resolved_owner_id is None:
|
if resolved_user_id is None:
|
||||||
return True # explicit bypass
|
return True # explicit bypass
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
return row is not None and row.owner_id == resolved_owner_id
|
return row is not None and row.user_id == resolved_user_id
|
||||||
|
|
||||||
async def update_display_name(
|
async def update_display_name(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
display_name: str,
|
display_name: str,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update the display_name (title) for a thread."""
|
"""Update the display_name (title) for a thread."""
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name")
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_display_name")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||||
return
|
return
|
||||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -171,11 +165,11 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
status: str,
|
status: str,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status")
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_status")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||||
return
|
return
|
||||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -185,20 +179,20 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
metadata: dict,
|
metadata: dict,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Merge ``metadata`` into ``metadata_json``.
|
"""Merge ``metadata`` into ``metadata_json``.
|
||||||
|
|
||||||
Read-modify-write inside a single session/transaction so concurrent
|
Read-modify-write inside a single session/transaction so concurrent
|
||||||
callers see consistent state. No-op if the row does not exist or
|
callers see consistent state. No-op if the row does not exist or
|
||||||
the owner_id check fails.
|
the user_id check fails.
|
||||||
"""
|
"""
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata")
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_metadata")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return
|
return
|
||||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||||
return
|
return
|
||||||
merged = dict(row.metadata_json or {})
|
merged = dict(row.metadata_json or {})
|
||||||
merged.update(metadata)
|
merged.update(metadata)
|
||||||
@@ -210,14 +204,14 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete")
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.delete")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return
|
return
|
||||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||||
return
|
return
|
||||||
await session.delete(row)
|
await session.delete(row)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -5,12 +5,18 @@ Re-exports the public API of :mod:`~deerflow.runtime.runs` and
|
|||||||
directly from ``deerflow.runtime``.
|
directly from ``deerflow.runtime``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .checkpointer import checkpointer_context, get_checkpointer, make_checkpointer, reset_checkpointer
|
||||||
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
|
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
|
||||||
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
|
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
|
||||||
from .store import get_store, make_store, reset_store, store_context
|
from .store import get_store, make_store, reset_store, store_context
|
||||||
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
|
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# checkpointer
|
||||||
|
"checkpointer_context",
|
||||||
|
"get_checkpointer",
|
||||||
|
"make_checkpointer",
|
||||||
|
"reset_checkpointer",
|
||||||
# runs
|
# runs
|
||||||
"ConflictError",
|
"ConflictError",
|
||||||
"DisconnectMode",
|
"DisconnectMode",
|
||||||
|
|||||||
+4
-4
@@ -7,12 +7,12 @@ Supported backends: memory, sqlite, postgres.
|
|||||||
|
|
||||||
Usage (e.g. FastAPI lifespan)::
|
Usage (e.g. FastAPI lifespan)::
|
||||||
|
|
||||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||||
|
|
||||||
async with make_checkpointer() as checkpointer:
|
async with make_checkpointer() as checkpointer:
|
||||||
app.state.checkpointer = checkpointer # InMemorySaver if not configured
|
app.state.checkpointer = checkpointer # InMemorySaver if not configured
|
||||||
|
|
||||||
For sync usage see :mod:`deerflow.agents.checkpointer.provider`.
|
For sync usage see :mod:`deerflow.runtime.checkpointer.provider`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -24,12 +24,12 @@ from collections.abc import AsyncIterator
|
|||||||
|
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
from deerflow.agents.checkpointer.provider import (
|
from deerflow.config.app_config import get_app_config
|
||||||
|
from deerflow.runtime.checkpointer.provider import (
|
||||||
POSTGRES_CONN_REQUIRED,
|
POSTGRES_CONN_REQUIRED,
|
||||||
POSTGRES_INSTALL,
|
POSTGRES_INSTALL,
|
||||||
SQLITE_INSTALL,
|
SQLITE_INSTALL,
|
||||||
)
|
)
|
||||||
from deerflow.config.app_config import get_app_config
|
|
||||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
+1
-1
@@ -7,7 +7,7 @@ Supported backends: memory, sqlite, postgres.
|
|||||||
|
|
||||||
Usage::
|
Usage::
|
||||||
|
|
||||||
from deerflow.agents.checkpointer.provider import get_checkpointer, checkpointer_context
|
from deerflow.runtime.checkpointer.provider import get_checkpointer, checkpointer_context
|
||||||
|
|
||||||
# Singleton — reused across calls, closed on process exit
|
# Singleton — reused across calls, closed on process exit
|
||||||
cp = get_checkpointer()
|
cp = get_checkpointer()
|
||||||
@@ -83,8 +83,18 @@ class RunEventStore(abc.ABC):
|
|||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
run_id: str,
|
run_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending."""
|
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending.
|
||||||
|
|
||||||
|
Supports bidirectional cursor pagination:
|
||||||
|
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
|
||||||
|
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
|
||||||
|
- neither: return the latest ``limit`` records (ascending)
|
||||||
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def count_messages(self, thread_id: str) -> int:
|
async def count_messages(self, thread_id: str) -> int:
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|||||||
|
|
||||||
from deerflow.persistence.models.run_event import RunEventRow
|
from deerflow.persistence.models.run_event import RunEventRow
|
||||||
from deerflow.runtime.events.store.base import RunEventStore
|
from deerflow.runtime.events.store.base import RunEventStore
|
||||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -55,16 +55,22 @@ class DbRunEventStore(RunEventStore):
|
|||||||
return content, metadata or {}
|
return content, metadata or {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _owner_from_context() -> str | None:
|
def _user_id_from_context() -> str | None:
|
||||||
"""Soft read of owner_id from contextvar for write paths.
|
"""Soft read of user_id from contextvar for write paths.
|
||||||
|
|
||||||
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
||||||
which is the expected case for background worker writes. HTTP
|
which is the expected case for background worker writes. HTTP
|
||||||
request writes will have the contextvar set by auth middleware
|
request writes will have the contextvar set by auth middleware
|
||||||
and get their user_id stamped automatically.
|
and get their user_id stamped automatically.
|
||||||
|
|
||||||
|
Coerces ``user.id`` to ``str`` at the boundary: ``User.id`` is
|
||||||
|
typed as ``UUID`` by the auth layer, but ``run_events.user_id``
|
||||||
|
is ``VARCHAR(64)`` and aiosqlite cannot bind a raw UUID object
|
||||||
|
to a VARCHAR column ("type 'UUID' is not supported") — the
|
||||||
|
INSERT would silently roll back and the worker would hang.
|
||||||
"""
|
"""
|
||||||
user = get_current_user()
|
user = get_current_user()
|
||||||
return user.id if user is not None else None
|
return str(user.id) if user is not None else None
|
||||||
|
|
||||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||||
"""Write a single event — low-frequency path only.
|
"""Write a single event — low-frequency path only.
|
||||||
@@ -81,7 +87,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||||
else:
|
else:
|
||||||
db_content = content
|
db_content = content
|
||||||
owner_id = self._owner_from_context()
|
user_id = self._user_id_from_context()
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||||
@@ -92,7 +98,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
row = RunEventRow(
|
row = RunEventRow(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
owner_id=owner_id,
|
user_id=user_id,
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
category=category,
|
category=category,
|
||||||
content=db_content,
|
content=db_content,
|
||||||
@@ -106,7 +112,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
async def put_batch(self, events):
|
async def put_batch(self, events):
|
||||||
if not events:
|
if not events:
|
||||||
return []
|
return []
|
||||||
owner_id = self._owner_from_context()
|
user_id = self._user_id_from_context()
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||||
@@ -130,7 +136,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
row = RunEventRow(
|
row = RunEventRow(
|
||||||
thread_id=e["thread_id"],
|
thread_id=e["thread_id"],
|
||||||
run_id=e["run_id"],
|
run_id=e["run_id"],
|
||||||
owner_id=e.get("owner_id", owner_id),
|
user_id=e.get("user_id", user_id),
|
||||||
event_type=e["event_type"],
|
event_type=e["event_type"],
|
||||||
category=category,
|
category=category,
|
||||||
content=db_content,
|
content=db_content,
|
||||||
@@ -149,12 +155,12 @@ class DbRunEventStore(RunEventStore):
|
|||||||
limit=50,
|
limit=50,
|
||||||
before_seq=None,
|
before_seq=None,
|
||||||
after_seq=None,
|
after_seq=None,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages")
|
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages")
|
||||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||||
if before_seq is not None:
|
if before_seq is not None:
|
||||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||||
if after_seq is not None:
|
if after_seq is not None:
|
||||||
@@ -181,12 +187,12 @@ class DbRunEventStore(RunEventStore):
|
|||||||
*,
|
*,
|
||||||
event_types=None,
|
event_types=None,
|
||||||
limit=500,
|
limit=500,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events")
|
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events")
|
||||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||||
if event_types:
|
if event_types:
|
||||||
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
||||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||||
@@ -199,27 +205,46 @@ class DbRunEventStore(RunEventStore):
|
|||||||
thread_id,
|
thread_id,
|
||||||
run_id,
|
run_id,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
limit=50,
|
||||||
|
before_seq=None,
|
||||||
|
after_seq=None,
|
||||||
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run")
|
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
|
stmt = select(RunEventRow).where(
|
||||||
if resolved_owner_id is not None:
|
RunEventRow.thread_id == thread_id,
|
||||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
RunEventRow.run_id == run_id,
|
||||||
stmt = stmt.order_by(RunEventRow.seq.asc())
|
RunEventRow.category == "message",
|
||||||
async with self._sf() as session:
|
)
|
||||||
result = await session.execute(stmt)
|
if resolved_user_id is not None:
|
||||||
return [self._row_to_dict(r) for r in result.scalars()]
|
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||||
|
if before_seq is not None:
|
||||||
|
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||||
|
if after_seq is not None:
|
||||||
|
stmt = stmt.where(RunEventRow.seq > after_seq)
|
||||||
|
|
||||||
|
if after_seq is not None:
|
||||||
|
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||||
|
async with self._sf() as session:
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return [self._row_to_dict(r) for r in result.scalars()]
|
||||||
|
else:
|
||||||
|
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
|
||||||
|
async with self._sf() as session:
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
rows = list(result.scalars())
|
||||||
|
return [self._row_to_dict(r) for r in reversed(rows)]
|
||||||
|
|
||||||
async def count_messages(
|
async def count_messages(
|
||||||
self,
|
self,
|
||||||
thread_id,
|
thread_id,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages")
|
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages")
|
||||||
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
return await session.scalar(stmt) or 0
|
return await session.scalar(stmt) or 0
|
||||||
|
|
||||||
@@ -227,13 +252,13 @@ class DbRunEventStore(RunEventStore):
|
|||||||
self,
|
self,
|
||||||
thread_id,
|
thread_id,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread")
|
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
count_conditions = [RunEventRow.thread_id == thread_id]
|
count_conditions = [RunEventRow.thread_id == thread_id]
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
||||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||||
count = await session.scalar(count_stmt) or 0
|
count = await session.scalar(count_stmt) or 0
|
||||||
if count > 0:
|
if count > 0:
|
||||||
@@ -246,13 +271,13 @@ class DbRunEventStore(RunEventStore):
|
|||||||
thread_id,
|
thread_id,
|
||||||
run_id,
|
run_id,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run")
|
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
||||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||||
count = await session.scalar(count_stmt) or 0
|
count = await session.scalar(count_stmt) or 0
|
||||||
if count > 0:
|
if count > 0:
|
||||||
|
|||||||
@@ -152,9 +152,17 @@ class JsonlRunEventStore(RunEventStore):
|
|||||||
events = [e for e in events if e.get("event_type") in event_types]
|
events = [e for e in events if e.get("event_type") in event_types]
|
||||||
return events[:limit]
|
return events[:limit]
|
||||||
|
|
||||||
async def list_messages_by_run(self, thread_id, run_id):
|
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
|
||||||
events = self._read_run_events(thread_id, run_id)
|
events = self._read_run_events(thread_id, run_id)
|
||||||
return [e for e in events if e.get("category") == "message"]
|
filtered = [e for e in events if e.get("category") == "message"]
|
||||||
|
if before_seq is not None:
|
||||||
|
filtered = [e for e in filtered if e.get("seq", 0) < before_seq]
|
||||||
|
if after_seq is not None:
|
||||||
|
filtered = [e for e in filtered if e.get("seq", 0) > after_seq]
|
||||||
|
if after_seq is not None:
|
||||||
|
return filtered[:limit]
|
||||||
|
else:
|
||||||
|
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||||
|
|
||||||
async def count_messages(self, thread_id):
|
async def count_messages(self, thread_id):
|
||||||
all_events = self._read_thread_events(thread_id)
|
all_events = self._read_thread_events(thread_id)
|
||||||
|
|||||||
@@ -97,9 +97,17 @@ class MemoryRunEventStore(RunEventStore):
|
|||||||
filtered = [e for e in filtered if e["event_type"] in event_types]
|
filtered = [e for e in filtered if e["event_type"] in event_types]
|
||||||
return filtered[:limit]
|
return filtered[:limit]
|
||||||
|
|
||||||
async def list_messages_by_run(self, thread_id, run_id):
|
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
|
||||||
all_events = self._events.get(thread_id, [])
|
all_events = self._events.get(thread_id, [])
|
||||||
return [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
|
filtered = [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
|
||||||
|
if before_seq is not None:
|
||||||
|
filtered = [e for e in filtered if e["seq"] < before_seq]
|
||||||
|
if after_seq is not None:
|
||||||
|
filtered = [e for e in filtered if e["seq"] > after_seq]
|
||||||
|
if after_seq is not None:
|
||||||
|
return filtered[:limit]
|
||||||
|
else:
|
||||||
|
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||||
|
|
||||||
async def count_messages(self, thread_id):
|
async def count_messages(self, thread_id):
|
||||||
all_events = self._events.get(thread_id, [])
|
all_events = self._events.get(thread_id, [])
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
|
|
||||||
# Write buffer
|
# Write buffer
|
||||||
self._buffer: list[dict] = []
|
self._buffer: list[dict] = []
|
||||||
|
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
|
||||||
|
|
||||||
# Token accumulators
|
# Token accumulators
|
||||||
self._total_input_tokens = 0
|
self._total_input_tokens = 0
|
||||||
@@ -245,6 +246,19 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
# Tools that update graph state return a ``Command`` (e.g.
|
||||||
|
# ``present_files``). LangGraph later unwraps the inner ToolMessage
|
||||||
|
# into checkpoint state, so to stay checkpoint-aligned we must
|
||||||
|
# extract it here rather than storing ``str(Command(...))``.
|
||||||
|
if isinstance(output, Command):
|
||||||
|
update = getattr(output, "update", None) or {}
|
||||||
|
inner_msgs = update.get("messages") if isinstance(update, dict) else None
|
||||||
|
if isinstance(inner_msgs, list):
|
||||||
|
inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None)
|
||||||
|
if inner_tool_msg is not None:
|
||||||
|
output = inner_tool_msg
|
||||||
|
|
||||||
# Extract fields from ToolMessage object when LangChain provides one.
|
# Extract fields from ToolMessage object when LangChain provides one.
|
||||||
# LangChain's _format_output wraps tool results into a ToolMessage
|
# LangChain's _format_output wraps tool results into a ToolMessage
|
||||||
@@ -381,6 +395,10 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
"""
|
"""
|
||||||
if not self._buffer:
|
if not self._buffer:
|
||||||
return
|
return
|
||||||
|
# Skip if a flush is already in flight — avoids concurrent writes
|
||||||
|
# to the same SQLite file from multiple fire-and-forget tasks.
|
||||||
|
if self._pending_flush_tasks:
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
@@ -389,6 +407,7 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
batch = self._buffer.copy()
|
batch = self._buffer.copy()
|
||||||
self._buffer.clear()
|
self._buffer.clear()
|
||||||
task = loop.create_task(self._flush_async(batch))
|
task = loop.create_task(self._flush_async(batch))
|
||||||
|
self._pending_flush_tasks.add(task)
|
||||||
task.add_done_callback(self._on_flush_done)
|
task.add_done_callback(self._on_flush_done)
|
||||||
|
|
||||||
async def _flush_async(self, batch: list[dict]) -> None:
|
async def _flush_async(self, batch: list[dict]) -> None:
|
||||||
@@ -404,8 +423,8 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
# Return failed events to buffer for retry on next flush
|
# Return failed events to buffer for retry on next flush
|
||||||
self._buffer = batch + self._buffer
|
self._buffer = batch + self._buffer
|
||||||
|
|
||||||
@staticmethod
|
def _on_flush_done(self, task: asyncio.Task) -> None:
|
||||||
def _on_flush_done(task: asyncio.Task) -> None:
|
self._pending_flush_tasks.discard(task)
|
||||||
if task.cancelled():
|
if task.cancelled():
|
||||||
return
|
return
|
||||||
exc = task.exception()
|
exc = task.exception()
|
||||||
@@ -450,10 +469,17 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
|
|
||||||
async def flush(self) -> None:
|
async def flush(self) -> None:
|
||||||
"""Force flush remaining buffer. Called in worker's finally block."""
|
"""Force flush remaining buffer. Called in worker's finally block."""
|
||||||
if self._buffer:
|
if self._pending_flush_tasks:
|
||||||
batch = self._buffer.copy()
|
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
|
||||||
self._buffer.clear()
|
|
||||||
await self._store.put_batch(batch)
|
while self._buffer:
|
||||||
|
batch = self._buffer[: self._flush_threshold]
|
||||||
|
del self._buffer[: self._flush_threshold]
|
||||||
|
try:
|
||||||
|
await self._store.put_batch(batch)
|
||||||
|
except Exception:
|
||||||
|
self._buffer = batch + self._buffer
|
||||||
|
raise
|
||||||
|
|
||||||
def get_completion_data(self) -> dict:
|
def get_completion_data(self) -> dict:
|
||||||
"""Return accumulated token and message data for run completion."""
|
"""Return accumulated token and message data for run completion."""
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ RunManager depends on this interface. Implementations:
|
|||||||
- MemoryRunStore: in-memory dict (development, tests)
|
- MemoryRunStore: in-memory dict (development, tests)
|
||||||
- Future: RunRepository backed by SQLAlchemy ORM
|
- Future: RunRepository backed by SQLAlchemy ORM
|
||||||
|
|
||||||
All methods accept an optional owner_id for user isolation.
|
All methods accept an optional user_id for user isolation.
|
||||||
When owner_id is None, no user filtering is applied (single-user mode).
|
When user_id is None, no user filtering is applied (single-user mode).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -22,7 +22,7 @@ class RunStore(abc.ABC):
|
|||||||
*,
|
*,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
assistant_id: str | None = None,
|
assistant_id: str | None = None,
|
||||||
owner_id: str | None = None,
|
user_id: str | None = None,
|
||||||
status: str = "pending",
|
status: str = "pending",
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
@@ -42,7 +42,7 @@ class RunStore(abc.ABC):
|
|||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None = None,
|
user_id: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class MemoryRunStore(RunStore):
|
|||||||
*,
|
*,
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=None,
|
assistant_id=None,
|
||||||
owner_id=None,
|
user_id=None,
|
||||||
status="pending",
|
status="pending",
|
||||||
multitask_strategy="reject",
|
multitask_strategy="reject",
|
||||||
metadata=None,
|
metadata=None,
|
||||||
@@ -35,7 +35,7 @@ class MemoryRunStore(RunStore):
|
|||||||
"run_id": run_id,
|
"run_id": run_id,
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
"assistant_id": assistant_id,
|
"assistant_id": assistant_id,
|
||||||
"owner_id": owner_id,
|
"user_id": user_id,
|
||||||
"status": status,
|
"status": status,
|
||||||
"multitask_strategy": multitask_strategy,
|
"multitask_strategy": multitask_strategy,
|
||||||
"metadata": metadata or {},
|
"metadata": metadata or {},
|
||||||
@@ -49,8 +49,8 @@ class MemoryRunStore(RunStore):
|
|||||||
async def get(self, run_id):
|
async def get(self, run_id):
|
||||||
return self._runs.get(run_id)
|
return self._runs.get(run_id)
|
||||||
|
|
||||||
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
|
async def list_by_thread(self, thread_id, *, user_id=None, limit=100):
|
||||||
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (owner_id is None or r.get("owner_id") == owner_id)]
|
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (user_id is None or r.get("user_id") == user_id)]
|
||||||
results.sort(key=lambda r: r["created_at"], reverse=True)
|
results.sort(key=lambda r: r["created_at"], reverse=True)
|
||||||
return results[:limit]
|
return results[:limit]
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class RunContext:
|
|||||||
store: Any | None = field(default=None)
|
store: Any | None = field(default=None)
|
||||||
event_store: Any | None = field(default=None)
|
event_store: Any | None = field(default=None)
|
||||||
run_events_config: Any | None = field(default=None)
|
run_events_config: Any | None = field(default=None)
|
||||||
thread_meta_repo: Any | None = field(default=None)
|
thread_store: Any | None = field(default=None)
|
||||||
follow_up_to_run_id: str | None = field(default=None)
|
follow_up_to_run_id: str | None = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
@@ -75,7 +75,7 @@ async def run_agent(
|
|||||||
store = ctx.store
|
store = ctx.store
|
||||||
event_store = ctx.event_store
|
event_store = ctx.event_store
|
||||||
run_events_config = ctx.run_events_config
|
run_events_config = ctx.run_events_config
|
||||||
thread_meta_repo = ctx.thread_meta_repo
|
thread_store = ctx.thread_store
|
||||||
follow_up_to_run_id = ctx.follow_up_to_run_id
|
follow_up_to_run_id = ctx.follow_up_to_run_id
|
||||||
|
|
||||||
run_id = record.run_id
|
run_id = record.run_id
|
||||||
@@ -85,34 +85,7 @@ async def run_agent(
|
|||||||
pre_run_snapshot: dict[str, Any] | None = None
|
pre_run_snapshot: dict[str, Any] | None = None
|
||||||
snapshot_capture_failed = False
|
snapshot_capture_failed = False
|
||||||
|
|
||||||
# Initialize RunJournal for event capture
|
|
||||||
journal = None
|
journal = None
|
||||||
if event_store is not None:
|
|
||||||
from deerflow.runtime.journal import RunJournal
|
|
||||||
|
|
||||||
journal = RunJournal(
|
|
||||||
run_id=run_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
event_store=event_store,
|
|
||||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Write human_message event (model_dump format, aligned with checkpoint)
|
|
||||||
human_msg = _extract_human_message(graph_input)
|
|
||||||
if human_msg is not None:
|
|
||||||
msg_metadata = {}
|
|
||||||
if follow_up_to_run_id:
|
|
||||||
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
|
||||||
await event_store.put(
|
|
||||||
thread_id=thread_id,
|
|
||||||
run_id=run_id,
|
|
||||||
event_type="human_message",
|
|
||||||
category="message",
|
|
||||||
content=human_msg.model_dump(),
|
|
||||||
metadata=msg_metadata or None,
|
|
||||||
)
|
|
||||||
content = human_msg.content
|
|
||||||
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
|
||||||
|
|
||||||
# Track whether "events" was requested but skipped
|
# Track whether "events" was requested but skipped
|
||||||
if "events" in requested_modes:
|
if "events" in requested_modes:
|
||||||
@@ -122,6 +95,38 @@ async def run_agent(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Initialize RunJournal + write human_message event.
|
||||||
|
# These are inside the try block so any exception (e.g. a DB
|
||||||
|
# error writing the event) flows through the except/finally
|
||||||
|
# path that publishes an "end" event to the SSE bridge —
|
||||||
|
# otherwise a failure here would leave the stream hanging
|
||||||
|
# with no terminator.
|
||||||
|
if event_store is not None:
|
||||||
|
from deerflow.runtime.journal import RunJournal
|
||||||
|
|
||||||
|
journal = RunJournal(
|
||||||
|
run_id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
event_store=event_store,
|
||||||
|
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||||
|
)
|
||||||
|
|
||||||
|
human_msg = _extract_human_message(graph_input)
|
||||||
|
if human_msg is not None:
|
||||||
|
msg_metadata = {}
|
||||||
|
if follow_up_to_run_id:
|
||||||
|
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
||||||
|
await event_store.put(
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
event_type="human_message",
|
||||||
|
category="message",
|
||||||
|
content=human_msg.model_dump(),
|
||||||
|
metadata=msg_metadata or None,
|
||||||
|
)
|
||||||
|
content = human_msg.content
|
||||||
|
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
||||||
|
|
||||||
# 1. Mark running
|
# 1. Mark running
|
||||||
await run_manager.set_status(run_id, RunStatus.running)
|
await run_manager.set_status(run_id, RunStatus.running)
|
||||||
|
|
||||||
@@ -305,12 +310,15 @@ async def run_agent(
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
|
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
|
||||||
|
|
||||||
# Persist token usage + convenience fields to RunStore
|
try:
|
||||||
completion = journal.get_completion_data()
|
# Persist token usage + convenience fields to RunStore
|
||||||
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
|
completion = journal.get_completion_data()
|
||||||
|
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to persist run completion for %s (non-fatal)", run_id, exc_info=True)
|
||||||
|
|
||||||
# Sync title from checkpoint to threads_meta.display_name
|
# Sync title from checkpoint to threads_meta.display_name
|
||||||
if checkpointer is not None:
|
if checkpointer is not None and thread_store is not None:
|
||||||
try:
|
try:
|
||||||
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
|
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
|
||||||
@@ -318,16 +326,17 @@ async def run_agent(
|
|||||||
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
|
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
|
||||||
title = ckpt.get("channel_values", {}).get("title")
|
title = ckpt.get("channel_values", {}).get("title")
|
||||||
if title:
|
if title:
|
||||||
await thread_meta_repo.update_display_name(thread_id, title)
|
await thread_store.update_display_name(thread_id, title)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
|
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
|
||||||
|
|
||||||
# Update threads_meta status based on run outcome
|
# Update threads_meta status based on run outcome
|
||||||
try:
|
if thread_store is not None:
|
||||||
final_status = "idle" if record.status == RunStatus.success else record.status.value
|
try:
|
||||||
await thread_meta_repo.update_status(thread_id, final_status)
|
final_status = "idle" if record.status == RunStatus.success else record.status.value
|
||||||
except Exception:
|
await thread_store.update_status(thread_id, final_status)
|
||||||
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
|
except Exception:
|
||||||
|
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
|
||||||
|
|
||||||
await bridge.publish_end(run_id)
|
await bridge.publish_end(run_id)
|
||||||
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ async def make_store() -> AsyncIterator[BaseStore]:
|
|||||||
configured checkpointer.
|
configured checkpointer.
|
||||||
|
|
||||||
Reads from the same ``checkpointer`` section of *config.yaml* used by
|
Reads from the same ``checkpointer`` section of *config.yaml* used by
|
||||||
:func:`deerflow.agents.checkpointer.async_provider.make_checkpointer` so
|
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so
|
||||||
that both singletons always use the same persistence technology::
|
that both singletons always use the same persistence technology::
|
||||||
|
|
||||||
async with make_store() as store:
|
async with make_store() as store:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Async stream bridge factory.
|
"""Async stream bridge factory.
|
||||||
|
|
||||||
Provides an **async context manager** aligned with
|
Provides an **async context manager** aligned with
|
||||||
:func:`deerflow.agents.checkpointer.async_provider.make_checkpointer`.
|
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer`.
|
||||||
|
|
||||||
Usage (e.g. FastAPI lifespan)::
|
Usage (e.g. FastAPI lifespan)::
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""Request-scoped user context for owner-based authorization.
|
"""Request-scoped user context for user-based authorization.
|
||||||
|
|
||||||
This module holds a :class:`~contextvars.ContextVar` that the gateway's
|
This module holds a :class:`~contextvars.ContextVar` that the gateway's
|
||||||
auth middleware sets after a successful authentication. Repository
|
auth middleware sets after a successful authentication. Repository
|
||||||
methods read the contextvar via a sentinel default parameter, letting
|
methods read the contextvar via a sentinel default parameter, letting
|
||||||
routers stay free of ``owner_id`` boilerplate.
|
routers stay free of ``user_id`` boilerplate.
|
||||||
|
|
||||||
Three-state semantics for the repository ``owner_id`` parameter (the
|
Three-state semantics for the repository ``user_id`` parameter (the
|
||||||
consumer side of this module lives in ``deerflow.persistence.*``):
|
consumer side of this module lives in ``deerflow.persistence.*``):
|
||||||
|
|
||||||
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
|
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
|
||||||
@@ -91,16 +91,35 @@ def require_current_user() -> CurrentUser:
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Sentinel-based owner_id resolution
|
# Effective user_id helpers (filesystem isolation)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
DEFAULT_USER_ID: Final[str] = "default"
|
||||||
|
|
||||||
|
|
||||||
|
def get_effective_user_id() -> str:
|
||||||
|
"""Return the current user's id as a string, or DEFAULT_USER_ID if unset.
|
||||||
|
|
||||||
|
Unlike :func:`require_current_user` this never raises — it is designed
|
||||||
|
for filesystem-path resolution where a valid user bucket is always needed.
|
||||||
|
"""
|
||||||
|
user = _current_user.get()
|
||||||
|
if user is None:
|
||||||
|
return DEFAULT_USER_ID
|
||||||
|
return str(user.id)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sentinel-based user_id resolution
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
# Repository methods accept an ``owner_id`` keyword-only argument that
|
# Repository methods accept a ``user_id`` keyword-only argument that
|
||||||
# defaults to ``AUTO``. The three possible values drive distinct
|
# defaults to ``AUTO``. The three possible values drive distinct
|
||||||
# behaviours; see the docstring on :func:`resolve_owner_id`.
|
# behaviours; see the docstring on :func:`resolve_user_id`.
|
||||||
|
|
||||||
|
|
||||||
class _AutoSentinel:
|
class _AutoSentinel:
|
||||||
"""Singleton marker meaning 'resolve owner_id from contextvar'."""
|
"""Singleton marker meaning 'resolve user_id from contextvar'."""
|
||||||
|
|
||||||
_instance: _AutoSentinel | None = None
|
_instance: _AutoSentinel | None = None
|
||||||
|
|
||||||
@@ -116,12 +135,12 @@ class _AutoSentinel:
|
|||||||
AUTO: Final[_AutoSentinel] = _AutoSentinel()
|
AUTO: Final[_AutoSentinel] = _AutoSentinel()
|
||||||
|
|
||||||
|
|
||||||
def resolve_owner_id(
|
def resolve_user_id(
|
||||||
value: str | None | _AutoSentinel,
|
value: str | None | _AutoSentinel,
|
||||||
*,
|
*,
|
||||||
method_name: str = "repository method",
|
method_name: str = "repository method",
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Resolve the owner_id parameter passed to a repository method.
|
"""Resolve the user_id parameter passed to a repository method.
|
||||||
|
|
||||||
Three-state semantics:
|
Three-state semantics:
|
||||||
|
|
||||||
@@ -131,16 +150,16 @@ def resolve_owner_id(
|
|||||||
- Explicit ``str``: use the provided id verbatim, overriding any
|
- Explicit ``str``: use the provided id verbatim, overriding any
|
||||||
contextvar value. Useful for tests and admin-override flows.
|
contextvar value. Useful for tests and admin-override flows.
|
||||||
- Explicit ``None``: no filter — the repository should skip the
|
- Explicit ``None``: no filter — the repository should skip the
|
||||||
owner_id WHERE clause entirely. Reserved for migration scripts
|
user_id WHERE clause entirely. Reserved for migration scripts
|
||||||
and CLI tools that intentionally bypass isolation.
|
and CLI tools that intentionally bypass isolation.
|
||||||
"""
|
"""
|
||||||
if isinstance(value, _AutoSentinel):
|
if isinstance(value, _AutoSentinel):
|
||||||
user = _current_user.get()
|
user = _current_user.get()
|
||||||
if user is None:
|
if user is None:
|
||||||
raise RuntimeError(f"{method_name} called with owner_id=AUTO but no user context is set; pass an explicit owner_id, set the contextvar via auth middleware, or opt out with owner_id=None for migration/CLI paths.")
|
raise RuntimeError(f"{method_name} called with user_id=AUTO but no user context is set; pass an explicit user_id, set the contextvar via auth middleware, or opt out with user_id=None for migration/CLI paths.")
|
||||||
# Coerce to ``str`` at the boundary: ``User.id`` is typed as
|
# Coerce to ``str`` at the boundary: ``User.id`` is typed as
|
||||||
# ``UUID`` for the API surface, but the persistence layer
|
# ``UUID`` for the API surface, but the persistence layer
|
||||||
# stores ``owner_id`` as ``String(64)`` and aiosqlite cannot
|
# stores ``user_id`` as ``String(64)`` and aiosqlite cannot
|
||||||
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is
|
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is
|
||||||
# not supported"). Honour the documented return type here
|
# not supported"). Honour the documented return type here
|
||||||
# rather than ripple a type change through every caller.
|
# rather than ripple a type change through every caller.
|
||||||
|
|||||||
@@ -200,8 +200,9 @@ def _get_acp_workspace_host_path(thread_id: str | None = None) -> str | None:
|
|||||||
if thread_id is not None:
|
if thread_id is not None:
|
||||||
try:
|
try:
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
host_path = get_paths().acp_workspace_dir(thread_id)
|
host_path = get_paths().acp_workspace_dir(thread_id, user_id=get_effective_user_id())
|
||||||
if host_path.exists():
|
if host_path.exists():
|
||||||
return str(host_path)
|
return str(host_path)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -33,11 +33,12 @@ def _get_work_dir(thread_id: str | None) -> str:
|
|||||||
An absolute physical filesystem path to use as the working directory.
|
An absolute physical filesystem path to use as the working directory.
|
||||||
"""
|
"""
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
if thread_id:
|
if thread_id:
|
||||||
try:
|
try:
|
||||||
work_dir = paths.acp_workspace_dir(thread_id)
|
work_dir = paths.acp_workspace_dir(thread_id, user_id=get_effective_user_id())
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning("Invalid thread_id %r for ACP workspace, falling back to global", thread_id)
|
logger.warning("Invalid thread_id %r for ACP workspace, falling back to global", thread_id)
|
||||||
work_dir = paths.base_dir / "acp-workspace"
|
work_dir = paths.base_dir / "acp-workspace"
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from langgraph.typing import ContextT
|
|||||||
|
|
||||||
from deerflow.agents.thread_state import ThreadState
|
from deerflow.agents.thread_state import ThreadState
|
||||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
||||||
|
|
||||||
@@ -47,7 +48,7 @@ def _normalize_presented_filepath(
|
|||||||
virtual_prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
|
virtual_prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
|
||||||
|
|
||||||
if stripped == virtual_prefix or stripped.startswith(virtual_prefix + "/"):
|
if stripped == virtual_prefix or stripped.startswith(virtual_prefix + "/"):
|
||||||
actual_path = get_paths().resolve_virtual_path(thread_id, filepath)
|
actual_path = get_paths().resolve_virtual_path(thread_id, filepath, user_id=get_effective_user_id())
|
||||||
else:
|
else:
|
||||||
actual_path = Path(filepath).expanduser().resolve()
|
actual_path = Path(filepath).expanduser().resolve()
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from pathlib import Path
|
|||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
|
|
||||||
class PathTraversalError(ValueError):
|
class PathTraversalError(ValueError):
|
||||||
@@ -33,7 +34,7 @@ def validate_thread_id(thread_id: str) -> None:
|
|||||||
def get_uploads_dir(thread_id: str) -> Path:
|
def get_uploads_dir(thread_id: str) -> Path:
|
||||||
"""Return the uploads directory path for a thread (no side effects)."""
|
"""Return the uploads directory path for a thread (no side effects)."""
|
||||||
validate_thread_id(thread_id)
|
validate_thread_id(thread_id)
|
||||||
return get_paths().sandbox_uploads_dir(thread_id)
|
return get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||||
|
|
||||||
|
|
||||||
def ensure_uploads_dir(thread_id: str) -> Path:
|
def ensure_uploads_dir(thread_id: str) -> Path:
|
||||||
|
|||||||
@@ -23,9 +23,7 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
postgres = [
|
postgres = ["deerflow-harness[postgres]"]
|
||||||
"deerflow-harness[postgres]",
|
|
||||||
]
|
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = ["pytest>=8.0.0", "ruff>=0.14.11"]
|
dev = ["pytest>=8.0.0", "ruff>=0.14.11"]
|
||||||
|
|||||||
@@ -0,0 +1,160 @@
|
|||||||
|
"""One-time migration: move legacy thread dirs and memory into per-user layout.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run]
|
||||||
|
|
||||||
|
The script is idempotent — re-running it after a successful migration is a no-op.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from deerflow.config.paths import Paths, get_paths
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_thread_dirs(
|
||||||
|
paths: Paths,
|
||||||
|
thread_owner_map: dict[str, str],
|
||||||
|
*,
|
||||||
|
dry_run: bool = False,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Move legacy thread directories into per-user layout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paths: Paths instance.
|
||||||
|
thread_owner_map: Mapping of thread_id -> user_id from threads_meta table.
|
||||||
|
dry_run: If True, only log what would happen.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of migration report entries.
|
||||||
|
"""
|
||||||
|
report: list[dict] = []
|
||||||
|
legacy_threads = paths.base_dir / "threads"
|
||||||
|
if not legacy_threads.exists():
|
||||||
|
logger.info("No legacy threads directory found — nothing to migrate.")
|
||||||
|
return report
|
||||||
|
|
||||||
|
for thread_dir in sorted(legacy_threads.iterdir()):
|
||||||
|
if not thread_dir.is_dir():
|
||||||
|
continue
|
||||||
|
thread_id = thread_dir.name
|
||||||
|
user_id = thread_owner_map.get(thread_id, "default")
|
||||||
|
dest = paths.base_dir / "users" / user_id / "threads" / thread_id
|
||||||
|
|
||||||
|
entry = {"thread_id": thread_id, "user_id": user_id, "action": ""}
|
||||||
|
|
||||||
|
if dest.exists():
|
||||||
|
conflicts_dir = paths.base_dir / "migration-conflicts" / thread_id
|
||||||
|
entry["action"] = f"conflict -> {conflicts_dir}"
|
||||||
|
if not dry_run:
|
||||||
|
conflicts_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.move(str(thread_dir), str(conflicts_dir))
|
||||||
|
logger.warning("Conflict for thread %s: moved to %s", thread_id, conflicts_dir)
|
||||||
|
else:
|
||||||
|
entry["action"] = f"moved -> {dest}"
|
||||||
|
if not dry_run:
|
||||||
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.move(str(thread_dir), str(dest))
|
||||||
|
logger.info("Migrated thread %s -> user %s", thread_id, user_id)
|
||||||
|
|
||||||
|
report.append(entry)
|
||||||
|
|
||||||
|
# Clean up empty legacy threads dir
|
||||||
|
if not dry_run and legacy_threads.exists() and not any(legacy_threads.iterdir()):
|
||||||
|
legacy_threads.rmdir()
|
||||||
|
|
||||||
|
return report
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_memory(
|
||||||
|
paths: Paths,
|
||||||
|
user_id: str = "default",
|
||||||
|
*,
|
||||||
|
dry_run: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Move legacy global memory.json into per-user layout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paths: Paths instance.
|
||||||
|
user_id: Target user to receive the legacy memory.
|
||||||
|
dry_run: If True, only log.
|
||||||
|
"""
|
||||||
|
legacy_mem = paths.base_dir / "memory.json"
|
||||||
|
if not legacy_mem.exists():
|
||||||
|
logger.info("No legacy memory.json found — nothing to migrate.")
|
||||||
|
return
|
||||||
|
|
||||||
|
dest = paths.user_memory_file(user_id)
|
||||||
|
if dest.exists():
|
||||||
|
legacy_backup = paths.base_dir / "memory.legacy.json"
|
||||||
|
logger.warning("Destination %s exists; renaming legacy to %s", dest, legacy_backup)
|
||||||
|
if not dry_run:
|
||||||
|
legacy_mem.rename(legacy_backup)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Migrating memory.json -> %s", dest)
|
||||||
|
if not dry_run:
|
||||||
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.move(str(legacy_mem), str(dest))
|
||||||
|
|
||||||
|
|
||||||
|
def _build_owner_map_from_db(paths: Paths) -> dict[str, str]:
|
||||||
|
"""Query threads_meta table for thread_id -> user_id mapping.
|
||||||
|
|
||||||
|
Uses raw sqlite3 to avoid async dependencies.
|
||||||
|
"""
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
db_path = paths.base_dir / "deer-flow.db"
|
||||||
|
if not db_path.exists():
|
||||||
|
logger.info("No database found at %s — using empty owner map.", db_path)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
conn = sqlite3.connect(str(db_path))
|
||||||
|
try:
|
||||||
|
cursor = conn.execute("SELECT thread_id, user_id FROM threads_meta WHERE user_id IS NOT NULL")
|
||||||
|
return {row[0]: row[1] for row in cursor.fetchall()}
|
||||||
|
except sqlite3.OperationalError as e:
|
||||||
|
logger.warning("Failed to query threads_meta: %s", e)
|
||||||
|
return {}
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="Migrate DeerFlow data to per-user layout")
|
||||||
|
parser.add_argument("--dry-run", action="store_true", help="Log actions without making changes")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
|
||||||
|
paths = get_paths()
|
||||||
|
logger.info("Base directory: %s", paths.base_dir)
|
||||||
|
logger.info("Dry run: %s", args.dry_run)
|
||||||
|
|
||||||
|
owner_map = _build_owner_map_from_db(paths)
|
||||||
|
logger.info("Found %d thread ownership records in DB", len(owner_map))
|
||||||
|
|
||||||
|
report = migrate_thread_dirs(paths, owner_map, dry_run=args.dry_run)
|
||||||
|
migrate_memory(paths, user_id="default", dry_run=args.dry_run)
|
||||||
|
|
||||||
|
if report:
|
||||||
|
logger.info("Migration report:")
|
||||||
|
for entry in report:
|
||||||
|
logger.info(" thread=%s user=%s action=%s", entry["thread_id"], entry["user_id"], entry["action"])
|
||||||
|
else:
|
||||||
|
logger.info("No threads to migrate.")
|
||||||
|
|
||||||
|
unowned = [e for e in report if e["user_id"] == "default"]
|
||||||
|
if unowned:
|
||||||
|
logger.warning("%d thread(s) had no owner and were assigned to 'default':", len(unowned))
|
||||||
|
for e in unowned:
|
||||||
|
logger.warning(" %s", e["thread_id"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -3,16 +3,16 @@
|
|||||||
The production gateway runs ``AuthMiddleware`` (validates the JWT cookie)
|
The production gateway runs ``AuthMiddleware`` (validates the JWT cookie)
|
||||||
ahead of every router, plus ``@require_permission(owner_check=True)``
|
ahead of every router, plus ``@require_permission(owner_check=True)``
|
||||||
decorators that read ``request.state.auth`` and call
|
decorators that read ``request.state.auth`` and call
|
||||||
``thread_meta_repo.check_access``. Router-level unit tests construct
|
``thread_store.check_access``. Router-level unit tests construct
|
||||||
**bare** FastAPI apps that include only one router — they have neither
|
**bare** FastAPI apps that include only one router — they have neither
|
||||||
the auth middleware nor a real thread_meta_repo, so the decorators raise
|
the auth middleware nor a real thread_store, so the decorators raise
|
||||||
401 (TestClient path) or ValueError (direct-call path).
|
401 (TestClient path) or ValueError (direct-call path).
|
||||||
|
|
||||||
This module provides two surfaces:
|
This module provides two surfaces:
|
||||||
|
|
||||||
1. :func:`make_authed_test_app` — wraps ``FastAPI()`` with a tiny
|
1. :func:`make_authed_test_app` — wraps ``FastAPI()`` with a tiny
|
||||||
``BaseHTTPMiddleware`` that stamps a fake user / AuthContext on every
|
``BaseHTTPMiddleware`` that stamps a fake user / AuthContext on every
|
||||||
request, plus a permissive ``thread_meta_repo`` mock on
|
request, plus a permissive ``thread_store`` mock on
|
||||||
``app.state``. Use from TestClient-based router tests.
|
``app.state``. Use from TestClient-based router tests.
|
||||||
|
|
||||||
2. :func:`call_unwrapped` — invokes the underlying function bypassing
|
2. :func:`call_unwrapped` — invokes the underlying function bypassing
|
||||||
@@ -86,20 +86,20 @@ def make_authed_test_app(
|
|||||||
user_factory: Callable[[], User] | None = None,
|
user_factory: Callable[[], User] | None = None,
|
||||||
owner_check_passes: bool = True,
|
owner_check_passes: bool = True,
|
||||||
) -> FastAPI:
|
) -> FastAPI:
|
||||||
"""Build a FastAPI test app with stub auth + permissive thread_meta_repo.
|
"""Build a FastAPI test app with stub auth + permissive thread_store.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_factory: Override the default test user. Must return a fully
|
user_factory: Override the default test user. Must return a fully
|
||||||
populated :class:`User`. Useful for cross-user isolation tests
|
populated :class:`User`. Useful for cross-user isolation tests
|
||||||
that need a stable id across requests.
|
that need a stable id across requests.
|
||||||
owner_check_passes: When True (default), ``thread_meta_repo.check_access``
|
owner_check_passes: When True (default), ``thread_store.check_access``
|
||||||
returns True for every call so ``@require_permission(owner_check=True)``
|
returns True for every call so ``@require_permission(owner_check=True)``
|
||||||
never blocks the route under test. Pass False to verify that
|
never blocks the route under test. Pass False to verify that
|
||||||
permission failures surface correctly.
|
permission failures surface correctly.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A ``FastAPI`` app with the stub middleware installed and
|
A ``FastAPI`` app with the stub middleware installed and
|
||||||
``app.state.thread_meta_repo`` set to a permissive mock. The
|
``app.state.thread_store`` set to a permissive mock. The
|
||||||
caller is still responsible for ``app.include_router(...)``.
|
caller is still responsible for ``app.include_router(...)``.
|
||||||
"""
|
"""
|
||||||
factory = user_factory or _make_stub_user
|
factory = user_factory or _make_stub_user
|
||||||
@@ -108,7 +108,7 @@ def make_authed_test_app(
|
|||||||
|
|
||||||
repo = MagicMock()
|
repo = MagicMock()
|
||||||
repo.check_access = AsyncMock(return_value=owner_check_passes)
|
repo.check_access = AsyncMock(return_value=owner_check_passes)
|
||||||
app.state.thread_meta_repo = repo
|
app.state.thread_store = repo
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ def provisioner_module():
|
|||||||
# Auto-set user context for every test unless marked no_auto_user
|
# Auto-set user context for every test unless marked no_auto_user
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
# Repository methods read ``owner_id`` from a contextvar by default
|
# Repository methods read ``user_id`` from a contextvar by default
|
||||||
# (see ``deerflow.runtime.user_context``). Without this fixture, every
|
# (see ``deerflow.runtime.user_context``). Without this fixture, every
|
||||||
# pre-existing persistence test would raise RuntimeError because the
|
# pre-existing persistence test would raise RuntimeError because the
|
||||||
# contextvar is unset. The fixture sets a default test user on every
|
# contextvar is unset. The fixture sets a default test user on every
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ def test_get_thread_mounts_includes_acp_workspace(tmp_path, monkeypatch):
|
|||||||
"""_get_thread_mounts must include /mnt/acp-workspace (read-only) for docker sandbox."""
|
"""_get_thread_mounts must include /mnt/acp-workspace (read-only) for docker sandbox."""
|
||||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||||
|
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
|
||||||
|
|
||||||
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-3")
|
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-3")
|
||||||
|
|
||||||
@@ -95,6 +96,7 @@ def test_get_thread_mounts_preserves_windows_host_path_style(tmp_path, monkeypat
|
|||||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||||
monkeypatch.setenv("DEER_FLOW_HOST_BASE_DIR", r"C:\Users\demo\deer-flow\backend\.deer-flow")
|
monkeypatch.setenv("DEER_FLOW_HOST_BASE_DIR", r"C:\Users\demo\deer-flow\backend\.deer-flow")
|
||||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||||
|
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
|
||||||
|
|
||||||
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-10")
|
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-10")
|
||||||
|
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ class TestResolveAttachments:
|
|||||||
mock_paths = MagicMock()
|
mock_paths = MagicMock()
|
||||||
mock_paths.sandbox_outputs_dir.return_value = outputs_dir
|
mock_paths.sandbox_outputs_dir.return_value = outputs_dir
|
||||||
|
|
||||||
def resolve_side_effect(tid, vpath):
|
def resolve_side_effect(tid, vpath, *, user_id=None):
|
||||||
if "data.csv" in vpath:
|
if "data.csv" in vpath:
|
||||||
return good_file
|
return good_file
|
||||||
return tmp_path / "missing.txt"
|
return tmp_path / "missing.txt"
|
||||||
|
|||||||
@@ -6,13 +6,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import deerflow.config.app_config as app_config_module
|
import deerflow.config.app_config as app_config_module
|
||||||
from deerflow.agents.checkpointer import get_checkpointer, reset_checkpointer
|
|
||||||
from deerflow.config.checkpointer_config import (
|
from deerflow.config.checkpointer_config import (
|
||||||
CheckpointerConfig,
|
CheckpointerConfig,
|
||||||
get_checkpointer_config,
|
get_checkpointer_config,
|
||||||
load_checkpointer_config_from_dict,
|
load_checkpointer_config_from_dict,
|
||||||
set_checkpointer_config,
|
set_checkpointer_config,
|
||||||
)
|
)
|
||||||
|
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -78,7 +78,7 @@ class TestGetCheckpointer:
|
|||||||
"""get_checkpointer should return InMemorySaver when not configured."""
|
"""get_checkpointer should return InMemorySaver when not configured."""
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
with patch("deerflow.agents.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
|
with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
|
||||||
cp = get_checkpointer()
|
cp = get_checkpointer()
|
||||||
assert cp is not None
|
assert cp is not None
|
||||||
assert isinstance(cp, InMemorySaver)
|
assert isinstance(cp, InMemorySaver)
|
||||||
@@ -178,7 +178,7 @@ class TestAsyncCheckpointer:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_sqlite_creates_parent_dir_via_to_thread(self):
|
async def test_sqlite_creates_parent_dir_via_to_thread(self):
|
||||||
"""Async SQLite setup should move mkdir off the event loop."""
|
"""Async SQLite setup should move mkdir off the event loop."""
|
||||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||||
|
|
||||||
mock_config = MagicMock()
|
mock_config = MagicMock()
|
||||||
mock_config.checkpointer = CheckpointerConfig(type="sqlite", connection_string="relative/test.db")
|
mock_config.checkpointer = CheckpointerConfig(type="sqlite", connection_string="relative/test.db")
|
||||||
@@ -195,11 +195,11 @@ class TestAsyncCheckpointer:
|
|||||||
mock_module.AsyncSqliteSaver = mock_saver_cls
|
mock_module.AsyncSqliteSaver = mock_saver_cls
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config),
|
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
|
||||||
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
|
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
|
||||||
patch("deerflow.agents.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
|
patch("deerflow.runtime.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
|
||||||
patch(
|
patch(
|
||||||
"deerflow.agents.checkpointer.async_provider.resolve_sqlite_conn_str",
|
"deerflow.runtime.checkpointer.async_provider.resolve_sqlite_conn_str",
|
||||||
return_value="/tmp/resolved/test.db",
|
return_value="/tmp/resolved/test.db",
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -12,14 +12,14 @@ class TestCheckpointerNoneFix:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_async_make_checkpointer_returns_in_memory_saver_when_not_configured(self):
|
async def test_async_make_checkpointer_returns_in_memory_saver_when_not_configured(self):
|
||||||
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
|
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
|
||||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||||
|
|
||||||
# Mock get_app_config to return a config with checkpointer=None and database=None
|
# Mock get_app_config to return a config with checkpointer=None and database=None
|
||||||
mock_config = MagicMock()
|
mock_config = MagicMock()
|
||||||
mock_config.checkpointer = None
|
mock_config.checkpointer = None
|
||||||
mock_config.database = None
|
mock_config.database = None
|
||||||
|
|
||||||
with patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config):
|
with patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config):
|
||||||
async with make_checkpointer() as checkpointer:
|
async with make_checkpointer() as checkpointer:
|
||||||
# Should return InMemorySaver, not None
|
# Should return InMemorySaver, not None
|
||||||
assert checkpointer is not None
|
assert checkpointer is not None
|
||||||
@@ -36,13 +36,13 @@ class TestCheckpointerNoneFix:
|
|||||||
|
|
||||||
def test_sync_checkpointer_context_returns_in_memory_saver_when_not_configured(self):
|
def test_sync_checkpointer_context_returns_in_memory_saver_when_not_configured(self):
|
||||||
"""checkpointer_context should return InMemorySaver when config.checkpointer is None."""
|
"""checkpointer_context should return InMemorySaver when config.checkpointer is None."""
|
||||||
from deerflow.agents.checkpointer.provider import checkpointer_context
|
from deerflow.runtime.checkpointer.provider import checkpointer_context
|
||||||
|
|
||||||
# Mock get_app_config to return a config with checkpointer=None
|
# Mock get_app_config to return a config with checkpointer=None
|
||||||
mock_config = MagicMock()
|
mock_config = MagicMock()
|
||||||
mock_config.checkpointer = None
|
mock_config.checkpointer = None
|
||||||
|
|
||||||
with patch("deerflow.agents.checkpointer.provider.get_app_config", return_value=mock_config):
|
with patch("deerflow.runtime.checkpointer.provider.get_app_config", return_value=mock_config):
|
||||||
with checkpointer_context() as checkpointer:
|
with checkpointer_context() as checkpointer:
|
||||||
# Should return InMemorySaver, not None
|
# Should return InMemorySaver, not None
|
||||||
assert checkpointer is not None
|
assert checkpointer is not None
|
||||||
|
|||||||
@@ -817,7 +817,7 @@ class TestEnsureAgent:
|
|||||||
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
|
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
|
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._agent_name = "custom-agent"
|
client._agent_name = "custom-agent"
|
||||||
client._available_skills = {"test_skill"}
|
client._available_skills = {"test_skill"}
|
||||||
@@ -842,7 +842,7 @@ class TestEnsureAgent:
|
|||||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=mock_checkpointer),
|
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
|
|
||||||
@@ -867,7 +867,7 @@ class TestEnsureAgent:
|
|||||||
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
|
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
|
|
||||||
@@ -886,7 +886,7 @@ class TestEnsureAgent:
|
|||||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=None),
|
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=None),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
|
|
||||||
@@ -1015,7 +1015,7 @@ class TestThreadQueries:
|
|||||||
mock_checkpointer = MagicMock()
|
mock_checkpointer = MagicMock()
|
||||||
mock_checkpointer.list.return_value = []
|
mock_checkpointer.list.return_value = []
|
||||||
|
|
||||||
with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||||
# No internal checkpointer, should fetch from provider
|
# No internal checkpointer, should fetch from provider
|
||||||
result = client.list_threads()
|
result = client.list_threads()
|
||||||
|
|
||||||
@@ -1069,7 +1069,7 @@ class TestThreadQueries:
|
|||||||
mock_checkpointer = MagicMock()
|
mock_checkpointer = MagicMock()
|
||||||
mock_checkpointer.list.return_value = []
|
mock_checkpointer.list.return_value = []
|
||||||
|
|
||||||
with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||||
result = client.get_thread("t99")
|
result = client.get_thread("t99")
|
||||||
|
|
||||||
assert result["thread_id"] == "t99"
|
assert result["thread_id"] == "t99"
|
||||||
@@ -1241,7 +1241,10 @@ class TestMemoryManagement:
|
|||||||
with patch("deerflow.agents.memory.updater.import_memory_data", return_value=imported) as mock_import:
|
with patch("deerflow.agents.memory.updater.import_memory_data", return_value=imported) as mock_import:
|
||||||
result = client.import_memory(imported)
|
result = client.import_memory(imported)
|
||||||
|
|
||||||
mock_import.assert_called_once_with(imported)
|
assert mock_import.call_count == 1
|
||||||
|
call_args = mock_import.call_args
|
||||||
|
assert call_args.args == (imported,)
|
||||||
|
assert "user_id" in call_args.kwargs
|
||||||
assert result == imported
|
assert result == imported
|
||||||
|
|
||||||
def test_reload_memory(self, client):
|
def test_reload_memory(self, client):
|
||||||
@@ -1487,9 +1490,12 @@ class TestUploads:
|
|||||||
|
|
||||||
class TestArtifacts:
|
class TestArtifacts:
|
||||||
def test_get_artifact(self, client):
|
def test_get_artifact(self, client):
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
paths = Paths(base_dir=tmp)
|
paths = Paths(base_dir=tmp)
|
||||||
outputs = paths.sandbox_outputs_dir("t1")
|
user_id = get_effective_user_id()
|
||||||
|
outputs = paths.sandbox_outputs_dir("t1", user_id=user_id)
|
||||||
outputs.mkdir(parents=True)
|
outputs.mkdir(parents=True)
|
||||||
(outputs / "result.txt").write_text("artifact content")
|
(outputs / "result.txt").write_text("artifact content")
|
||||||
|
|
||||||
@@ -1500,9 +1506,12 @@ class TestArtifacts:
|
|||||||
assert "text" in mime
|
assert "text" in mime
|
||||||
|
|
||||||
def test_get_artifact_not_found(self, client):
|
def test_get_artifact_not_found(self, client):
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
paths = Paths(base_dir=tmp)
|
paths = Paths(base_dir=tmp)
|
||||||
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
user_id = get_effective_user_id()
|
||||||
|
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
||||||
|
|
||||||
with patch("deerflow.client.get_paths", return_value=paths):
|
with patch("deerflow.client.get_paths", return_value=paths):
|
||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
@@ -1513,9 +1522,12 @@ class TestArtifacts:
|
|||||||
client.get_artifact("t1", "bad/path/file.txt")
|
client.get_artifact("t1", "bad/path/file.txt")
|
||||||
|
|
||||||
def test_get_artifact_path_traversal(self, client):
|
def test_get_artifact_path_traversal(self, client):
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
paths = Paths(base_dir=tmp)
|
paths = Paths(base_dir=tmp)
|
||||||
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
user_id = get_effective_user_id()
|
||||||
|
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
||||||
|
|
||||||
with patch("deerflow.client.get_paths", return_value=paths):
|
with patch("deerflow.client.get_paths", return_value=paths):
|
||||||
with pytest.raises(PathTraversalError):
|
with pytest.raises(PathTraversalError):
|
||||||
@@ -1699,13 +1711,16 @@ class TestScenarioFileLifecycle:
|
|||||||
|
|
||||||
def test_upload_then_read_artifact(self, client):
|
def test_upload_then_read_artifact(self, client):
|
||||||
"""Upload a file, simulate agent producing artifact, read it back."""
|
"""Upload a file, simulate agent producing artifact, read it back."""
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
tmp_path = Path(tmp)
|
tmp_path = Path(tmp)
|
||||||
uploads_dir = tmp_path / "uploads"
|
uploads_dir = tmp_path / "uploads"
|
||||||
uploads_dir.mkdir()
|
uploads_dir.mkdir()
|
||||||
|
|
||||||
paths = Paths(base_dir=tmp_path)
|
paths = Paths(base_dir=tmp_path)
|
||||||
outputs_dir = paths.sandbox_outputs_dir("t-artifact")
|
user_id = get_effective_user_id()
|
||||||
|
outputs_dir = paths.sandbox_outputs_dir("t-artifact", user_id=user_id)
|
||||||
outputs_dir.mkdir(parents=True)
|
outputs_dir.mkdir(parents=True)
|
||||||
|
|
||||||
# Upload phase
|
# Upload phase
|
||||||
@@ -1844,7 +1859,7 @@ class TestScenarioAgentRecreation:
|
|||||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config_a)
|
client._ensure_agent(config_a)
|
||||||
first_agent = client._agent
|
first_agent = client._agent
|
||||||
@@ -1872,7 +1887,7 @@ class TestScenarioAgentRecreation:
|
|||||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
@@ -1897,7 +1912,7 @@ class TestScenarioAgentRecreation:
|
|||||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
client.reset_agent()
|
client.reset_agent()
|
||||||
@@ -1955,11 +1970,14 @@ class TestScenarioThreadIsolation:
|
|||||||
|
|
||||||
def test_artifacts_isolated_per_thread(self, client):
|
def test_artifacts_isolated_per_thread(self, client):
|
||||||
"""Artifacts in thread-A are not accessible from thread-B."""
|
"""Artifacts in thread-A are not accessible from thread-B."""
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
paths = Paths(base_dir=tmp)
|
paths = Paths(base_dir=tmp)
|
||||||
outputs_a = paths.sandbox_outputs_dir("thread-a")
|
user_id = get_effective_user_id()
|
||||||
|
outputs_a = paths.sandbox_outputs_dir("thread-a", user_id=user_id)
|
||||||
outputs_a.mkdir(parents=True)
|
outputs_a.mkdir(parents=True)
|
||||||
paths.sandbox_user_data_dir("thread-b").mkdir(parents=True)
|
paths.sandbox_outputs_dir("thread-b", user_id=user_id).mkdir(parents=True)
|
||||||
(outputs_a / "result.txt").write_text("thread-a artifact")
|
(outputs_a / "result.txt").write_text("thread-a artifact")
|
||||||
|
|
||||||
with patch("deerflow.client.get_paths", return_value=paths):
|
with patch("deerflow.client.get_paths", return_value=paths):
|
||||||
@@ -2864,9 +2882,12 @@ class TestUploadDeleteSymlink:
|
|||||||
class TestArtifactHardening:
|
class TestArtifactHardening:
|
||||||
def test_artifact_directory_rejected(self, client):
|
def test_artifact_directory_rejected(self, client):
|
||||||
"""get_artifact rejects paths that resolve to a directory."""
|
"""get_artifact rejects paths that resolve to a directory."""
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
paths = Paths(base_dir=tmp)
|
paths = Paths(base_dir=tmp)
|
||||||
subdir = paths.sandbox_outputs_dir("t1") / "subdir"
|
user_id = get_effective_user_id()
|
||||||
|
subdir = paths.sandbox_outputs_dir("t1", user_id=user_id) / "subdir"
|
||||||
subdir.mkdir(parents=True)
|
subdir.mkdir(parents=True)
|
||||||
|
|
||||||
with patch("deerflow.client.get_paths", return_value=paths):
|
with patch("deerflow.client.get_paths", return_value=paths):
|
||||||
@@ -2875,9 +2896,12 @@ class TestArtifactHardening:
|
|||||||
|
|
||||||
def test_artifact_leading_slash_stripped(self, client):
|
def test_artifact_leading_slash_stripped(self, client):
|
||||||
"""Paths with leading slash are handled correctly."""
|
"""Paths with leading slash are handled correctly."""
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
paths = Paths(base_dir=tmp)
|
paths = Paths(base_dir=tmp)
|
||||||
outputs = paths.sandbox_outputs_dir("t1")
|
user_id = get_effective_user_id()
|
||||||
|
outputs = paths.sandbox_outputs_dir("t1", user_id=user_id)
|
||||||
outputs.mkdir(parents=True)
|
outputs.mkdir(parents=True)
|
||||||
(outputs / "file.txt").write_text("content")
|
(outputs / "file.txt").write_text("content")
|
||||||
|
|
||||||
@@ -2991,9 +3015,12 @@ class TestBugArtifactPrefixMatchTooLoose:
|
|||||||
|
|
||||||
def test_exact_prefix_without_subpath_accepted(self, client):
|
def test_exact_prefix_without_subpath_accepted(self, client):
|
||||||
"""Bare 'mnt/user-data' is accepted (will later fail as directory, not at prefix)."""
|
"""Bare 'mnt/user-data' is accepted (will later fail as directory, not at prefix)."""
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
paths = Paths(base_dir=tmp)
|
paths = Paths(base_dir=tmp)
|
||||||
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
user_id = get_effective_user_id()
|
||||||
|
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
||||||
|
|
||||||
with patch("deerflow.client.get_paths", return_value=paths):
|
with patch("deerflow.client.get_paths", return_value=paths):
|
||||||
# Accepted at prefix check, but fails because it's a directory.
|
# Accepted at prefix check, but fails because it's a directory.
|
||||||
|
|||||||
@@ -262,8 +262,9 @@ class TestFileUploadIntegration:
|
|||||||
|
|
||||||
# Physically exists
|
# Physically exists
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
assert (get_paths().sandbox_uploads_dir(tid) / "readme.txt").exists()
|
assert (get_paths().sandbox_uploads_dir(tid, user_id=get_effective_user_id()) / "readme.txt").exists()
|
||||||
|
|
||||||
def test_upload_duplicate_rename(self, e2e_env, tmp_path):
|
def test_upload_duplicate_rename(self, e2e_env, tmp_path):
|
||||||
"""Uploading two files with the same name auto-renames the second."""
|
"""Uploading two files with the same name auto-renames the second."""
|
||||||
@@ -472,12 +473,13 @@ class TestArtifactAccess:
|
|||||||
def test_get_artifact_happy_path(self, e2e_env):
|
def test_get_artifact_happy_path(self, e2e_env):
|
||||||
"""Write a file to outputs, then read it back via get_artifact()."""
|
"""Write a file to outputs, then read it back via get_artifact()."""
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||||
tid = str(uuid.uuid4())
|
tid = str(uuid.uuid4())
|
||||||
|
|
||||||
# Create an output file in the thread's outputs directory
|
# Create an output file in the thread's outputs directory
|
||||||
outputs_dir = get_paths().sandbox_outputs_dir(tid)
|
outputs_dir = get_paths().sandbox_outputs_dir(tid, user_id=get_effective_user_id())
|
||||||
outputs_dir.mkdir(parents=True, exist_ok=True)
|
outputs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
(outputs_dir / "result.txt").write_text("hello artifact")
|
(outputs_dir / "result.txt").write_text("hello artifact")
|
||||||
|
|
||||||
@@ -488,11 +490,12 @@ class TestArtifactAccess:
|
|||||||
def test_get_artifact_nested_path(self, e2e_env):
|
def test_get_artifact_nested_path(self, e2e_env):
|
||||||
"""Artifacts in subdirectories are accessible."""
|
"""Artifacts in subdirectories are accessible."""
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||||
tid = str(uuid.uuid4())
|
tid = str(uuid.uuid4())
|
||||||
|
|
||||||
outputs_dir = get_paths().sandbox_outputs_dir(tid)
|
outputs_dir = get_paths().sandbox_outputs_dir(tid, user_id=get_effective_user_id())
|
||||||
sub = outputs_dir / "charts"
|
sub = outputs_dir / "charts"
|
||||||
sub.mkdir(parents=True, exist_ok=True)
|
sub.mkdir(parents=True, exist_ok=True)
|
||||||
(sub / "data.json").write_text('{"x": 1}')
|
(sub / "data.json").write_text('{"x": 1}')
|
||||||
|
|||||||
+116
-134
@@ -1,21 +1,19 @@
|
|||||||
"""Tests for _ensure_admin_user() in app.py.
|
"""Tests for _ensure_admin_user() in app.py.
|
||||||
|
|
||||||
Covers: first-boot admin creation, auto-reset on needs_setup=True,
|
Covers: first-boot no-op (admin creation removed), orphan migration
|
||||||
no-op on needs_setup=False, migration, and edge cases.
|
when admin exists, no-op on no admin found, and edge cases.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from datetime import UTC, datetime, timedelta
|
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
|
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
|
||||||
|
|
||||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||||
from app.gateway.auth.models import User
|
|
||||||
|
|
||||||
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
|
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
|
||||||
|
|
||||||
@@ -35,53 +33,90 @@ def _make_app_stub(store=None):
|
|||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def _make_provider(user_count=0, admin_user=None):
|
def _make_provider(admin_count=0):
|
||||||
p = AsyncMock()
|
p = AsyncMock()
|
||||||
p.count_users = AsyncMock(return_value=user_count)
|
p.count_users = AsyncMock(return_value=admin_count)
|
||||||
p.create_user = AsyncMock(
|
p.count_admin_users = AsyncMock(return_value=admin_count)
|
||||||
side_effect=lambda **kw: User(
|
p.create_user = AsyncMock()
|
||||||
email=kw["email"],
|
|
||||||
password_hash="hashed",
|
|
||||||
system_role=kw.get("system_role", "user"),
|
|
||||||
needs_setup=kw.get("needs_setup", False),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
p.get_user_by_email = AsyncMock(return_value=admin_user)
|
|
||||||
p.update_user = AsyncMock(side_effect=lambda u: u)
|
p.update_user = AsyncMock(side_effect=lambda u: u)
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
# ── First boot: no users ─────────────────────────────────────────────────
|
def _make_session_factory(admin_row=None):
|
||||||
|
"""Build a mock async session factory that returns a row from execute()."""
|
||||||
|
row_result = MagicMock()
|
||||||
|
row_result.scalar_one_or_none.return_value = admin_row
|
||||||
|
|
||||||
|
execute_result = MagicMock()
|
||||||
|
execute_result.scalar_one_or_none.return_value = admin_row
|
||||||
|
|
||||||
|
session = AsyncMock()
|
||||||
|
session.execute = AsyncMock(return_value=execute_result)
|
||||||
|
|
||||||
|
# Async context manager
|
||||||
|
session_cm = AsyncMock()
|
||||||
|
session_cm.__aenter__ = AsyncMock(return_value=session)
|
||||||
|
session_cm.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
sf = MagicMock()
|
||||||
|
sf.return_value = session_cm
|
||||||
|
return sf
|
||||||
|
|
||||||
|
|
||||||
def test_first_boot_creates_admin():
|
# ── First boot: no admin → generate init_token, return early ─────────────
|
||||||
"""count_users==0 → create admin with needs_setup=True."""
|
|
||||||
provider = _make_provider(user_count=0)
|
|
||||||
|
def test_first_boot_does_not_create_admin():
|
||||||
|
"""admin_count==0 → generate init_token, do NOT create admin automatically."""
|
||||||
|
provider = _make_provider(admin_count=0)
|
||||||
app = _make_app_stub()
|
app = _make_app_stub()
|
||||||
|
app.state.init_token = None # lifespan sets this
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
from app.gateway.app import _ensure_admin_user
|
||||||
from app.gateway.app import _ensure_admin_user
|
|
||||||
|
|
||||||
asyncio.run(_ensure_admin_user(app))
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
provider.create_user.assert_called_once()
|
provider.create_user.assert_not_called()
|
||||||
call_kwargs = provider.create_user.call_args[1]
|
# init_token must have been set on app.state
|
||||||
assert call_kwargs["email"] == "admin@deerflow.dev"
|
assert app.state.init_token is not None
|
||||||
assert call_kwargs["system_role"] == "admin"
|
assert len(app.state.init_token) > 10
|
||||||
assert call_kwargs["needs_setup"] is True
|
|
||||||
assert len(call_kwargs["password"]) > 10 # random password generated
|
|
||||||
|
|
||||||
|
|
||||||
def test_first_boot_triggers_migration_if_store_present():
|
def test_first_boot_skips_migration():
|
||||||
"""First boot with store → _migrate_orphaned_threads called."""
|
"""No admin → return early before any migration attempt."""
|
||||||
provider = _make_provider(user_count=0)
|
provider = _make_provider(admin_count=0)
|
||||||
|
store = AsyncMock()
|
||||||
|
store.asearch = AsyncMock(return_value=[])
|
||||||
|
app = _make_app_stub(store=store)
|
||||||
|
app.state.init_token = None # lifespan sets this
|
||||||
|
|
||||||
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
|
store.asearch.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Admin exists: migration runs when admin row found ────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_admin_exists_triggers_migration():
|
||||||
|
"""Admin exists and admin row found → _migrate_orphaned_threads called."""
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
admin_row = MagicMock()
|
||||||
|
admin_row.id = uuid4()
|
||||||
|
|
||||||
|
provider = _make_provider(admin_count=1)
|
||||||
|
sf = _make_session_factory(admin_row=admin_row)
|
||||||
store = AsyncMock()
|
store = AsyncMock()
|
||||||
store.asearch = AsyncMock(return_value=[])
|
store.asearch = AsyncMock(return_value=[])
|
||||||
app = _make_app_stub(store=store)
|
app = _make_app_stub(store=store)
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||||
from app.gateway.app import _ensure_admin_user
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
asyncio.run(_ensure_admin_user(app))
|
asyncio.run(_ensure_admin_user(app))
|
||||||
@@ -89,140 +124,87 @@ def test_first_boot_triggers_migration_if_store_present():
|
|||||||
store.asearch.assert_called_once()
|
store.asearch.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_first_boot_no_store_skips_migration():
|
def test_admin_exists_no_admin_row_skips_migration():
|
||||||
"""First boot without store → no crash, migration skipped."""
|
"""Admin count > 0 but DB row missing (edge case) → skip migration gracefully."""
|
||||||
provider = _make_provider(user_count=0)
|
provider = _make_provider(admin_count=2)
|
||||||
|
sf = _make_session_factory(admin_row=None)
|
||||||
|
store = AsyncMock()
|
||||||
|
app = _make_app_stub(store=store)
|
||||||
|
|
||||||
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
|
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||||
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
|
store.asearch.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_admin_exists_no_store_skips_migration():
|
||||||
|
"""Admin exists, row found, but no store → no crash, no migration."""
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
admin_row = MagicMock()
|
||||||
|
admin_row.id = uuid4()
|
||||||
|
|
||||||
|
provider = _make_provider(admin_count=1)
|
||||||
|
sf = _make_session_factory(admin_row=admin_row)
|
||||||
app = _make_app_stub(store=None)
|
app = _make_app_stub(store=None)
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||||
from app.gateway.app import _ensure_admin_user
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
asyncio.run(_ensure_admin_user(app))
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
provider.create_user.assert_called_once()
|
# No assertion needed — just verify no crash
|
||||||
|
|
||||||
|
|
||||||
# ── Subsequent boot: needs_setup=True → auto-reset ───────────────────────
|
def test_admin_exists_session_factory_none_skips_migration():
|
||||||
|
"""get_session_factory() returns None → return early, no crash."""
|
||||||
|
provider = _make_provider(admin_count=1)
|
||||||
def test_needs_setup_true_resets_password():
|
store = AsyncMock()
|
||||||
"""Existing admin with needs_setup=True → password reset + token_version bumped."""
|
app = _make_app_stub(store=store)
|
||||||
admin = User(
|
|
||||||
email="admin@deerflow.dev",
|
|
||||||
password_hash="old-hash",
|
|
||||||
system_role="admin",
|
|
||||||
needs_setup=True,
|
|
||||||
token_version=0,
|
|
||||||
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
|
||||||
)
|
|
||||||
provider = _make_provider(user_count=1, admin_user=admin)
|
|
||||||
app = _make_app_stub()
|
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
with patch("deerflow.persistence.engine.get_session_factory", return_value=None):
|
||||||
from app.gateway.app import _ensure_admin_user
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
asyncio.run(_ensure_admin_user(app))
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
# Password was reset
|
store.asearch.assert_not_called()
|
||||||
provider.update_user.assert_called_once()
|
|
||||||
updated = provider.update_user.call_args[0][0]
|
|
||||||
assert updated.password_hash == "new-hash"
|
|
||||||
assert updated.token_version == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_needs_setup_true_consecutive_resets_increment_version():
|
|
||||||
"""Two boots with needs_setup=True → token_version increments each time."""
|
|
||||||
admin = User(
|
|
||||||
email="admin@deerflow.dev",
|
|
||||||
password_hash="hash",
|
|
||||||
system_role="admin",
|
|
||||||
needs_setup=True,
|
|
||||||
token_version=3,
|
|
||||||
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
|
||||||
)
|
|
||||||
provider = _make_provider(user_count=1, admin_user=admin)
|
|
||||||
app = _make_app_stub()
|
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
|
||||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
|
||||||
from app.gateway.app import _ensure_admin_user
|
|
||||||
|
|
||||||
asyncio.run(_ensure_admin_user(app))
|
|
||||||
|
|
||||||
updated = provider.update_user.call_args[0][0]
|
|
||||||
assert updated.token_version == 4
|
|
||||||
|
|
||||||
|
|
||||||
# ── Subsequent boot: needs_setup=False → no-op ──────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_needs_setup_false_no_reset():
|
|
||||||
"""Admin with needs_setup=False → no password reset, no update."""
|
|
||||||
admin = User(
|
|
||||||
email="admin@deerflow.dev",
|
|
||||||
password_hash="stable-hash",
|
|
||||||
system_role="admin",
|
|
||||||
needs_setup=False,
|
|
||||||
token_version=2,
|
|
||||||
)
|
|
||||||
provider = _make_provider(user_count=1, admin_user=admin)
|
|
||||||
app = _make_app_stub()
|
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
|
||||||
from app.gateway.app import _ensure_admin_user
|
|
||||||
|
|
||||||
asyncio.run(_ensure_admin_user(app))
|
|
||||||
|
|
||||||
provider.update_user.assert_not_called()
|
|
||||||
assert admin.password_hash == "stable-hash"
|
|
||||||
assert admin.token_version == 2
|
|
||||||
|
|
||||||
|
|
||||||
# ── Edge cases ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_admin_email_found_no_crash():
|
|
||||||
"""Users exist but no admin@deerflow.dev → no crash, no reset."""
|
|
||||||
provider = _make_provider(user_count=3, admin_user=None)
|
|
||||||
app = _make_app_stub()
|
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
|
||||||
from app.gateway.app import _ensure_admin_user
|
|
||||||
|
|
||||||
asyncio.run(_ensure_admin_user(app))
|
|
||||||
|
|
||||||
provider.update_user.assert_not_called()
|
|
||||||
provider.create_user.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
def test_migration_failure_is_non_fatal():
|
def test_migration_failure_is_non_fatal():
|
||||||
"""_migrate_orphaned_threads exception is caught and logged."""
|
"""_migrate_orphaned_threads exception is caught and logged."""
|
||||||
provider = _make_provider(user_count=0)
|
from uuid import uuid4
|
||||||
|
|
||||||
|
admin_row = MagicMock()
|
||||||
|
admin_row.id = uuid4()
|
||||||
|
|
||||||
|
provider = _make_provider(admin_count=1)
|
||||||
|
sf = _make_session_factory(admin_row=admin_row)
|
||||||
store = AsyncMock()
|
store = AsyncMock()
|
||||||
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
|
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
|
||||||
app = _make_app_stub(store=store)
|
app = _make_app_stub(store=store)
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
||||||
from app.gateway.app import _ensure_admin_user
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
# Should not raise
|
# Should not raise
|
||||||
asyncio.run(_ensure_admin_user(app))
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
provider.create_user.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Section 5.1-5.6 upgrade path: orphan thread migration ────────────────
|
# ── Section 5.1-5.6 upgrade path: orphan thread migration ────────────────
|
||||||
|
|
||||||
|
|
||||||
def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows():
|
||||||
"""First boot finds Store-only legacy threads → stamps admin's id.
|
"""First boot finds Store-only legacy threads → stamps admin's id.
|
||||||
|
|
||||||
Validates the **TC-UPG-02 upgrade story**: an operator running main
|
Validates the **TC-UPG-02 upgrade story**: an operator running main
|
||||||
(no auth) accumulates threads in the LangGraph Store namespace
|
(no auth) accumulates threads in the LangGraph Store namespace
|
||||||
``("threads",)`` with no ``metadata.owner_id``. After upgrading to
|
``("threads",)`` with no ``metadata.user_id``. After upgrading to
|
||||||
feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should
|
feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should
|
||||||
rewrite each unowned item with the freshly created admin's id.
|
rewrite each unowned item with the freshly created admin's id.
|
||||||
"""
|
"""
|
||||||
@@ -233,7 +215,7 @@ def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
|||||||
SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}),
|
SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}),
|
||||||
SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}),
|
SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}),
|
||||||
SimpleNamespace(key="t3", value={"metadata": {}}),
|
SimpleNamespace(key="t3", value={"metadata": {}}),
|
||||||
SimpleNamespace(key="t4", value={"metadata": {"owner_id": "someone-else", "title": "preserved"}}),
|
SimpleNamespace(key="t4", value={"metadata": {"user_id": "someone-else", "title": "preserved"}}),
|
||||||
]
|
]
|
||||||
store = AsyncMock()
|
store = AsyncMock()
|
||||||
# asearch returns the entire batch on first call, then an empty page
|
# asearch returns the entire batch on first call, then an empty page
|
||||||
@@ -253,11 +235,11 @@ def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
|||||||
assert len(aput_calls) == 3
|
assert len(aput_calls) == 3
|
||||||
rewritten_keys = {call[1] for call in aput_calls}
|
rewritten_keys = {call[1] for call in aput_calls}
|
||||||
assert rewritten_keys == {"t1", "t2", "t3"}
|
assert rewritten_keys == {"t1", "t2", "t3"}
|
||||||
# Each rewrite carries the new owner_id; titles preserved where present.
|
# Each rewrite carries the new user_id; titles preserved where present.
|
||||||
by_key = {call[1]: call[2] for call in aput_calls}
|
by_key = {call[1]: call[2] for call in aput_calls}
|
||||||
assert by_key["t1"]["metadata"]["owner_id"] == "admin-id-42"
|
assert by_key["t1"]["metadata"]["user_id"] == "admin-id-42"
|
||||||
assert by_key["t1"]["metadata"]["title"] == "old-thread-1"
|
assert by_key["t1"]["metadata"]["title"] == "old-thread-1"
|
||||||
assert by_key["t3"]["metadata"]["owner_id"] == "admin-id-42"
|
assert by_key["t3"]["metadata"]["user_id"] == "admin-id-42"
|
||||||
# The pre-owned item must NOT have been rewritten.
|
# The pre-owned item must NOT have been rewritten.
|
||||||
assert "t4" not in rewritten_keys
|
assert "t4" not in rewritten_keys
|
||||||
|
|
||||||
|
|||||||
@@ -60,8 +60,8 @@ class TestFeedbackRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_create_with_owner(self, tmp_path):
|
async def test_create_with_owner(self, tmp_path):
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, owner_id="user-1")
|
record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||||
assert record["owner_id"] == "user-1"
|
assert record["user_id"] == "user-1"
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -97,10 +97,10 @@ class TestFeedbackRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_list_by_run(self, tmp_path):
|
async def test_list_by_run(self, tmp_path):
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||||
await repo.create(run_id="r1", thread_id="t1", rating=-1)
|
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-2")
|
||||||
await repo.create(run_id="r2", thread_id="t1", rating=1)
|
await repo.create(run_id="r2", thread_id="t1", rating=1, user_id="user-1")
|
||||||
results = await repo.list_by_run("t1", "r1")
|
results = await repo.list_by_run("t1", "r1", user_id=None)
|
||||||
assert len(results) == 2
|
assert len(results) == 2
|
||||||
assert all(r["run_id"] == "r1" for r in results)
|
assert all(r["run_id"] == "r1" for r in results)
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
@@ -135,9 +135,9 @@ class TestFeedbackRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_aggregate_by_run(self, tmp_path):
|
async def test_aggregate_by_run(self, tmp_path):
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-2")
|
||||||
await repo.create(run_id="r1", thread_id="t1", rating=-1)
|
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-3")
|
||||||
stats = await repo.aggregate_by_run("t1", "r1")
|
stats = await repo.aggregate_by_run("t1", "r1")
|
||||||
assert stats["total"] == 3
|
assert stats["total"] == 3
|
||||||
assert stats["positive"] == 2
|
assert stats["positive"] == 2
|
||||||
@@ -154,6 +154,80 @@ class TestFeedbackRepository:
|
|||||||
assert stats["negative"] == 0
|
assert stats["negative"] == 0
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_upsert_creates_new(self, tmp_path):
|
||||||
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
|
record = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||||
|
assert record["rating"] == 1
|
||||||
|
assert record["feedback_id"]
|
||||||
|
assert record["user_id"] == "u1"
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_upsert_updates_existing(self, tmp_path):
|
||||||
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
|
first = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||||
|
second = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u1", comment="changed my mind")
|
||||||
|
assert second["feedback_id"] == first["feedback_id"]
|
||||||
|
assert second["rating"] == -1
|
||||||
|
assert second["comment"] == "changed my mind"
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_upsert_different_users_separate(self, tmp_path):
|
||||||
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
|
r1 = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||||
|
r2 = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u2")
|
||||||
|
assert r1["feedback_id"] != r2["feedback_id"]
|
||||||
|
assert r1["rating"] == 1
|
||||||
|
assert r2["rating"] == -1
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_upsert_invalid_rating(self, tmp_path):
|
||||||
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1")
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_delete_by_run(self, tmp_path):
|
||||||
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
|
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||||
|
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
||||||
|
assert deleted is True
|
||||||
|
results = await repo.list_by_run("t1", "r1", user_id="u1")
|
||||||
|
assert len(results) == 0
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_delete_by_run_nonexistent(self, tmp_path):
|
||||||
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
|
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
||||||
|
assert deleted is False
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_by_thread_grouped(self, tmp_path):
|
||||||
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
|
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||||
|
await repo.upsert(run_id="r2", thread_id="t1", rating=-1, user_id="u1")
|
||||||
|
await repo.upsert(run_id="r3", thread_id="t2", rating=1, user_id="u1")
|
||||||
|
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||||
|
assert "r1" in grouped
|
||||||
|
assert "r2" in grouped
|
||||||
|
assert "r3" not in grouped
|
||||||
|
assert grouped["r1"]["rating"] == 1
|
||||||
|
assert grouped["r2"]["rating"] == -1
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_by_thread_grouped_empty(self, tmp_path):
|
||||||
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
|
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||||
|
assert grouped == {}
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
|
||||||
# -- Follow-up association --
|
# -- Follow-up association --
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,229 @@
|
|||||||
|
"""Tests for the POST /api/v1/auth/initialize endpoint.
|
||||||
|
|
||||||
|
Covers: first-boot admin creation, rejection when system already
|
||||||
|
initialized, invalid/missing init_token, password strength validation,
|
||||||
|
and public accessibility (no auth cookie required).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-initialize-admin-min-32")
|
||||||
|
|
||||||
|
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||||
|
|
||||||
|
_TEST_SECRET = "test-secret-key-initialize-admin-min-32"
|
||||||
|
_INIT_TOKEN = "test-init-token-for-initialization-tests"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _setup_auth(tmp_path):
|
||||||
|
"""Fresh SQLite engine + auth config per test."""
|
||||||
|
from app.gateway import deps
|
||||||
|
from deerflow.persistence.engine import close_engine, init_engine
|
||||||
|
|
||||||
|
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||||
|
url = f"sqlite+aiosqlite:///{tmp_path}/init_admin.db"
|
||||||
|
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
|
||||||
|
deps._cached_local_provider = None
|
||||||
|
deps._cached_repo = None
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
deps._cached_local_provider = None
|
||||||
|
deps._cached_repo = None
|
||||||
|
asyncio.run(close_engine())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def client(_setup_auth):
|
||||||
|
from app.gateway.app import create_app
|
||||||
|
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||||
|
|
||||||
|
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||||
|
app = create_app()
|
||||||
|
# Pre-set the init token on app.state (normally done by the lifespan on
|
||||||
|
# first boot; tests don't run the lifespan because it requires config.yaml).
|
||||||
|
app.state.init_token = _INIT_TOKEN
|
||||||
|
# Do NOT use TestClient as a context manager — that would trigger the
|
||||||
|
# full lifespan which requires config.yaml. The auth endpoints work
|
||||||
|
# without the lifespan (persistence engine is set up by _setup_auth).
|
||||||
|
yield TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_payload(**extra):
|
||||||
|
"""Build a valid /initialize payload with the test init_token."""
|
||||||
|
return {
|
||||||
|
"email": "admin@example.com",
|
||||||
|
"password": "Str0ng!Pass99",
|
||||||
|
"init_token": _INIT_TOKEN,
|
||||||
|
**extra,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Happy path ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_creates_admin_and_sets_cookie(client):
|
||||||
|
"""POST /initialize when no admin exists → 201, session cookie set."""
|
||||||
|
resp = client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert data["email"] == "admin@example.com"
|
||||||
|
assert data["system_role"] == "admin"
|
||||||
|
assert "access_token" in resp.cookies
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_needs_setup_false(client):
|
||||||
|
"""Newly created admin via /initialize has needs_setup=False."""
|
||||||
|
client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||||
|
me = client.get("/api/v1/auth/me")
|
||||||
|
assert me.status_code == 200
|
||||||
|
assert me.json()["needs_setup"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── Token validation ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_rejects_wrong_token(client):
|
||||||
|
"""Wrong init_token → 403 invalid_init_token."""
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/initialize",
|
||||||
|
json={**_init_payload(), "init_token": "wrong-token"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
assert resp.json()["detail"]["code"] == "invalid_init_token"
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_rejects_empty_token(client):
|
||||||
|
"""Empty init_token → 403 (fails constant-time comparison against stored token)."""
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/initialize",
|
||||||
|
json={**_init_payload(), "init_token": ""},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_token_consumed_after_success(client):
|
||||||
|
"""After a successful /initialize the token is consumed and cannot be reused."""
|
||||||
|
client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||||
|
# The token is now None; any subsequent call with the old token must be rejected (403)
|
||||||
|
resp2 = client.post(
|
||||||
|
"/api/v1/auth/initialize",
|
||||||
|
json={**_init_payload(), "email": "other@example.com"},
|
||||||
|
)
|
||||||
|
assert resp2.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# ── Rejection when already initialized ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_rejected_when_admin_exists(client):
|
||||||
|
"""Second call to /initialize after admin exists → 409 system_already_initialized.
|
||||||
|
|
||||||
|
The first call consumes the token. Re-setting it on app.state simulates
|
||||||
|
what would happen if the operator somehow restarted or manually refreshed
|
||||||
|
the token (e.g., in testing).
|
||||||
|
"""
|
||||||
|
client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||||
|
# Re-set the token so the second attempt can pass token validation
|
||||||
|
# and reach the admin-exists check.
|
||||||
|
client.app.state.init_token = _INIT_TOKEN
|
||||||
|
resp2 = client.post(
|
||||||
|
"/api/v1/auth/initialize",
|
||||||
|
json={**_init_payload(), "email": "other@example.com"},
|
||||||
|
)
|
||||||
|
assert resp2.status_code == 409
|
||||||
|
body = resp2.json()
|
||||||
|
assert body["detail"]["code"] == "system_already_initialized"
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_token_not_consumed_on_admin_exists(client):
|
||||||
|
"""Token is NOT consumed when the admin-exists guard rejects the request.
|
||||||
|
|
||||||
|
This prevents a DoS where an attacker calls with the correct token when
|
||||||
|
admin already exists and permanently burns the init token.
|
||||||
|
"""
|
||||||
|
client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||||
|
# Token consumed by success above; re-simulate the scenario:
|
||||||
|
# admin exists, token is still valid (re-set), call should 409 and NOT consume token.
|
||||||
|
client.app.state.init_token = _INIT_TOKEN
|
||||||
|
client.post(
|
||||||
|
"/api/v1/auth/initialize",
|
||||||
|
json={**_init_payload(), "email": "other@example.com"},
|
||||||
|
)
|
||||||
|
# Token must still be set (not consumed) after the 409 rejection.
|
||||||
|
assert client.app.state.init_token == _INIT_TOKEN
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_register_does_not_block_initialization(client):
|
||||||
|
"""/register creating a user before /initialize doesn't block admin creation."""
|
||||||
|
# Register a regular user first
|
||||||
|
client.post("/api/v1/auth/register", json={"email": "regular@example.com", "password": "Tr0ub4dor3a"})
|
||||||
|
# /initialize should still succeed (checks admin_count, not total user_count)
|
||||||
|
resp = client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||||
|
assert resp.status_code == 201
|
||||||
|
assert resp.json()["system_role"] == "admin"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Endpoint is public (no cookie required) ───────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_accessible_without_cookie(client):
|
||||||
|
"""No access_token cookie needed for /initialize."""
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/initialize",
|
||||||
|
json=_init_payload(),
|
||||||
|
cookies={},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
|
||||||
|
|
||||||
|
# ── Password validation ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_rejects_short_password(client):
|
||||||
|
"""Password shorter than 8 chars → 422."""
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/initialize",
|
||||||
|
json={**_init_payload(), "password": "short"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_rejects_common_password(client):
|
||||||
|
"""Common password → 422."""
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/initialize",
|
||||||
|
json={**_init_payload(), "password": "password123"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
# ── setup-status reflects initialization ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup_status_before_initialization(client):
|
||||||
|
"""setup-status returns needs_setup=True before /initialize is called."""
|
||||||
|
resp = client.get("/api/v1/auth/setup-status")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["needs_setup"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup_status_after_initialization(client):
|
||||||
|
"""setup-status returns needs_setup=False after /initialize succeeds."""
|
||||||
|
client.post("/api/v1/auth/initialize", json=_init_payload())
|
||||||
|
resp = client.get("/api/v1/auth/setup-status")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["needs_setup"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup_status_false_when_only_regular_user_exists(client):
|
||||||
|
"""setup-status returns needs_setup=True even when regular users exist (no admin)."""
|
||||||
|
client.post("/api/v1/auth/register", json={"email": "regular@example.com", "password": "Tr0ub4dor3a"})
|
||||||
|
resp = client.get("/api/v1/auth/setup-status")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["needs_setup"] is True
|
||||||
@@ -152,8 +152,10 @@ def test_get_work_dir_uses_base_dir_when_no_thread_id(monkeypatch, tmp_path):
|
|||||||
def test_get_work_dir_uses_per_thread_path_when_thread_id_given(monkeypatch, tmp_path):
|
def test_get_work_dir_uses_per_thread_path_when_thread_id_given(monkeypatch, tmp_path):
|
||||||
"""P1.1: _get_work_dir(thread_id) uses {base_dir}/threads/{thread_id}/acp-workspace/."""
|
"""P1.1: _get_work_dir(thread_id) uses {base_dir}/threads/{thread_id}/acp-workspace/."""
|
||||||
from deerflow.config import paths as paths_module
|
from deerflow.config import paths as paths_module
|
||||||
|
from deerflow.runtime import user_context as uc_module
|
||||||
|
|
||||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||||
|
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
|
||||||
result = _get_work_dir("thread-abc-123")
|
result = _get_work_dir("thread-abc-123")
|
||||||
expected = tmp_path / "threads" / "thread-abc-123" / "acp-workspace"
|
expected = tmp_path / "threads" / "thread-abc-123" / "acp-workspace"
|
||||||
assert result == str(expected)
|
assert result == str(expected)
|
||||||
@@ -310,8 +312,10 @@ async def test_invoke_acp_agent_uses_fixed_acp_workspace(monkeypatch, tmp_path):
|
|||||||
async def test_invoke_acp_agent_uses_per_thread_workspace_when_thread_id_in_config(monkeypatch, tmp_path):
|
async def test_invoke_acp_agent_uses_per_thread_workspace_when_thread_id_in_config(monkeypatch, tmp_path):
|
||||||
"""P1.1: When thread_id is in the RunnableConfig, ACP agent uses per-thread workspace."""
|
"""P1.1: When thread_id is in the RunnableConfig, ACP agent uses per-thread workspace."""
|
||||||
from deerflow.config import paths as paths_module
|
from deerflow.config import paths as paths_module
|
||||||
|
from deerflow.runtime import user_context as uc_module
|
||||||
|
|
||||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||||
|
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||||
|
|||||||
@@ -175,46 +175,46 @@ def _make_ctx(user_id):
|
|||||||
def test_filter_injects_user_id():
|
def test_filter_injects_user_id():
|
||||||
value = {}
|
value = {}
|
||||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||||
assert value["metadata"]["owner_id"] == "user-a"
|
assert value["metadata"]["user_id"] == "user-a"
|
||||||
|
|
||||||
|
|
||||||
def test_filter_preserves_existing_metadata():
|
def test_filter_preserves_existing_metadata():
|
||||||
value = {"metadata": {"title": "hello"}}
|
value = {"metadata": {"title": "hello"}}
|
||||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||||
assert value["metadata"]["owner_id"] == "user-a"
|
assert value["metadata"]["user_id"] == "user-a"
|
||||||
assert value["metadata"]["title"] == "hello"
|
assert value["metadata"]["title"] == "hello"
|
||||||
|
|
||||||
|
|
||||||
def test_filter_returns_user_id_dict():
|
def test_filter_returns_user_id_dict():
|
||||||
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
||||||
assert result == {"owner_id": "user-x"}
|
assert result == {"user_id": "user-x"}
|
||||||
|
|
||||||
|
|
||||||
def test_filter_read_write_consistency():
|
def test_filter_read_write_consistency():
|
||||||
value = {}
|
value = {}
|
||||||
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
||||||
assert value["metadata"]["owner_id"] == filter_dict["owner_id"]
|
assert value["metadata"]["user_id"] == filter_dict["user_id"]
|
||||||
|
|
||||||
|
|
||||||
def test_different_users_different_filters():
|
def test_different_users_different_filters():
|
||||||
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
||||||
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
||||||
assert f_a["owner_id"] != f_b["owner_id"]
|
assert f_a["user_id"] != f_b["user_id"]
|
||||||
|
|
||||||
|
|
||||||
def test_filter_overrides_conflicting_user_id():
|
def test_filter_overrides_conflicting_user_id():
|
||||||
"""If value already has a different user_id in metadata, it gets overwritten."""
|
"""If value already has a different user_id in metadata, it gets overwritten."""
|
||||||
value = {"metadata": {"owner_id": "attacker"}}
|
value = {"metadata": {"user_id": "attacker"}}
|
||||||
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
||||||
assert value["metadata"]["owner_id"] == "real-owner"
|
assert value["metadata"]["user_id"] == "real-owner"
|
||||||
|
|
||||||
|
|
||||||
def test_filter_with_empty_metadata():
|
def test_filter_with_empty_metadata():
|
||||||
"""Explicit empty metadata dict is fine."""
|
"""Explicit empty metadata dict is fine."""
|
||||||
value = {"metadata": {}}
|
value = {"metadata": {}}
|
||||||
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
||||||
assert value["metadata"]["owner_id"] == "user-z"
|
assert value["metadata"]["user_id"] == "user-z"
|
||||||
assert result == {"owner_id": "user-z"}
|
assert result == {"user_id": "user-z"}
|
||||||
|
|
||||||
|
|
||||||
# ── Gateway parity ───────────────────────────────────────────────────────
|
# ── Gateway parity ───────────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
|||||||
agent_name="lead_agent",
|
agent_name="lead_agent",
|
||||||
correction_detected=True,
|
correction_detected=True,
|
||||||
reinforcement_detected=False,
|
reinforcement_detected=False,
|
||||||
|
user_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -88,4 +89,5 @@ def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
|||||||
agent_name="lead_agent",
|
agent_name="lead_agent",
|
||||||
correction_detected=False,
|
correction_detected=False,
|
||||||
reinforcement_detected=True,
|
reinforcement_detected=True,
|
||||||
|
user_id=None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,38 @@
|
|||||||
|
"""Tests for user_id propagation through memory queue."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_context_has_user_id():
|
||||||
|
ctx = ConversationContext(thread_id="t1", messages=[], user_id="alice")
|
||||||
|
assert ctx.user_id == "alice"
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_context_user_id_default_none():
|
||||||
|
ctx = ConversationContext(thread_id="t1", messages=[])
|
||||||
|
assert ctx.user_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_queue_add_stores_user_id():
|
||||||
|
q = MemoryUpdateQueue()
|
||||||
|
with patch.object(q, "_reset_timer"):
|
||||||
|
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||||
|
assert len(q._queue) == 1
|
||||||
|
assert q._queue[0].user_id == "alice"
|
||||||
|
q.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_queue_process_passes_user_id_to_updater():
|
||||||
|
q = MemoryUpdateQueue()
|
||||||
|
with patch.object(q, "_reset_timer"):
|
||||||
|
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
||||||
|
|
||||||
|
mock_updater = MagicMock()
|
||||||
|
mock_updater.update_memory.return_value = True
|
||||||
|
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
|
||||||
|
q._process_queue()
|
||||||
|
|
||||||
|
mock_updater.update_memory.assert_called_once()
|
||||||
|
call_kwargs = mock_updater.update_memory.call_args.kwargs
|
||||||
|
assert call_kwargs["user_id"] == "alice"
|
||||||
@@ -258,12 +258,13 @@ def test_update_memory_fact_route_preserves_omitted_fields() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
update_fact.assert_called_once_with(
|
assert update_fact.call_count == 1
|
||||||
fact_id="fact_edit",
|
call_kwargs = update_fact.call_args.kwargs
|
||||||
content="User prefers spaces",
|
assert call_kwargs.get("fact_id") == "fact_edit"
|
||||||
category=None,
|
assert call_kwargs.get("content") == "User prefers spaces"
|
||||||
confidence=None,
|
assert call_kwargs.get("category") is None
|
||||||
)
|
assert call_kwargs.get("confidence") is None
|
||||||
|
assert "user_id" in call_kwargs
|
||||||
assert response.json()["facts"] == updated_memory["facts"]
|
assert response.json()["facts"] == updated_memory["facts"]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,150 @@
|
|||||||
|
"""Tests for per-user memory storage isolation."""
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from deerflow.agents.memory.storage import FileMemoryStorage, create_empty_memory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base_dir(tmp_path: Path) -> Path:
|
||||||
|
return tmp_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def storage() -> FileMemoryStorage:
|
||||||
|
return FileMemoryStorage()
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserIsolatedStorage:
|
||||||
|
def test_save_and_load_per_user(self, storage: FileMemoryStorage, base_dir: Path):
|
||||||
|
from deerflow.config.paths import Paths
|
||||||
|
|
||||||
|
paths = Paths(base_dir)
|
||||||
|
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||||
|
memory_a = create_empty_memory()
|
||||||
|
memory_a["user"]["workContext"]["summary"] = "User A context"
|
||||||
|
storage.save(memory_a, user_id="alice")
|
||||||
|
|
||||||
|
memory_b = create_empty_memory()
|
||||||
|
memory_b["user"]["workContext"]["summary"] = "User B context"
|
||||||
|
storage.save(memory_b, user_id="bob")
|
||||||
|
|
||||||
|
loaded_a = storage.load(user_id="alice")
|
||||||
|
loaded_b = storage.load(user_id="bob")
|
||||||
|
|
||||||
|
assert loaded_a["user"]["workContext"]["summary"] == "User A context"
|
||||||
|
assert loaded_b["user"]["workContext"]["summary"] == "User B context"
|
||||||
|
|
||||||
|
def test_user_memory_file_location(self, base_dir: Path):
|
||||||
|
from deerflow.config.paths import Paths
|
||||||
|
|
||||||
|
paths = Paths(base_dir)
|
||||||
|
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||||
|
s = FileMemoryStorage()
|
||||||
|
memory = create_empty_memory()
|
||||||
|
s.save(memory, user_id="alice")
|
||||||
|
expected_path = base_dir / "users" / "alice" / "memory.json"
|
||||||
|
assert expected_path.exists()
|
||||||
|
|
||||||
|
def test_cache_isolated_per_user(self, base_dir: Path):
|
||||||
|
from deerflow.config.paths import Paths
|
||||||
|
|
||||||
|
paths = Paths(base_dir)
|
||||||
|
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||||
|
s = FileMemoryStorage()
|
||||||
|
memory_a = create_empty_memory()
|
||||||
|
memory_a["user"]["workContext"]["summary"] = "A"
|
||||||
|
s.save(memory_a, user_id="alice")
|
||||||
|
|
||||||
|
memory_b = create_empty_memory()
|
||||||
|
memory_b["user"]["workContext"]["summary"] = "B"
|
||||||
|
s.save(memory_b, user_id="bob")
|
||||||
|
|
||||||
|
loaded_a = s.load(user_id="alice")
|
||||||
|
assert loaded_a["user"]["workContext"]["summary"] == "A"
|
||||||
|
|
||||||
|
def test_no_user_id_uses_legacy_path(self, base_dir: Path):
|
||||||
|
from deerflow.config.paths import Paths
|
||||||
|
from deerflow.config.memory_config import MemoryConfig
|
||||||
|
|
||||||
|
paths = Paths(base_dir)
|
||||||
|
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||||
|
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||||
|
s = FileMemoryStorage()
|
||||||
|
memory = create_empty_memory()
|
||||||
|
s.save(memory, user_id=None)
|
||||||
|
expected_path = base_dir / "memory.json"
|
||||||
|
assert expected_path.exists()
|
||||||
|
|
||||||
|
def test_user_and_legacy_do_not_interfere(self, base_dir: Path):
|
||||||
|
"""user_id=None (legacy) and user_id='alice' must use different files and caches."""
|
||||||
|
from deerflow.config.paths import Paths
|
||||||
|
from deerflow.config.memory_config import MemoryConfig
|
||||||
|
|
||||||
|
paths = Paths(base_dir)
|
||||||
|
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||||
|
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
||||||
|
s = FileMemoryStorage()
|
||||||
|
|
||||||
|
legacy_mem = create_empty_memory()
|
||||||
|
legacy_mem["user"]["workContext"]["summary"] = "legacy"
|
||||||
|
s.save(legacy_mem, user_id=None)
|
||||||
|
|
||||||
|
user_mem = create_empty_memory()
|
||||||
|
user_mem["user"]["workContext"]["summary"] = "alice"
|
||||||
|
s.save(user_mem, user_id="alice")
|
||||||
|
|
||||||
|
assert s.load(user_id=None)["user"]["workContext"]["summary"] == "legacy"
|
||||||
|
assert s.load(user_id="alice")["user"]["workContext"]["summary"] == "alice"
|
||||||
|
|
||||||
|
def test_user_agent_memory_file_location(self, base_dir: Path):
|
||||||
|
"""Per-user per-agent memory uses the user_agent_memory_file path."""
|
||||||
|
from deerflow.config.paths import Paths
|
||||||
|
|
||||||
|
paths = Paths(base_dir)
|
||||||
|
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||||
|
s = FileMemoryStorage()
|
||||||
|
memory = create_empty_memory()
|
||||||
|
memory["user"]["workContext"]["summary"] = "agent scoped"
|
||||||
|
s.save(memory, "test-agent", user_id="alice")
|
||||||
|
expected_path = base_dir / "users" / "alice" / "agents" / "test-agent" / "memory.json"
|
||||||
|
assert expected_path.exists()
|
||||||
|
|
||||||
|
def test_cache_key_is_user_agent_tuple(self, base_dir: Path):
|
||||||
|
"""Cache keys must be (user_id, agent_name) tuples, not bare agent names."""
|
||||||
|
from deerflow.config.paths import Paths
|
||||||
|
|
||||||
|
paths = Paths(base_dir)
|
||||||
|
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||||
|
s = FileMemoryStorage()
|
||||||
|
memory = create_empty_memory()
|
||||||
|
s.save(memory, user_id="alice")
|
||||||
|
# After save, cache should have tuple key
|
||||||
|
assert ("alice", None) in s._memory_cache
|
||||||
|
|
||||||
|
def test_reload_with_user_id(self, base_dir: Path):
|
||||||
|
"""reload() with user_id should force re-read from the user-scoped file."""
|
||||||
|
from deerflow.config.paths import Paths
|
||||||
|
|
||||||
|
paths = Paths(base_dir)
|
||||||
|
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
||||||
|
s = FileMemoryStorage()
|
||||||
|
memory = create_empty_memory()
|
||||||
|
memory["user"]["workContext"]["summary"] = "initial"
|
||||||
|
s.save(memory, user_id="alice")
|
||||||
|
|
||||||
|
# Load once to prime cache
|
||||||
|
s.load(user_id="alice")
|
||||||
|
|
||||||
|
# Write updated content directly to file
|
||||||
|
user_file = base_dir / "users" / "alice" / "memory.json"
|
||||||
|
import json
|
||||||
|
|
||||||
|
updated = create_empty_memory()
|
||||||
|
updated["user"]["workContext"]["summary"] = "updated"
|
||||||
|
user_file.write_text(json.dumps(updated))
|
||||||
|
|
||||||
|
# reload should pick up the new content
|
||||||
|
reloaded = s.reload(user_id="alice")
|
||||||
|
assert reloaded["user"]["workContext"]["summary"] == "updated"
|
||||||
@@ -0,0 +1,156 @@
|
|||||||
|
"""Owner isolation tests for MemoryThreadMetaStore.
|
||||||
|
|
||||||
|
Mirrors the SQL-backed tests in test_owner_isolation.py but exercises
|
||||||
|
the in-memory LangGraph Store backend used when database.backend=memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langgraph.store.memory import InMemoryStore
|
||||||
|
|
||||||
|
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||||
|
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||||
|
|
||||||
|
USER_A = SimpleNamespace(id="user-a", email="a@test.local")
|
||||||
|
USER_B = SimpleNamespace(id="user-b", email="b@test.local")
|
||||||
|
|
||||||
|
|
||||||
|
def _as_user(user):
|
||||||
|
class _Ctx:
|
||||||
|
def __enter__(self):
|
||||||
|
self._token = set_current_user(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
def __exit__(self, *exc):
|
||||||
|
reset_current_user(self._token)
|
||||||
|
|
||||||
|
return _Ctx()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store():
|
||||||
|
return MemoryThreadMetaStore(InMemoryStore())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
async def test_search_isolation(store):
|
||||||
|
"""search() returns only threads owned by the current user."""
|
||||||
|
with _as_user(USER_A):
|
||||||
|
await store.create("t-alpha", display_name="A's thread")
|
||||||
|
with _as_user(USER_B):
|
||||||
|
await store.create("t-beta", display_name="B's thread")
|
||||||
|
|
||||||
|
with _as_user(USER_A):
|
||||||
|
results = await store.search()
|
||||||
|
assert [r["thread_id"] for r in results] == ["t-alpha"]
|
||||||
|
|
||||||
|
with _as_user(USER_B):
|
||||||
|
results = await store.search()
|
||||||
|
assert [r["thread_id"] for r in results] == ["t-beta"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
async def test_get_isolation(store):
|
||||||
|
"""get() returns None for threads owned by another user."""
|
||||||
|
with _as_user(USER_A):
|
||||||
|
await store.create("t-alpha", display_name="A's thread")
|
||||||
|
|
||||||
|
with _as_user(USER_B):
|
||||||
|
assert await store.get("t-alpha") is None
|
||||||
|
|
||||||
|
with _as_user(USER_A):
|
||||||
|
result = await store.get("t-alpha")
|
||||||
|
assert result is not None
|
||||||
|
assert result["display_name"] == "A's thread"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
async def test_update_display_name_denied(store):
|
||||||
|
"""User B cannot rename User A's thread."""
|
||||||
|
with _as_user(USER_A):
|
||||||
|
await store.create("t-alpha", display_name="original")
|
||||||
|
|
||||||
|
with _as_user(USER_B):
|
||||||
|
await store.update_display_name("t-alpha", "hacked")
|
||||||
|
|
||||||
|
with _as_user(USER_A):
|
||||||
|
row = await store.get("t-alpha")
|
||||||
|
assert row is not None
|
||||||
|
assert row["display_name"] == "original"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
async def test_update_status_denied(store):
|
||||||
|
"""User B cannot change status of User A's thread."""
|
||||||
|
with _as_user(USER_A):
|
||||||
|
await store.create("t-alpha")
|
||||||
|
|
||||||
|
with _as_user(USER_B):
|
||||||
|
await store.update_status("t-alpha", "error")
|
||||||
|
|
||||||
|
with _as_user(USER_A):
|
||||||
|
row = await store.get("t-alpha")
|
||||||
|
assert row is not None
|
||||||
|
assert row["status"] == "idle"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
async def test_update_metadata_denied(store):
|
||||||
|
"""User B cannot modify metadata of User A's thread."""
|
||||||
|
with _as_user(USER_A):
|
||||||
|
await store.create("t-alpha", metadata={"key": "original"})
|
||||||
|
|
||||||
|
with _as_user(USER_B):
|
||||||
|
await store.update_metadata("t-alpha", {"key": "hacked"})
|
||||||
|
|
||||||
|
with _as_user(USER_A):
|
||||||
|
row = await store.get("t-alpha")
|
||||||
|
assert row is not None
|
||||||
|
assert row["metadata"]["key"] == "original"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
async def test_delete_denied(store):
|
||||||
|
"""User B cannot delete User A's thread."""
|
||||||
|
with _as_user(USER_A):
|
||||||
|
await store.create("t-alpha")
|
||||||
|
|
||||||
|
with _as_user(USER_B):
|
||||||
|
await store.delete("t-alpha")
|
||||||
|
|
||||||
|
with _as_user(USER_A):
|
||||||
|
row = await store.get("t-alpha")
|
||||||
|
assert row is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
async def test_no_context_raises(store):
|
||||||
|
"""Calling methods without user context raises RuntimeError."""
|
||||||
|
with pytest.raises(RuntimeError, match="no user context is set"):
|
||||||
|
await store.search()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
async def test_explicit_none_bypasses_filter(store):
|
||||||
|
"""user_id=None bypasses isolation (migration/CLI escape hatch)."""
|
||||||
|
with _as_user(USER_A):
|
||||||
|
await store.create("t-alpha")
|
||||||
|
with _as_user(USER_B):
|
||||||
|
await store.create("t-beta")
|
||||||
|
|
||||||
|
all_rows = await store.search(user_id=None)
|
||||||
|
assert {r["thread_id"] for r in all_rows} == {"t-alpha", "t-beta"}
|
||||||
|
|
||||||
|
row = await store.get("t-alpha", user_id=None)
|
||||||
|
assert row is not None
|
||||||
@@ -301,8 +301,8 @@ def test_import_memory_data_saves_and_returns_imported_memory() -> None:
|
|||||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||||
result = import_memory_data(imported_memory)
|
result = import_memory_data(imported_memory)
|
||||||
|
|
||||||
mock_storage.save.assert_called_once_with(imported_memory, None)
|
mock_storage.save.assert_called_once_with(imported_memory, None, user_id=None)
|
||||||
mock_storage.load.assert_called_once_with(None)
|
mock_storage.load.assert_called_once_with(None, user_id=None)
|
||||||
assert result == imported_memory
|
assert result == imported_memory
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
"""Tests for user_id propagation in memory updater."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from deerflow.agents.memory.updater import get_memory_data, clear_memory_data, _save_memory_to_file
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_memory_data_passes_user_id():
|
||||||
|
mock_storage = MagicMock()
|
||||||
|
mock_storage.load.return_value = {"version": "1.0"}
|
||||||
|
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||||
|
get_memory_data(user_id="alice")
|
||||||
|
mock_storage.load.assert_called_once_with(None, user_id="alice")
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_memory_passes_user_id():
|
||||||
|
mock_storage = MagicMock()
|
||||||
|
mock_storage.save.return_value = True
|
||||||
|
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||||
|
_save_memory_to_file({"version": "1.0"}, user_id="bob")
|
||||||
|
mock_storage.save.assert_called_once_with({"version": "1.0"}, None, user_id="bob")
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_memory_data_passes_user_id():
|
||||||
|
mock_storage = MagicMock()
|
||||||
|
mock_storage.save.return_value = True
|
||||||
|
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||||
|
clear_memory_data(user_id="charlie")
|
||||||
|
# Verify save was called with user_id
|
||||||
|
assert mock_storage.save.call_args.kwargs["user_id"] == "charlie"
|
||||||
@@ -0,0 +1,116 @@
|
|||||||
|
"""Tests for per-user data migration."""
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from deerflow.config.paths import Paths
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base_dir(tmp_path: Path) -> Path:
|
||||||
|
return tmp_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def paths(base_dir: Path) -> Paths:
|
||||||
|
return Paths(base_dir)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMigrateThreadDirs:
|
||||||
|
def test_moves_thread_to_user_dir(self, base_dir: Path, paths: Paths):
|
||||||
|
legacy = base_dir / "threads" / "t1" / "user-data" / "workspace"
|
||||||
|
legacy.mkdir(parents=True)
|
||||||
|
(legacy / "file.txt").write_text("hello")
|
||||||
|
|
||||||
|
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||||
|
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||||
|
|
||||||
|
expected = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace" / "file.txt"
|
||||||
|
assert expected.exists()
|
||||||
|
assert expected.read_text() == "hello"
|
||||||
|
assert not (base_dir / "threads" / "t1").exists()
|
||||||
|
|
||||||
|
def test_unowned_thread_goes_to_default(self, base_dir: Path, paths: Paths):
|
||||||
|
legacy = base_dir / "threads" / "t2" / "user-data" / "workspace"
|
||||||
|
legacy.mkdir(parents=True)
|
||||||
|
|
||||||
|
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||||
|
migrate_thread_dirs(paths, thread_owner_map={})
|
||||||
|
|
||||||
|
expected = base_dir / "users" / "default" / "threads" / "t2"
|
||||||
|
assert expected.exists()
|
||||||
|
|
||||||
|
def test_idempotent_skip_already_migrated(self, base_dir: Path, paths: Paths):
|
||||||
|
new_dir = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace"
|
||||||
|
new_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||||
|
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||||
|
assert new_dir.exists()
|
||||||
|
|
||||||
|
def test_conflict_preserved(self, base_dir: Path, paths: Paths):
|
||||||
|
legacy = base_dir / "threads" / "t1" / "user-data" / "workspace"
|
||||||
|
legacy.mkdir(parents=True)
|
||||||
|
(legacy / "old.txt").write_text("old")
|
||||||
|
|
||||||
|
dest = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace"
|
||||||
|
dest.mkdir(parents=True)
|
||||||
|
(dest / "new.txt").write_text("new")
|
||||||
|
|
||||||
|
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||||
|
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
||||||
|
|
||||||
|
assert (dest / "new.txt").read_text() == "new"
|
||||||
|
conflicts = base_dir / "migration-conflicts" / "t1"
|
||||||
|
assert conflicts.exists()
|
||||||
|
|
||||||
|
def test_cleans_up_empty_legacy_dir(self, base_dir: Path, paths: Paths):
|
||||||
|
legacy = base_dir / "threads" / "t1" / "user-data"
|
||||||
|
legacy.mkdir(parents=True)
|
||||||
|
|
||||||
|
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||||
|
migrate_thread_dirs(paths, thread_owner_map={})
|
||||||
|
|
||||||
|
assert not (base_dir / "threads").exists()
|
||||||
|
|
||||||
|
def test_dry_run_does_not_move(self, base_dir: Path, paths: Paths):
|
||||||
|
legacy = base_dir / "threads" / "t1" / "user-data"
|
||||||
|
legacy.mkdir(parents=True)
|
||||||
|
|
||||||
|
from scripts.migrate_user_isolation import migrate_thread_dirs
|
||||||
|
report = migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"}, dry_run=True)
|
||||||
|
|
||||||
|
assert len(report) == 1
|
||||||
|
assert (base_dir / "threads" / "t1").exists() # not moved
|
||||||
|
assert not (base_dir / "users" / "alice" / "threads" / "t1").exists()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMigrateMemory:
|
||||||
|
def test_moves_global_memory(self, base_dir: Path, paths: Paths):
|
||||||
|
legacy_mem = base_dir / "memory.json"
|
||||||
|
legacy_mem.write_text(json.dumps({"version": "1.0", "facts": []}))
|
||||||
|
|
||||||
|
from scripts.migrate_user_isolation import migrate_memory
|
||||||
|
migrate_memory(paths, user_id="default")
|
||||||
|
|
||||||
|
expected = base_dir / "users" / "default" / "memory.json"
|
||||||
|
assert expected.exists()
|
||||||
|
assert not legacy_mem.exists()
|
||||||
|
|
||||||
|
def test_skips_if_destination_exists(self, base_dir: Path, paths: Paths):
|
||||||
|
legacy_mem = base_dir / "memory.json"
|
||||||
|
legacy_mem.write_text(json.dumps({"version": "old"}))
|
||||||
|
|
||||||
|
dest = base_dir / "users" / "default" / "memory.json"
|
||||||
|
dest.parent.mkdir(parents=True)
|
||||||
|
dest.write_text(json.dumps({"version": "new"}))
|
||||||
|
|
||||||
|
from scripts.migrate_user_isolation import migrate_memory
|
||||||
|
migrate_memory(paths, user_id="default")
|
||||||
|
|
||||||
|
assert json.loads(dest.read_text())["version"] == "new"
|
||||||
|
assert (base_dir / "memory.legacy.json").exists()
|
||||||
|
|
||||||
|
def test_no_legacy_memory_is_noop(self, base_dir: Path, paths: Paths):
|
||||||
|
from scripts.migrate_user_isolation import migrate_memory
|
||||||
|
migrate_memory(paths, user_id="default") # should not raise
|
||||||
@@ -9,8 +9,8 @@ These tests bypass the HTTP layer and exercise the storage-layer
|
|||||||
owner filter directly by switching the ``user_context`` contextvar
|
owner filter directly by switching the ``user_context`` contextvar
|
||||||
between two users. The safety property under test is:
|
between two users. The safety property under test is:
|
||||||
|
|
||||||
After a repository write with owner_id=A, a subsequent read with
|
After a repository write with user_id=A, a subsequent read with
|
||||||
owner_id=B must not return the row, and vice versa.
|
user_id=B must not return the row, and vice versa.
|
||||||
|
|
||||||
The HTTP layer is covered by test_auth_middleware.py, which proves
|
The HTTP layer is covered by test_auth_middleware.py, which proves
|
||||||
that a request cookie reaches the ``set_current_user`` call. Together
|
that a request cookie reaches the ``set_current_user`` call. Together
|
||||||
@@ -431,13 +431,13 @@ async def test_repository_without_context_raises(tmp_path):
|
|||||||
await cleanup()
|
await cleanup()
|
||||||
|
|
||||||
|
|
||||||
# ── Escape hatch: explicit owner_id=None bypasses filter (for migration) ──
|
# ── Escape hatch: explicit user_id=None bypasses filter (for migration) ──
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.no_auto_user
|
@pytest.mark.no_auto_user
|
||||||
async def test_explicit_none_bypasses_filter(tmp_path):
|
async def test_explicit_none_bypasses_filter(tmp_path):
|
||||||
"""Migration scripts pass owner_id=None to see all rows regardless of owner."""
|
"""Migration scripts pass user_id=None to see all rows regardless of owner."""
|
||||||
from deerflow.persistence.engine import get_session_factory
|
from deerflow.persistence.engine import get_session_factory
|
||||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||||
|
|
||||||
@@ -452,14 +452,14 @@ async def test_explicit_none_bypasses_filter(tmp_path):
|
|||||||
await repo.create("t-beta")
|
await repo.create("t-beta")
|
||||||
|
|
||||||
# Migration-style read: no contextvar, explicit None bypass.
|
# Migration-style read: no contextvar, explicit None bypass.
|
||||||
all_rows = await repo.search(owner_id=None)
|
all_rows = await repo.search(user_id=None)
|
||||||
thread_ids = {r["thread_id"] for r in all_rows}
|
thread_ids = {r["thread_id"] for r in all_rows}
|
||||||
assert thread_ids == {"t-alpha", "t-beta"}
|
assert thread_ids == {"t-alpha", "t-beta"}
|
||||||
|
|
||||||
# Explicit get with None does not apply the filter either.
|
# Explicit get with None does not apply the filter either.
|
||||||
row_a = await repo.get("t-alpha", owner_id=None)
|
row_a = await repo.get("t-alpha", user_id=None)
|
||||||
assert row_a is not None
|
assert row_a is not None
|
||||||
row_b = await repo.get("t-beta", owner_id=None)
|
row_b = await repo.get("t-beta", user_id=None)
|
||||||
assert row_b is not None
|
assert row_b is not None
|
||||||
finally:
|
finally:
|
||||||
await cleanup()
|
await cleanup()
|
||||||
|
|||||||
@@ -0,0 +1,167 @@
|
|||||||
|
"""Tests for user-scoped path resolution in Paths."""
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from deerflow.config.paths import Paths
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def paths(tmp_path: Path) -> Paths:
|
||||||
|
return Paths(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateUserId:
|
||||||
|
def test_valid_user_id(self, paths: Paths):
|
||||||
|
d = paths.user_dir("u-abc-123")
|
||||||
|
assert d == paths.base_dir / "users" / "u-abc-123"
|
||||||
|
|
||||||
|
def test_rejects_path_traversal(self, paths: Paths):
|
||||||
|
with pytest.raises(ValueError, match="Invalid user_id"):
|
||||||
|
paths.user_dir("../escape")
|
||||||
|
|
||||||
|
def test_rejects_slash(self, paths: Paths):
|
||||||
|
with pytest.raises(ValueError, match="Invalid user_id"):
|
||||||
|
paths.user_dir("foo/bar")
|
||||||
|
|
||||||
|
def test_rejects_empty(self, paths: Paths):
|
||||||
|
with pytest.raises(ValueError, match="Invalid user_id"):
|
||||||
|
paths.user_dir("")
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserDir:
|
||||||
|
def test_user_dir(self, paths: Paths):
|
||||||
|
assert paths.user_dir("alice") == paths.base_dir / "users" / "alice"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserMemoryFile:
|
||||||
|
def test_user_memory_file(self, paths: Paths):
|
||||||
|
assert paths.user_memory_file("bob") == paths.base_dir / "users" / "bob" / "memory.json"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserAgentMemoryFile:
|
||||||
|
def test_user_agent_memory_file(self, paths: Paths):
|
||||||
|
expected = paths.base_dir / "users" / "bob" / "agents" / "myagent" / "memory.json"
|
||||||
|
assert paths.user_agent_memory_file("bob", "myagent") == expected
|
||||||
|
|
||||||
|
def test_user_agent_memory_file_lowercases_name(self, paths: Paths):
|
||||||
|
expected = paths.base_dir / "users" / "bob" / "agents" / "myagent" / "memory.json"
|
||||||
|
assert paths.user_agent_memory_file("bob", "MyAgent") == expected
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserThreadDir:
|
||||||
|
def test_user_thread_dir(self, paths: Paths):
|
||||||
|
expected = paths.base_dir / "users" / "u1" / "threads" / "t1"
|
||||||
|
assert paths.thread_dir("t1", user_id="u1") == expected
|
||||||
|
|
||||||
|
def test_thread_dir_no_user_id_falls_back_to_legacy(self, paths: Paths):
|
||||||
|
expected = paths.base_dir / "threads" / "t1"
|
||||||
|
assert paths.thread_dir("t1") == expected
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserSandboxDirs:
|
||||||
|
def test_sandbox_work_dir(self, paths: Paths):
|
||||||
|
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "workspace"
|
||||||
|
assert paths.sandbox_work_dir("t1", user_id="u1") == expected
|
||||||
|
|
||||||
|
def test_sandbox_uploads_dir(self, paths: Paths):
|
||||||
|
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "uploads"
|
||||||
|
assert paths.sandbox_uploads_dir("t1", user_id="u1") == expected
|
||||||
|
|
||||||
|
def test_sandbox_outputs_dir(self, paths: Paths):
|
||||||
|
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "outputs"
|
||||||
|
assert paths.sandbox_outputs_dir("t1", user_id="u1") == expected
|
||||||
|
|
||||||
|
def test_sandbox_user_data_dir(self, paths: Paths):
|
||||||
|
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data"
|
||||||
|
assert paths.sandbox_user_data_dir("t1", user_id="u1") == expected
|
||||||
|
|
||||||
|
def test_acp_workspace_dir(self, paths: Paths):
|
||||||
|
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "acp-workspace"
|
||||||
|
assert paths.acp_workspace_dir("t1", user_id="u1") == expected
|
||||||
|
|
||||||
|
def test_legacy_sandbox_work_dir(self, paths: Paths):
|
||||||
|
expected = paths.base_dir / "threads" / "t1" / "user-data" / "workspace"
|
||||||
|
assert paths.sandbox_work_dir("t1") == expected
|
||||||
|
|
||||||
|
|
||||||
|
class TestHostPathsWithUserId:
|
||||||
|
def test_host_thread_dir_with_user_id(self, paths: Paths):
|
||||||
|
result = paths.host_thread_dir("t1", user_id="u1")
|
||||||
|
assert "users" in result
|
||||||
|
assert "u1" in result
|
||||||
|
assert "threads" in result
|
||||||
|
assert "t1" in result
|
||||||
|
|
||||||
|
def test_host_thread_dir_legacy(self, paths: Paths):
|
||||||
|
result = paths.host_thread_dir("t1")
|
||||||
|
assert "threads" in result
|
||||||
|
assert "t1" in result
|
||||||
|
assert "users" not in result
|
||||||
|
|
||||||
|
def test_host_sandbox_user_data_dir_with_user_id(self, paths: Paths):
|
||||||
|
result = paths.host_sandbox_user_data_dir("t1", user_id="u1")
|
||||||
|
assert "users" in result
|
||||||
|
assert "user-data" in result
|
||||||
|
|
||||||
|
def test_host_sandbox_work_dir_with_user_id(self, paths: Paths):
|
||||||
|
result = paths.host_sandbox_work_dir("t1", user_id="u1")
|
||||||
|
assert "workspace" in result
|
||||||
|
|
||||||
|
def test_host_sandbox_uploads_dir_with_user_id(self, paths: Paths):
|
||||||
|
result = paths.host_sandbox_uploads_dir("t1", user_id="u1")
|
||||||
|
assert "uploads" in result
|
||||||
|
|
||||||
|
def test_host_sandbox_outputs_dir_with_user_id(self, paths: Paths):
|
||||||
|
result = paths.host_sandbox_outputs_dir("t1", user_id="u1")
|
||||||
|
assert "outputs" in result
|
||||||
|
|
||||||
|
def test_host_acp_workspace_dir_with_user_id(self, paths: Paths):
|
||||||
|
result = paths.host_acp_workspace_dir("t1", user_id="u1")
|
||||||
|
assert "acp-workspace" in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnsureAndDeleteWithUserId:
|
||||||
|
def test_ensure_thread_dirs_creates_user_scoped(self, paths: Paths):
|
||||||
|
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||||
|
assert paths.sandbox_work_dir("t1", user_id="u1").is_dir()
|
||||||
|
assert paths.sandbox_uploads_dir("t1", user_id="u1").is_dir()
|
||||||
|
assert paths.sandbox_outputs_dir("t1", user_id="u1").is_dir()
|
||||||
|
assert paths.acp_workspace_dir("t1", user_id="u1").is_dir()
|
||||||
|
|
||||||
|
def test_delete_thread_dir_removes_user_scoped(self, paths: Paths):
|
||||||
|
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||||
|
assert paths.thread_dir("t1", user_id="u1").exists()
|
||||||
|
paths.delete_thread_dir("t1", user_id="u1")
|
||||||
|
assert not paths.thread_dir("t1", user_id="u1").exists()
|
||||||
|
|
||||||
|
def test_delete_thread_dir_idempotent(self, paths: Paths):
|
||||||
|
paths.delete_thread_dir("nonexistent", user_id="u1") # should not raise
|
||||||
|
|
||||||
|
def test_ensure_thread_dirs_legacy_still_works(self, paths: Paths):
|
||||||
|
paths.ensure_thread_dirs("t1")
|
||||||
|
assert paths.sandbox_work_dir("t1").is_dir()
|
||||||
|
|
||||||
|
def test_user_scoped_and_legacy_are_independent(self, paths: Paths):
|
||||||
|
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||||
|
paths.ensure_thread_dirs("t1")
|
||||||
|
# Both exist independently
|
||||||
|
assert paths.thread_dir("t1", user_id="u1").exists()
|
||||||
|
assert paths.thread_dir("t1").exists()
|
||||||
|
# Delete one doesn't affect the other
|
||||||
|
paths.delete_thread_dir("t1", user_id="u1")
|
||||||
|
assert not paths.thread_dir("t1", user_id="u1").exists()
|
||||||
|
assert paths.thread_dir("t1").exists()
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveVirtualPathWithUserId:
|
||||||
|
def test_resolve_virtual_path_with_user_id(self, paths: Paths):
|
||||||
|
paths.ensure_thread_dirs("t1", user_id="u1")
|
||||||
|
result = paths.resolve_virtual_path("t1", "/mnt/user-data/workspace/file.txt", user_id="u1")
|
||||||
|
expected_base = paths.sandbox_user_data_dir("t1", user_id="u1").resolve()
|
||||||
|
assert str(result).startswith(str(expected_base))
|
||||||
|
|
||||||
|
def test_resolve_virtual_path_legacy(self, paths: Paths):
|
||||||
|
paths.ensure_thread_dirs("t1")
|
||||||
|
result = paths.resolve_virtual_path("t1", "/mnt/user-data/workspace/file.txt")
|
||||||
|
expected_base = paths.sandbox_user_data_dir("t1").resolve()
|
||||||
|
assert str(result).startswith(str(expected_base))
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
Tests:
|
Tests:
|
||||||
1. DatabaseConfig property derivation (paths, URLs)
|
1. DatabaseConfig property derivation (paths, URLs)
|
||||||
2. MemoryRunStore CRUD + owner_id filtering
|
2. MemoryRunStore CRUD + user_id filtering
|
||||||
3. Base.to_dict() via inspect mixin
|
3. Base.to_dict() via inspect mixin
|
||||||
4. Engine init/close lifecycle (memory + SQLite)
|
4. Engine init/close lifecycle (memory + SQLite)
|
||||||
5. Postgres missing-dep error message
|
5. Postgres missing-dep error message
|
||||||
@@ -24,18 +24,19 @@ class TestDatabaseConfig:
|
|||||||
assert c.backend == "memory"
|
assert c.backend == "memory"
|
||||||
assert c.pool_size == 5
|
assert c.pool_size == 5
|
||||||
|
|
||||||
def test_sqlite_paths_are_different(self):
|
def test_sqlite_paths_unified(self):
|
||||||
c = DatabaseConfig(backend="sqlite", sqlite_dir="./mydata")
|
c = DatabaseConfig(backend="sqlite", sqlite_dir="./mydata")
|
||||||
assert c.checkpointer_sqlite_path.endswith("checkpoints.db")
|
assert c.sqlite_path.endswith("deerflow.db")
|
||||||
assert c.app_sqlite_path.endswith("app.db")
|
assert "mydata" in c.sqlite_path
|
||||||
assert "mydata" in c.checkpointer_sqlite_path
|
# Backward-compatible aliases point to the same file
|
||||||
assert c.checkpointer_sqlite_path != c.app_sqlite_path
|
assert c.checkpointer_sqlite_path == c.sqlite_path
|
||||||
|
assert c.app_sqlite_path == c.sqlite_path
|
||||||
|
|
||||||
def test_app_sqlalchemy_url_sqlite(self):
|
def test_app_sqlalchemy_url_sqlite(self):
|
||||||
c = DatabaseConfig(backend="sqlite", sqlite_dir="./data")
|
c = DatabaseConfig(backend="sqlite", sqlite_dir="./data")
|
||||||
url = c.app_sqlalchemy_url
|
url = c.app_sqlalchemy_url
|
||||||
assert url.startswith("sqlite+aiosqlite:///")
|
assert url.startswith("sqlite+aiosqlite:///")
|
||||||
assert "app.db" in url
|
assert "deerflow.db" in url
|
||||||
|
|
||||||
def test_app_sqlalchemy_url_postgres(self):
|
def test_app_sqlalchemy_url_postgres(self):
|
||||||
c = DatabaseConfig(
|
c = DatabaseConfig(
|
||||||
@@ -105,17 +106,17 @@ class TestMemoryRunStore:
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_list_by_thread_owner_filter(self, store):
|
async def test_list_by_thread_owner_filter(self, store):
|
||||||
await store.put("r1", thread_id="t1", owner_id="alice")
|
await store.put("r1", thread_id="t1", user_id="alice")
|
||||||
await store.put("r2", thread_id="t1", owner_id="bob")
|
await store.put("r2", thread_id="t1", user_id="bob")
|
||||||
rows = await store.list_by_thread("t1", owner_id="alice")
|
rows = await store.list_by_thread("t1", user_id="alice")
|
||||||
assert len(rows) == 1
|
assert len(rows) == 1
|
||||||
assert rows[0]["owner_id"] == "alice"
|
assert rows[0]["user_id"] == "alice"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_owner_none_returns_all(self, store):
|
async def test_owner_none_returns_all(self, store):
|
||||||
await store.put("r1", thread_id="t1", owner_id="alice")
|
await store.put("r1", thread_id="t1", user_id="alice")
|
||||||
await store.put("r2", thread_id="t1", owner_id="bob")
|
await store.put("r2", thread_id="t1", user_id="bob")
|
||||||
rows = await store.list_by_thread("t1", owner_id=None)
|
rows = await store.list_by_thread("t1", user_id=None)
|
||||||
assert len(rows) == 2
|
assert len(rows) == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ def test_present_files_keeps_virtual_outputs_path(tmp_path, monkeypatch):
|
|||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
present_file_tool_module,
|
present_file_tool_module,
|
||||||
"get_paths",
|
"get_paths",
|
||||||
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path: artifact_path),
|
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path, *, user_id=None: artifact_path),
|
||||||
)
|
)
|
||||||
|
|
||||||
result = present_file_tool_module.present_file_tool.func(
|
result = present_file_tool_module.present_file_tool.func(
|
||||||
|
|||||||
@@ -0,0 +1,107 @@
|
|||||||
|
"""Tests for paginated list_messages_by_run across all RunEventStore backends."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base_store():
|
||||||
|
return MemoryRunEventStore()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_messages_by_run_default_returns_all(base_store):
|
||||||
|
store = base_store
|
||||||
|
for i in range(7):
|
||||||
|
await store.put(
|
||||||
|
thread_id="t1", run_id="run-a",
|
||||||
|
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||||
|
category="message", content=f"msg-a-{i}",
|
||||||
|
)
|
||||||
|
for i in range(3):
|
||||||
|
await store.put(
|
||||||
|
thread_id="t1", run_id="run-b",
|
||||||
|
event_type="human_message", category="message", content=f"msg-b-{i}",
|
||||||
|
)
|
||||||
|
await store.put(thread_id="t1", run_id="run-a", event_type="tool_call", category="trace", content="trace")
|
||||||
|
|
||||||
|
msgs = await store.list_messages_by_run("t1", "run-a")
|
||||||
|
assert len(msgs) == 7
|
||||||
|
assert all(m["category"] == "message" for m in msgs)
|
||||||
|
assert all(m["run_id"] == "run-a" for m in msgs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_messages_by_run_with_limit(base_store):
|
||||||
|
store = base_store
|
||||||
|
for i in range(7):
|
||||||
|
await store.put(
|
||||||
|
thread_id="t1", run_id="run-a",
|
||||||
|
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||||
|
category="message", content=f"msg-a-{i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
msgs = await store.list_messages_by_run("t1", "run-a", limit=3)
|
||||||
|
assert len(msgs) == 3
|
||||||
|
seqs = [m["seq"] for m in msgs]
|
||||||
|
assert seqs == sorted(seqs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_messages_by_run_after_seq(base_store):
|
||||||
|
store = base_store
|
||||||
|
for i in range(7):
|
||||||
|
await store.put(
|
||||||
|
thread_id="t1", run_id="run-a",
|
||||||
|
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||||
|
category="message", content=f"msg-a-{i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
all_msgs = await store.list_messages_by_run("t1", "run-a")
|
||||||
|
cursor_seq = all_msgs[2]["seq"]
|
||||||
|
msgs = await store.list_messages_by_run("t1", "run-a", after_seq=cursor_seq, limit=50)
|
||||||
|
assert all(m["seq"] > cursor_seq for m in msgs)
|
||||||
|
assert len(msgs) == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_messages_by_run_before_seq(base_store):
|
||||||
|
store = base_store
|
||||||
|
for i in range(7):
|
||||||
|
await store.put(
|
||||||
|
thread_id="t1", run_id="run-a",
|
||||||
|
event_type="human_message" if i % 2 == 0 else "ai_message",
|
||||||
|
category="message", content=f"msg-a-{i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
all_msgs = await store.list_messages_by_run("t1", "run-a")
|
||||||
|
cursor_seq = all_msgs[4]["seq"]
|
||||||
|
msgs = await store.list_messages_by_run("t1", "run-a", before_seq=cursor_seq, limit=50)
|
||||||
|
assert all(m["seq"] < cursor_seq for m in msgs)
|
||||||
|
assert len(msgs) == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_messages_by_run_does_not_include_other_run(base_store):
|
||||||
|
store = base_store
|
||||||
|
for i in range(7):
|
||||||
|
await store.put(
|
||||||
|
thread_id="t1", run_id="run-a",
|
||||||
|
event_type="human_message", category="message", content=f"msg-a-{i}",
|
||||||
|
)
|
||||||
|
for i in range(3):
|
||||||
|
await store.put(
|
||||||
|
thread_id="t1", run_id="run-b",
|
||||||
|
event_type="human_message", category="message", content=f"msg-b-{i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
msgs = await store.list_messages_by_run("t1", "run-b")
|
||||||
|
assert len(msgs) == 3
|
||||||
|
assert all(m["run_id"] == "run-b" for m in msgs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_messages_by_run_empty_run(base_store):
|
||||||
|
store = base_store
|
||||||
|
msgs = await store.list_messages_by_run("t1", "nonexistent")
|
||||||
|
assert msgs == []
|
||||||
@@ -709,6 +709,81 @@ class TestToolResultMessage:
|
|||||||
assert tool_end["metadata"]["tool_call_id"] == "call_from_obj"
|
assert tool_end["metadata"]["tool_call_id"] == "call_from_obj"
|
||||||
assert tool_end["metadata"]["tool_name"] == "web_search"
|
assert tool_end["metadata"]["tool_name"] == "web_search"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tool_invoke_end_to_end_unwraps_command(self, journal_setup):
|
||||||
|
"""End-to-end: invoke a real LangChain tool that returns Command(update={'messages':[ToolMessage]}).
|
||||||
|
|
||||||
|
This goes through the real LangChain callback path (tool.invoke -> CallbackManager
|
||||||
|
-> on_tool_start/on_tool_end), which is what the production agent uses. Mirrors
|
||||||
|
the ``present_files`` tool shape exactly.
|
||||||
|
"""
|
||||||
|
from langchain_core.callbacks import CallbackManager
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
j, store = journal_setup
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def fake_present_files(filepaths: list[str]) -> Command:
|
||||||
|
"""Fake present_files that returns a Command with an inner ToolMessage."""
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"artifacts": filepaths,
|
||||||
|
"messages": [ToolMessage("Successfully presented files", tool_call_id="tc_123")],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Real LangChain callback dispatch (matches production agent path)
|
||||||
|
cm = CallbackManager(handlers=[j])
|
||||||
|
fake_present_files.invoke(
|
||||||
|
{"filepaths": ["/mnt/user-data/outputs/report.md"]},
|
||||||
|
config={"callbacks": cm, "run_id": uuid4()},
|
||||||
|
)
|
||||||
|
await j.flush()
|
||||||
|
|
||||||
|
messages = await store.list_messages("t1")
|
||||||
|
assert len(messages) == 1, f"expected 1 message event, got {len(messages)}: {messages}"
|
||||||
|
content = messages[0]["content"]
|
||||||
|
assert content["type"] == "tool"
|
||||||
|
# CRITICAL: must be the inner ToolMessage text, not str(Command(...))
|
||||||
|
assert content["content"] == "Successfully presented files", (
|
||||||
|
f"Command unwrap failed; stored content = {content['content']!r}"
|
||||||
|
)
|
||||||
|
assert "Command(update=" not in str(content["content"])
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tool_end_unwraps_command_with_inner_tool_message(self, journal_setup):
|
||||||
|
"""Tools like ``present_files`` return Command(update={'messages': [ToolMessage(...)]}).
|
||||||
|
|
||||||
|
LangGraph unwraps the inner ToolMessage into checkpoint state, so the
|
||||||
|
event store must do the same — otherwise it captures ``str(Command(...))``
|
||||||
|
and the /history response diverges from the real rendered message.
|
||||||
|
"""
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
j, store = journal_setup
|
||||||
|
run_id = uuid4()
|
||||||
|
inner = ToolMessage(
|
||||||
|
content="Successfully presented files",
|
||||||
|
tool_call_id="call_present",
|
||||||
|
name="present_files",
|
||||||
|
status="success",
|
||||||
|
)
|
||||||
|
cmd = Command(update={"artifacts": ["/mnt/user-data/outputs/report.md"], "messages": [inner]})
|
||||||
|
j.on_tool_end(cmd, run_id=run_id)
|
||||||
|
await j.flush()
|
||||||
|
|
||||||
|
messages = await store.list_messages("t1")
|
||||||
|
assert len(messages) == 1
|
||||||
|
content = messages[0]["content"]
|
||||||
|
assert content["type"] == "tool"
|
||||||
|
assert content["content"] == "Successfully presented files"
|
||||||
|
assert content["tool_call_id"] == "call_present"
|
||||||
|
assert content["name"] == "present_files"
|
||||||
|
assert "Command(update=" not in str(content["content"])
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_tool_message_object_overrides_kwargs(self, journal_setup):
|
async def test_tool_message_object_overrides_kwargs(self, journal_setup):
|
||||||
"""ToolMessage object fields take priority over kwargs."""
|
"""ToolMessage object fields take priority over kwargs."""
|
||||||
|
|||||||
@@ -73,11 +73,11 @@ class TestRunRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_list_by_thread_owner_filter(self, tmp_path):
|
async def test_list_by_thread_owner_filter(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.put("r1", thread_id="t1", owner_id="alice")
|
await repo.put("r1", thread_id="t1", user_id="alice")
|
||||||
await repo.put("r2", thread_id="t1", owner_id="bob")
|
await repo.put("r2", thread_id="t1", user_id="bob")
|
||||||
rows = await repo.list_by_thread("t1", owner_id="alice")
|
rows = await repo.list_by_thread("t1", user_id="alice")
|
||||||
assert len(rows) == 1
|
assert len(rows) == 1
|
||||||
assert rows[0]["owner_id"] == "alice"
|
assert rows[0]["user_id"] == "alice"
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -189,8 +189,8 @@ class TestRunRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_owner_none_returns_all(self, tmp_path):
|
async def test_owner_none_returns_all(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.put("r1", thread_id="t1", owner_id="alice")
|
await repo.put("r1", thread_id="t1", user_id="alice")
|
||||||
await repo.put("r2", thread_id="t1", owner_id="bob")
|
await repo.put("r2", thread_id="t1", user_id="bob")
|
||||||
rows = await repo.list_by_thread("t1", owner_id=None)
|
rows = await repo.list_by_thread("t1", user_id=None)
|
||||||
assert len(rows) == 2
|
assert len(rows) == 2
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|||||||
@@ -0,0 +1,243 @@
|
|||||||
|
"""Tests for GET /api/runs/{run_id}/messages and GET /api/runs/{run_id}/feedback endpoints."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from _router_auth_helpers import make_authed_test_app
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.gateway.routers import runs
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_app(run_store=None, event_store=None, feedback_repo=None):
|
||||||
|
"""Build a test FastAPI app with stub auth and mocked state."""
|
||||||
|
app = make_authed_test_app()
|
||||||
|
app.include_router(runs.router)
|
||||||
|
|
||||||
|
if run_store is not None:
|
||||||
|
app.state.run_store = run_store
|
||||||
|
if event_store is not None:
|
||||||
|
app.state.run_event_store = event_store
|
||||||
|
if feedback_repo is not None:
|
||||||
|
app.state.feedback_repo = feedback_repo
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def _make_run_store(run_record: dict | None):
|
||||||
|
"""Return an AsyncMock run store whose get() returns run_record."""
|
||||||
|
store = MagicMock()
|
||||||
|
store.get = AsyncMock(return_value=run_record)
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
def _make_event_store(rows: list[dict]):
|
||||||
|
"""Return an AsyncMock event store whose list_messages_by_run() returns rows."""
|
||||||
|
store = MagicMock()
|
||||||
|
store.list_messages_by_run = AsyncMock(return_value=rows)
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
def _make_message(seq: int) -> dict:
|
||||||
|
return {"seq": seq, "event_type": "on_chat_model_stream", "category": "message", "content": f"msg-{seq}"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_messages_returns_envelope():
|
||||||
|
"""GET /api/runs/{run_id}/messages returns {data: [...], has_more: bool}."""
|
||||||
|
rows = [_make_message(i) for i in range(1, 4)]
|
||||||
|
run_record = {"run_id": "run-1", "thread_id": "thread-1"}
|
||||||
|
app = _make_app(
|
||||||
|
run_store=_make_run_store(run_record),
|
||||||
|
event_store=_make_event_store(rows),
|
||||||
|
)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/runs/run-1/messages")
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert "data" in body
|
||||||
|
assert "has_more" in body
|
||||||
|
assert body["has_more"] is False
|
||||||
|
assert len(body["data"]) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_messages_404_when_run_not_found():
|
||||||
|
"""Returns 404 when the run store returns None."""
|
||||||
|
app = _make_app(
|
||||||
|
run_store=_make_run_store(None),
|
||||||
|
event_store=_make_event_store([]),
|
||||||
|
)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/runs/missing-run/messages")
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert "missing-run" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_messages_has_more_true_when_extra_row_returned():
|
||||||
|
"""has_more=True when event store returns limit+1 rows."""
|
||||||
|
# Default limit is 50; provide 51 rows
|
||||||
|
rows = [_make_message(i) for i in range(1, 52)] # 51 rows
|
||||||
|
run_record = {"run_id": "run-2", "thread_id": "thread-2"}
|
||||||
|
app = _make_app(
|
||||||
|
run_store=_make_run_store(run_record),
|
||||||
|
event_store=_make_event_store(rows),
|
||||||
|
)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/runs/run-2/messages")
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert body["has_more"] is True
|
||||||
|
assert len(body["data"]) == 50 # trimmed to limit
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_messages_passes_after_seq_to_event_store():
|
||||||
|
"""after_seq query param is forwarded to event_store.list_messages_by_run."""
|
||||||
|
rows = [_make_message(10)]
|
||||||
|
run_record = {"run_id": "run-3", "thread_id": "thread-3"}
|
||||||
|
event_store = _make_event_store(rows)
|
||||||
|
app = _make_app(
|
||||||
|
run_store=_make_run_store(run_record),
|
||||||
|
event_store=event_store,
|
||||||
|
)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/runs/run-3/messages?after_seq=5")
|
||||||
|
assert response.status_code == 200
|
||||||
|
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||||
|
"thread-3", "run-3",
|
||||||
|
limit=51, # default limit(50) + 1
|
||||||
|
before_seq=None,
|
||||||
|
after_seq=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_messages_respects_custom_limit():
|
||||||
|
"""Custom limit is respected and capped at 200."""
|
||||||
|
rows = [_make_message(i) for i in range(1, 6)]
|
||||||
|
run_record = {"run_id": "run-4", "thread_id": "thread-4"}
|
||||||
|
event_store = _make_event_store(rows)
|
||||||
|
app = _make_app(
|
||||||
|
run_store=_make_run_store(run_record),
|
||||||
|
event_store=event_store,
|
||||||
|
)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/runs/run-4/messages?limit=10")
|
||||||
|
assert response.status_code == 200
|
||||||
|
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||||
|
"thread-4", "run-4",
|
||||||
|
limit=11, # 10 + 1
|
||||||
|
before_seq=None,
|
||||||
|
after_seq=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_messages_passes_before_seq_to_event_store():
|
||||||
|
"""before_seq query param is forwarded to event_store.list_messages_by_run."""
|
||||||
|
rows = [_make_message(3)]
|
||||||
|
run_record = {"run_id": "run-5", "thread_id": "thread-5"}
|
||||||
|
event_store = _make_event_store(rows)
|
||||||
|
app = _make_app(
|
||||||
|
run_store=_make_run_store(run_record),
|
||||||
|
event_store=event_store,
|
||||||
|
)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/runs/run-5/messages?before_seq=10")
|
||||||
|
assert response.status_code == 200
|
||||||
|
event_store.list_messages_by_run.assert_awaited_once_with(
|
||||||
|
"thread-5", "run-5",
|
||||||
|
limit=51,
|
||||||
|
before_seq=10,
|
||||||
|
after_seq=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_messages_empty_data():
|
||||||
|
"""Returns empty data list when no messages exist."""
|
||||||
|
run_record = {"run_id": "run-6", "thread_id": "thread-6"}
|
||||||
|
app = _make_app(
|
||||||
|
run_store=_make_run_store(run_record),
|
||||||
|
event_store=_make_event_store([]),
|
||||||
|
)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/runs/run-6/messages")
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert body["data"] == []
|
||||||
|
assert body["has_more"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def _make_feedback_repo(rows: list[dict]):
|
||||||
|
"""Return an AsyncMock feedback repo whose list_by_run() returns rows."""
|
||||||
|
repo = MagicMock()
|
||||||
|
repo.list_by_run = AsyncMock(return_value=rows)
|
||||||
|
return repo
|
||||||
|
|
||||||
|
|
||||||
|
def _make_feedback(run_id: str, idx: int) -> dict:
|
||||||
|
return {"id": f"fb-{idx}", "run_id": run_id, "thread_id": "thread-x", "value": "up"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TestRunFeedback
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunFeedback:
|
||||||
|
def test_returns_list_of_feedback_dicts(self):
|
||||||
|
"""GET /api/runs/{run_id}/feedback returns a list of feedback dicts."""
|
||||||
|
run_record = {"run_id": "run-fb-1", "thread_id": "thread-fb-1"}
|
||||||
|
rows = [_make_feedback("run-fb-1", i) for i in range(3)]
|
||||||
|
app = _make_app(
|
||||||
|
run_store=_make_run_store(run_record),
|
||||||
|
feedback_repo=_make_feedback_repo(rows),
|
||||||
|
)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/runs/run-fb-1/feedback")
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert isinstance(body, list)
|
||||||
|
assert len(body) == 3
|
||||||
|
|
||||||
|
def test_404_when_run_not_found(self):
|
||||||
|
"""Returns 404 when run store returns None."""
|
||||||
|
app = _make_app(
|
||||||
|
run_store=_make_run_store(None),
|
||||||
|
feedback_repo=_make_feedback_repo([]),
|
||||||
|
)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/runs/missing-run/feedback")
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert "missing-run" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_empty_list_when_no_feedback(self):
|
||||||
|
"""Returns empty list when no feedback exists for the run."""
|
||||||
|
run_record = {"run_id": "run-fb-2", "thread_id": "thread-fb-2"}
|
||||||
|
app = _make_app(
|
||||||
|
run_store=_make_run_store(run_record),
|
||||||
|
feedback_repo=_make_feedback_repo([]),
|
||||||
|
)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/runs/run-fb-2/feedback")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == []
|
||||||
|
|
||||||
|
def test_503_when_feedback_repo_not_configured(self):
|
||||||
|
"""Returns 503 when feedback_repo is None (no DB configured)."""
|
||||||
|
run_record = {"run_id": "run-fb-3", "thread_id": "thread-fb-3"}
|
||||||
|
app = _make_app(
|
||||||
|
run_store=_make_run_store(run_record),
|
||||||
|
)
|
||||||
|
# Explicitly set feedback_repo to None to simulate missing DB
|
||||||
|
app.state.feedback_repo = None
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/runs/run-fb-3/feedback")
|
||||||
|
assert response.status_code == 503
|
||||||
@@ -47,7 +47,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
|
|||||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||||
|
|
||||||
# Bypass the require_permission decorator (which needs request +
|
# Bypass the require_permission decorator (which needs request +
|
||||||
# thread_meta_repo) — these tests cover the parsing logic.
|
# thread_store) — these tests cover the parsing logic.
|
||||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||||
|
|
||||||
assert result.suggestions == ["Q1", "Q2", "Q3"]
|
assert result.suggestions == ["Q1", "Q2", "Q3"]
|
||||||
@@ -67,7 +67,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
|
|||||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||||
|
|
||||||
# Bypass the require_permission decorator (which needs request +
|
# Bypass the require_permission decorator (which needs request +
|
||||||
# thread_meta_repo) — these tests cover the parsing logic.
|
# thread_store) — these tests cover the parsing logic.
|
||||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||||
|
|
||||||
assert result.suggestions == ["Q1", "Q2"]
|
assert result.suggestions == ["Q1", "Q2"]
|
||||||
@@ -87,7 +87,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
|
|||||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||||
|
|
||||||
# Bypass the require_permission decorator (which needs request +
|
# Bypass the require_permission decorator (which needs request +
|
||||||
# thread_meta_repo) — these tests cover the parsing logic.
|
# thread_store) — these tests cover the parsing logic.
|
||||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||||
|
|
||||||
assert result.suggestions == ["Q1", "Q2"]
|
assert result.suggestions == ["Q1", "Q2"]
|
||||||
@@ -104,7 +104,7 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
|
|||||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||||
|
|
||||||
# Bypass the require_permission decorator (which needs request +
|
# Bypass the require_permission decorator (which needs request +
|
||||||
# thread_meta_repo) — these tests cover the parsing logic.
|
# thread_store) — these tests cover the parsing logic.
|
||||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||||
|
|
||||||
assert result.suggestions == []
|
assert result.suggestions == []
|
||||||
|
|||||||
@@ -43,8 +43,8 @@ class TestThreadMetaRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_create_with_owner_and_display_name(self, tmp_path):
|
async def test_create_with_owner_and_display_name(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
record = await repo.create("t1", owner_id="user1", display_name="My Thread")
|
record = await repo.create("t1", user_id="user1", display_name="My Thread")
|
||||||
assert record["owner_id"] == "user1"
|
assert record["user_id"] == "user1"
|
||||||
assert record["display_name"] == "My Thread"
|
assert record["display_name"] == "My Thread"
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@@ -61,26 +61,6 @@ class TestThreadMetaRepository:
|
|||||||
assert await repo.get("nonexistent") is None
|
assert await repo.get("nonexistent") is None
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_list_by_owner(self, tmp_path):
|
|
||||||
repo = await _make_repo(tmp_path)
|
|
||||||
await repo.create("t1", owner_id="user1")
|
|
||||||
await repo.create("t2", owner_id="user1")
|
|
||||||
await repo.create("t3", owner_id="user2")
|
|
||||||
results = await repo.list_by_owner("user1")
|
|
||||||
assert len(results) == 2
|
|
||||||
assert all(r["owner_id"] == "user1" for r in results)
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_list_by_owner_with_limit_and_offset(self, tmp_path):
|
|
||||||
repo = await _make_repo(tmp_path)
|
|
||||||
for i in range(5):
|
|
||||||
await repo.create(f"t{i}", owner_id="user1")
|
|
||||||
results = await repo.list_by_owner("user1", limit=2, offset=1)
|
|
||||||
assert len(results) == 2
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_no_record_allows(self, tmp_path):
|
async def test_check_access_no_record_allows(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
@@ -90,23 +70,23 @@ class TestThreadMetaRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_owner_matches(self, tmp_path):
|
async def test_check_access_owner_matches(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.create("t1", owner_id="user1")
|
await repo.create("t1", user_id="user1")
|
||||||
assert await repo.check_access("t1", "user1") is True
|
assert await repo.check_access("t1", "user1") is True
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_owner_mismatch(self, tmp_path):
|
async def test_check_access_owner_mismatch(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.create("t1", owner_id="user1")
|
await repo.create("t1", user_id="user1")
|
||||||
assert await repo.check_access("t1", "user2") is False
|
assert await repo.check_access("t1", "user2") is False
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_no_owner_allows_all(self, tmp_path):
|
async def test_check_access_no_owner_allows_all(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
# Explicit owner_id=None to bypass the new AUTO default that
|
# Explicit user_id=None to bypass the new AUTO default that
|
||||||
# would otherwise pick up the test user from the autouse fixture.
|
# would otherwise pick up the test user from the autouse fixture.
|
||||||
await repo.create("t1", owner_id=None)
|
await repo.create("t1", user_id=None)
|
||||||
assert await repo.check_access("t1", "anyone") is True
|
assert await repo.check_access("t1", "anyone") is True
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@@ -125,27 +105,27 @@ class TestThreadMetaRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_strict_owner_match_allowed(self, tmp_path):
|
async def test_check_access_strict_owner_match_allowed(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.create("t1", owner_id="user1")
|
await repo.create("t1", user_id="user1")
|
||||||
assert await repo.check_access("t1", "user1", require_existing=True) is True
|
assert await repo.check_access("t1", "user1", require_existing=True) is True
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_strict_owner_mismatch_denied(self, tmp_path):
|
async def test_check_access_strict_owner_mismatch_denied(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.create("t1", owner_id="user1")
|
await repo.create("t1", user_id="user1")
|
||||||
assert await repo.check_access("t1", "user2", require_existing=True) is False
|
assert await repo.check_access("t1", "user2", require_existing=True) is False
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_strict_null_owner_still_allowed(self, tmp_path):
|
async def test_check_access_strict_null_owner_still_allowed(self, tmp_path):
|
||||||
"""Even in strict mode, a row with NULL owner_id stays shared.
|
"""Even in strict mode, a row with NULL user_id stays shared.
|
||||||
|
|
||||||
The strict flag tightens the *missing row* case, not the *shared
|
The strict flag tightens the *missing row* case, not the *shared
|
||||||
row* case — legacy pre-auth rows that survived a clean migration
|
row* case — legacy pre-auth rows that survived a clean migration
|
||||||
without an owner are still everyone's.
|
without an owner are still everyone's.
|
||||||
"""
|
"""
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.create("t1", owner_id=None)
|
await repo.create("t1", user_id=None)
|
||||||
assert await repo.check_access("t1", "anyone", require_existing=True) is True
|
assert await repo.check_access("t1", "anyone", require_existing=True) is True
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user