mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e60621d519 | |||
| f7a6ca8364 | |||
| 2540acd5f7 | |||
| b2704525a0 | |||
| 00e0e9a49a |
+8
-20
@@ -158,7 +158,7 @@ from deerflow.config import get_app_config
|
|||||||
|
|
||||||
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
|
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
|
||||||
|
|
||||||
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
|
1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory
|
||||||
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
||||||
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
||||||
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
|
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
|
||||||
@@ -216,9 +216,6 @@ FastAPI application on port 8001 with health check at `GET /health`.
|
|||||||
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
||||||
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
|
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
|
||||||
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing |
|
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing |
|
||||||
| **Thread Runs** (`/api/threads/{id}/runs`) | `POST /` - create background run; `POST /stream` - create + SSE stream; `POST /wait` - create + block; `GET /` - list runs; `GET /{rid}` - run details; `POST /{rid}/cancel` - cancel; `GET /{rid}/join` - join SSE; `GET /{rid}/messages` - paginated messages `{data, has_more}`; `GET /{rid}/events` - full event stream; `GET /../messages` - thread messages with feedback; `GET /../token-usage` - aggregate tokens |
|
|
||||||
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
|
||||||
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
|
||||||
|
|
||||||
Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway.
|
Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway.
|
||||||
|
|
||||||
@@ -232,7 +229,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
|||||||
|
|
||||||
**Virtual Path System**:
|
**Virtual Path System**:
|
||||||
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
||||||
- Physical: `backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
- Physical: `backend/.deer-flow/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
||||||
- Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()`
|
- Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()`
|
||||||
- Detection: `is_local_sandbox()` checks `sandbox_id == "local"`
|
- Detection: `is_local_sandbox()` checks `sandbox_id == "local"`
|
||||||
|
|
||||||
@@ -272,7 +269,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
|||||||
- `invoke_acp_agent` - Invokes external ACP-compatible agents from `config.yaml`
|
- `invoke_acp_agent` - Invokes external ACP-compatible agents from `config.yaml`
|
||||||
- ACP launchers must be real ACP adapters. The standard `codex` CLI is not ACP-compatible by itself; configure a wrapper such as `npx -y @zed-industries/codex-acp` or an installed `codex-acp` binary
|
- ACP launchers must be real ACP adapters. The standard `codex` CLI is not ACP-compatible by itself; configure a wrapper such as `npx -y @zed-industries/codex-acp` or an installed `codex-acp` binary
|
||||||
- Missing ACP executables now return an actionable error message instead of a raw `[Errno 2]`
|
- Missing ACP executables now return an actionable error message instead of a raw `[Errno 2]`
|
||||||
- Each ACP agent uses a per-thread workspace at `{base_dir}/users/{user_id}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
|
- Each ACP agent uses a per-thread workspace at `{base_dir}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
|
||||||
- `image_search/` - Image search via DuckDuckGo
|
- `image_search/` - Image search via DuckDuckGo
|
||||||
|
|
||||||
### MCP System (`packages/harness/deerflow/mcp/`)
|
### MCP System (`packages/harness/deerflow/mcp/`)
|
||||||
@@ -341,27 +338,18 @@ Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow a
|
|||||||
|
|
||||||
**Components**:
|
**Components**:
|
||||||
- `updater.py` - LLM-based memory updates with fact extraction, whitespace-normalized fact deduplication (trims leading/trailing whitespace before comparing), and atomic file I/O
|
- `updater.py` - LLM-based memory updates with fact extraction, whitespace-normalized fact deduplication (trims leading/trailing whitespace before comparing), and atomic file I/O
|
||||||
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time); captures `user_id` at enqueue time so it survives the `threading.Timer` boundary
|
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time)
|
||||||
- `prompt.py` - Prompt templates for memory updates
|
- `prompt.py` - Prompt templates for memory updates
|
||||||
- `storage.py` - File-based storage with per-user isolation; cache keyed by `(user_id, agent_name)` tuple
|
|
||||||
|
|
||||||
**Per-User Isolation**:
|
**Data Structure** (stored in `backend/.deer-flow/memory.json`):
|
||||||
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
|
|
||||||
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
|
|
||||||
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
|
|
||||||
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
|
|
||||||
- Absolute `storage_path` in config opts out of per-user isolation
|
|
||||||
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
|
|
||||||
|
|
||||||
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
|
|
||||||
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
||||||
- **History**: `recentMonths`, `earlierContext`, `longTermBackground`
|
- **History**: `recentMonths`, `earlierContext`, `longTermBackground`
|
||||||
- **Facts**: Discrete facts with `id`, `content`, `category` (preference/knowledge/context/behavior/goal), `confidence` (0-1), `createdAt`, `source`
|
- **Facts**: Discrete facts with `id`, `content`, `category` (preference/knowledge/context/behavior/goal), `confidence` (0-1), `createdAt`, `source`
|
||||||
|
|
||||||
**Workflow**:
|
**Workflow**:
|
||||||
1. `MemoryMiddleware` filters messages (user inputs + final AI responses), captures `user_id` via `get_effective_user_id()`, and queues conversation with the captured `user_id`
|
1. `MemoryMiddleware` filters messages (user inputs + final AI responses) and queues conversation
|
||||||
2. Queue debounces (30s default), batches updates, deduplicates per-thread
|
2. Queue debounces (30s default), batches updates, deduplicates per-thread
|
||||||
3. Background thread invokes LLM to extract context updates and facts, using the stored `user_id` (not the contextvar, which is unavailable on timer threads)
|
3. Background thread invokes LLM to extract context updates and facts
|
||||||
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
||||||
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
||||||
|
|
||||||
@@ -369,7 +357,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
|
|||||||
|
|
||||||
**Configuration** (`config.yaml` → `memory`):
|
**Configuration** (`config.yaml` → `memory`):
|
||||||
- `enabled` / `injection_enabled` - Master switches
|
- `enabled` / `injection_enabled` - Master switches
|
||||||
- `storage_path` - Path to memory.json (absolute path opts out of per-user isolation)
|
- `storage_path` - Path to memory.json
|
||||||
- `debounce_seconds` - Wait time before processing (default: 30)
|
- `debounce_seconds` - Wait time before processing (default: 30)
|
||||||
- `model_name` - LLM for updates (null = default model)
|
- `model_name` - LLM for updates (null = default model)
|
||||||
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from app.channels.base import Channel
|
|||||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -345,9 +344,8 @@ class FeishuChannel(Channel):
|
|||||||
return f"Failed to obtain the [{type}]"
|
return f"Failed to obtain the [{type}]"
|
||||||
|
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
user_id = get_effective_user_id()
|
paths.ensure_thread_dirs(thread_id)
|
||||||
paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
uploads_dir = paths.sandbox_uploads_dir(thread_id).resolve()
|
||||||
uploads_dir = paths.sandbox_uploads_dir(thread_id, user_id=user_id).resolve()
|
|
||||||
|
|
||||||
ext = "png" if type == "image" else "bin"
|
ext = "png" if type == "image" else "bin"
|
||||||
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
|
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from langgraph_sdk.errors import ConflictError
|
|||||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
from app.channels.store import ChannelStore
|
from app.channels.store import ChannelStore
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -342,15 +341,14 @@ def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedA
|
|||||||
|
|
||||||
attachments: list[ResolvedAttachment] = []
|
attachments: list[ResolvedAttachment] = []
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
user_id = get_effective_user_id()
|
outputs_dir = paths.sandbox_outputs_dir(thread_id).resolve()
|
||||||
outputs_dir = paths.sandbox_outputs_dir(thread_id, user_id=user_id).resolve()
|
|
||||||
for virtual_path in artifacts:
|
for virtual_path in artifacts:
|
||||||
# Security: only allow files from the agent outputs directory
|
# Security: only allow files from the agent outputs directory
|
||||||
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
|
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
|
||||||
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
|
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
actual = paths.resolve_virtual_path(thread_id, virtual_path, user_id=user_id)
|
actual = paths.resolve_virtual_path(thread_id, virtual_path)
|
||||||
# Verify the resolved path is actually under the outputs directory
|
# Verify the resolved path is actually under the outputs directory
|
||||||
# (guards against path-traversal even after prefix check)
|
# (guards against path-traversal even after prefix check)
|
||||||
try:
|
try:
|
||||||
|
|||||||
+55
-37
@@ -2,6 +2,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import UTC
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@@ -40,60 +41,77 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
async def _ensure_admin_user(app: FastAPI) -> None:
|
async def _ensure_admin_user(app: FastAPI) -> None:
|
||||||
"""Startup hook: handle first boot and migrate orphan threads otherwise.
|
"""Auto-create the admin user on first boot if no users exist.
|
||||||
|
|
||||||
After admin creation, migrate orphan threads from the LangGraph
|
After admin creation, migrate orphan threads from the LangGraph
|
||||||
store (metadata.user_id unset) to the admin account. This is the
|
store (metadata.owner_id unset) to the admin account. This is the
|
||||||
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
|
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
|
||||||
authentication have existing LangGraph thread data that needs an
|
authentication have existing LangGraph thread data that needs an
|
||||||
owner assigned.
|
owner assigned.
|
||||||
First boot (no admin exists):
|
|
||||||
- Does NOT create any user accounts automatically.
|
|
||||||
- The operator must visit ``/setup`` to create the first admin.
|
|
||||||
|
|
||||||
Subsequent boots (admin already exists):
|
No SQL persistence migration is needed: the four owner_id columns
|
||||||
- Runs the one-time "no-auth → with-auth" orphan thread migration for
|
|
||||||
existing LangGraph thread metadata that has no owner_id.
|
|
||||||
|
|
||||||
No SQL persistence migration is needed: the four user_id columns
|
|
||||||
(threads_meta, runs, run_events, feedback) only come into existence
|
(threads_meta, runs, run_events, feedback) only come into existence
|
||||||
alongside the auth module via create_all, so freshly created tables
|
alongside the auth module via create_all, so freshly created tables
|
||||||
never contain NULL-owner rows.
|
never contain NULL-owner rows. "Existing persistence DB + new auth"
|
||||||
"""
|
is not a supported upgrade path — fresh install or wipe-and-retry.
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
|
Multi-worker safe: relies on SQLite UNIQUE constraint to resolve
|
||||||
|
races during admin creation. Only the worker that successfully
|
||||||
|
creates/updates the admin prints the password; losers silently skip.
|
||||||
|
"""
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from app.gateway.auth.credential_file import write_initial_credentials
|
||||||
from app.gateway.deps import get_local_provider
|
from app.gateway.deps import get_local_provider
|
||||||
from deerflow.persistence.engine import get_session_factory
|
|
||||||
from deerflow.persistence.user.model import UserRow
|
def _announce_credentials(email: str, password: str, *, label: str, headline: str) -> None:
|
||||||
|
"""Write the password to a 0600 file and log the path (never the secret)."""
|
||||||
|
cred_path = write_initial_credentials(email, password, label=label)
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(" %s", headline)
|
||||||
|
logger.info(" Credentials written to: %s (mode 0600)", cred_path)
|
||||||
|
logger.info(" Change it after login: Settings -> Account")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
provider = get_local_provider()
|
provider = get_local_provider()
|
||||||
admin_count = await provider.count_admin_users()
|
user_count = await provider.count_users()
|
||||||
|
|
||||||
if admin_count == 0:
|
admin = None
|
||||||
logger.info("=" * 60)
|
|
||||||
logger.info(" First boot detected — no admin account exists.")
|
|
||||||
logger.info(" Visit /setup to complete admin account creation.")
|
|
||||||
logger.info("=" * 60)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Admin already exists — run orphan thread migration for any
|
if user_count == 0:
|
||||||
# LangGraph thread metadata that pre-dates the auth module.
|
password = secrets.token_urlsafe(16)
|
||||||
sf = get_session_factory()
|
try:
|
||||||
if sf is None:
|
admin = await provider.create_user(email="admin@deerflow.dev", password=password, system_role="admin", needs_setup=True)
|
||||||
return
|
except ValueError:
|
||||||
|
return # Another worker already created the admin.
|
||||||
|
_announce_credentials(admin.email, password, label="initial", headline="Admin account created on first boot")
|
||||||
|
else:
|
||||||
|
# Admin exists but setup never completed — reset password so operator
|
||||||
|
# can always find it in the console without needing the CLI.
|
||||||
|
# Multi-worker guard: if admin was created less than 30s ago, another
|
||||||
|
# worker just created it and will print the password — skip reset.
|
||||||
|
admin = await provider.get_user_by_email("admin@deerflow.dev")
|
||||||
|
if admin and admin.needs_setup:
|
||||||
|
import time
|
||||||
|
|
||||||
async with sf() as session:
|
age = time.time() - admin.created_at.replace(tzinfo=UTC).timestamp()
|
||||||
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
|
if age >= 30:
|
||||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
from app.gateway.auth.password import hash_password_async
|
||||||
|
|
||||||
if row is None:
|
password = secrets.token_urlsafe(16)
|
||||||
return # Should not happen (admin_count > 0 above), but be safe.
|
admin.password_hash = await hash_password_async(password)
|
||||||
|
admin.token_version += 1
|
||||||
|
await provider.update_user(admin)
|
||||||
|
_announce_credentials(admin.email, password, label="reset", headline="Admin account setup incomplete — password reset")
|
||||||
|
|
||||||
admin_id = str(row.id)
|
if admin is None:
|
||||||
|
return # Nothing to bind orphans to.
|
||||||
|
|
||||||
|
admin_id = str(admin.id)
|
||||||
|
|
||||||
# LangGraph store orphan migration — non-fatal.
|
# LangGraph store orphan migration — non-fatal.
|
||||||
# This covers the "no-auth → with-auth" upgrade path for users
|
# This covers the "no-auth → with-auth" upgrade path for users
|
||||||
# whose existing LangGraph thread metadata has no user_id set.
|
# whose existing LangGraph thread metadata has no owner_id set.
|
||||||
store = getattr(app.state, "store", None)
|
store = getattr(app.state, "store", None)
|
||||||
if store is not None:
|
if store is not None:
|
||||||
try:
|
try:
|
||||||
@@ -125,7 +143,7 @@ async def _iter_store_items(store, namespace, *, page_size: int = 500):
|
|||||||
|
|
||||||
|
|
||||||
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
||||||
"""Migrate LangGraph store threads with no user_id to the given admin.
|
"""Migrate LangGraph store threads with no owner_id to the given admin.
|
||||||
|
|
||||||
Uses cursor pagination so all orphans are migrated regardless of
|
Uses cursor pagination so all orphans are migrated regardless of
|
||||||
count. Returns the number of rows migrated.
|
count. Returns the number of rows migrated.
|
||||||
@@ -133,8 +151,8 @@ async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
|||||||
migrated = 0
|
migrated = 0
|
||||||
async for item in _iter_store_items(store, ("threads",)):
|
async for item in _iter_store_items(store, ("threads",)):
|
||||||
metadata = item.value.get("metadata", {})
|
metadata = item.value.get("metadata", {})
|
||||||
if not metadata.get("user_id"):
|
if not metadata.get("owner_id"):
|
||||||
metadata["user_id"] = admin_user_id
|
metadata["owner_id"] = admin_user_id
|
||||||
item.value["metadata"] = metadata
|
item.value["metadata"] = metadata
|
||||||
await store.aput(("threads",), item.key, item.value)
|
await store.aput(("threads",), item.key, item.value)
|
||||||
migrated += 1
|
migrated += 1
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ class AuthErrorCode(StrEnum):
|
|||||||
EMAIL_ALREADY_EXISTS = "email_already_exists"
|
EMAIL_ALREADY_EXISTS = "email_already_exists"
|
||||||
PROVIDER_NOT_FOUND = "provider_not_found"
|
PROVIDER_NOT_FOUND = "provider_not_found"
|
||||||
NOT_AUTHENTICATED = "not_authenticated"
|
NOT_AUTHENTICATED = "not_authenticated"
|
||||||
SYSTEM_ALREADY_INITIALIZED = "system_already_initialized"
|
|
||||||
|
|
||||||
|
|
||||||
class TokenError(StrEnum):
|
class TokenError(StrEnum):
|
||||||
|
|||||||
@@ -78,10 +78,6 @@ class LocalAuthProvider(AuthProvider):
|
|||||||
"""Return total number of registered users."""
|
"""Return total number of registered users."""
|
||||||
return await self._repo.count_users()
|
return await self._repo.count_users()
|
||||||
|
|
||||||
async def count_admin_users(self) -> int:
|
|
||||||
"""Return number of admin users."""
|
|
||||||
return await self._repo.count_admin_users()
|
|
||||||
|
|
||||||
async def update_user(self, user: User) -> User:
|
async def update_user(self, user: User) -> User:
|
||||||
"""Update an existing user."""
|
"""Update an existing user."""
|
||||||
return await self._repo.update_user(user)
|
return await self._repo.update_user(user)
|
||||||
|
|||||||
@@ -83,11 +83,6 @@ class UserRepository(ABC):
|
|||||||
"""Return total number of registered users."""
|
"""Return total number of registered users."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def count_admin_users(self) -> int:
|
|
||||||
"""Return number of users with system_role == 'admin'."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||||
"""Get user by OAuth provider and ID.
|
"""Get user by OAuth provider and ID.
|
||||||
|
|||||||
@@ -114,11 +114,6 @@ class SQLiteUserRepository(UserRepository):
|
|||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
return await session.scalar(stmt) or 0
|
return await session.scalar(stmt) or 0
|
||||||
|
|
||||||
async def count_admin_users(self) -> int:
|
|
||||||
stmt = select(func.count()).select_from(UserRow).where(UserRow.system_role == "admin")
|
|
||||||
async with self._sf() as session:
|
|
||||||
return await session.scalar(stmt) or 0
|
|
||||||
|
|
||||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||||
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
|
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ _PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
|
|||||||
"/api/v1/auth/register",
|
"/api/v1/auth/register",
|
||||||
"/api/v1/auth/logout",
|
"/api/v1/auth/logout",
|
||||||
"/api/v1/auth/setup-status",
|
"/api/v1/auth/setup-status",
|
||||||
"/api/v1/auth/initialize",
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -233,18 +233,18 @@ def require_permission(
|
|||||||
# (``threads_meta`` table). We verify ownership via
|
# (``threads_meta`` table). We verify ownership via
|
||||||
# ``ThreadMetaStore.check_access``: it returns True for
|
# ``ThreadMetaStore.check_access``: it returns True for
|
||||||
# missing rows (untracked legacy thread) and for rows whose
|
# missing rows (untracked legacy thread) and for rows whose
|
||||||
# ``user_id`` is NULL (shared / pre-auth data), so this is
|
# ``owner_id`` is NULL (shared / pre-auth data), so this is
|
||||||
# strict-deny rather than strict-allow — only an *existing*
|
# strict-deny rather than strict-allow — only an *existing*
|
||||||
# row with a *different* user_id triggers 404.
|
# row with a *different* owner_id triggers 404.
|
||||||
if owner_check:
|
if owner_check:
|
||||||
thread_id = kwargs.get("thread_id")
|
thread_id = kwargs.get("thread_id")
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||||
|
|
||||||
from app.gateway.deps import get_thread_store
|
from app.gateway.deps import get_thread_meta_repo
|
||||||
|
|
||||||
thread_store = get_thread_store(request)
|
thread_meta_repo = get_thread_meta_repo(request)
|
||||||
allowed = await thread_store.check_access(
|
allowed = await thread_meta_repo.check_access(
|
||||||
thread_id,
|
thread_id,
|
||||||
str(auth.user.id),
|
str(auth.user.id),
|
||||||
require_existing=require_existing,
|
require_existing=require_existing,
|
||||||
|
|||||||
@@ -48,7 +48,6 @@ _AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
|
|||||||
"/api/v1/auth/login/local",
|
"/api/v1/auth/login/local",
|
||||||
"/api/v1/auth/logout",
|
"/api/v1/auth/logout",
|
||||||
"/api/v1/auth/register",
|
"/api/v1/auth/register",
|
||||||
"/api/v1/auth/initialize",
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+26
-35
@@ -1,32 +1,25 @@
|
|||||||
"""Centralized accessors for singleton objects stored on ``app.state``.
|
"""Centralized accessors for singleton objects stored on ``app.state``.
|
||||||
|
|
||||||
**Getters** (used by routers): raise 503 when a required dependency is
|
**Getters** (used by routers): raise 503 when a required dependency is
|
||||||
missing, except ``get_store`` which returns ``None``.
|
missing, except ``get_store`` and ``get_thread_meta_repo`` which return
|
||||||
|
``None``.
|
||||||
|
|
||||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager
|
from contextlib import AsyncExitStack, asynccontextmanager
|
||||||
from typing import TYPE_CHECKING, TypeVar, cast
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
from langgraph.types import Checkpointer
|
|
||||||
|
|
||||||
from deerflow.persistence.feedback import FeedbackRepository
|
from deerflow.runtime import RunContext, RunManager
|
||||||
from deerflow.runtime import RunContext, RunManager, StreamBridge
|
|
||||||
from deerflow.runtime.events.store.base import RunEventStore
|
|
||||||
from deerflow.runtime.runs.store.base import RunStore
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -38,10 +31,10 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
async with langgraph_runtime(app):
|
async with langgraph_runtime(app):
|
||||||
yield
|
yield
|
||||||
"""
|
"""
|
||||||
|
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||||
from deerflow.config import get_app_config
|
from deerflow.config import get_app_config
|
||||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
||||||
from deerflow.runtime import make_store, make_stream_bridge
|
from deerflow.runtime import make_store, make_stream_bridge
|
||||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
|
||||||
from deerflow.runtime.events.store import make_run_event_store
|
from deerflow.runtime.events.store import make_run_event_store
|
||||||
|
|
||||||
async with AsyncExitStack() as stack:
|
async with AsyncExitStack() as stack:
|
||||||
@@ -60,18 +53,18 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
if sf is not None:
|
if sf is not None:
|
||||||
from deerflow.persistence.feedback import FeedbackRepository
|
from deerflow.persistence.feedback import FeedbackRepository
|
||||||
from deerflow.persistence.run import RunRepository
|
from deerflow.persistence.run import RunRepository
|
||||||
|
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||||
|
|
||||||
app.state.run_store = RunRepository(sf)
|
app.state.run_store = RunRepository(sf)
|
||||||
app.state.feedback_repo = FeedbackRepository(sf)
|
app.state.feedback_repo = FeedbackRepository(sf)
|
||||||
|
app.state.thread_meta_repo = ThreadMetaRepository(sf)
|
||||||
else:
|
else:
|
||||||
|
from deerflow.persistence.thread_meta import MemoryThreadMetaStore
|
||||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||||
|
|
||||||
app.state.run_store = MemoryRunStore()
|
app.state.run_store = MemoryRunStore()
|
||||||
app.state.feedback_repo = None
|
app.state.feedback_repo = None
|
||||||
|
app.state.thread_meta_repo = MemoryThreadMetaStore(app.state.store)
|
||||||
from deerflow.persistence.thread_meta import make_thread_store
|
|
||||||
|
|
||||||
app.state.thread_store = make_thread_store(sf, app.state.store)
|
|
||||||
|
|
||||||
# Run event store (has its own factory with config-driven backend selection)
|
# Run event store (has its own factory with config-driven backend selection)
|
||||||
run_events_config = getattr(config, "run_events", None)
|
run_events_config = getattr(config, "run_events", None)
|
||||||
@@ -87,29 +80,29 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Getters – called by routers per-request
|
# Getters -- called by routers per-request
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _require(attr: str, label: str) -> Callable[[Request], T]:
|
def _require(attr: str, label: str):
|
||||||
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
|
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
|
||||||
|
|
||||||
def dep(request: Request) -> T:
|
def dep(request: Request):
|
||||||
val = getattr(request.app.state, attr, None)
|
val = getattr(request.app.state, attr, None)
|
||||||
if val is None:
|
if val is None:
|
||||||
raise HTTPException(status_code=503, detail=f"{label} not available")
|
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||||
return cast(T, val)
|
return val
|
||||||
|
|
||||||
dep.__name__ = dep.__qualname__ = f"get_{attr}"
|
dep.__name__ = dep.__qualname__ = f"get_{attr}"
|
||||||
return dep
|
return dep
|
||||||
|
|
||||||
|
|
||||||
get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge")
|
get_stream_bridge = _require("stream_bridge", "Stream bridge")
|
||||||
get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager")
|
get_run_manager = _require("run_manager", "Run manager")
|
||||||
get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer")
|
get_checkpointer = _require("checkpointer", "Checkpointer")
|
||||||
get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store")
|
get_run_event_store = _require("run_event_store", "Run event store")
|
||||||
get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback")
|
get_feedback_repo = _require("feedback_repo", "Feedback")
|
||||||
get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store")
|
get_run_store = _require("run_store", "Run store")
|
||||||
|
|
||||||
|
|
||||||
def get_store(request: Request):
|
def get_store(request: Request):
|
||||||
@@ -117,18 +110,16 @@ def get_store(request: Request):
|
|||||||
return getattr(request.app.state, "store", None)
|
return getattr(request.app.state, "store", None)
|
||||||
|
|
||||||
|
|
||||||
def get_thread_store(request: Request) -> ThreadMetaStore:
|
get_thread_meta_repo = _require("thread_meta_repo", "Thread metadata store")
|
||||||
"""Return the thread metadata store (SQL or memory-backed)."""
|
|
||||||
val = getattr(request.app.state, "thread_store", None)
|
|
||||||
if val is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Thread metadata store not available")
|
|
||||||
return val
|
|
||||||
|
|
||||||
|
|
||||||
def get_run_context(request: Request) -> RunContext:
|
def get_run_context(request: Request) -> RunContext:
|
||||||
"""Build a :class:`RunContext` from ``app.state`` singletons.
|
"""Build a :class:`RunContext` from ``app.state`` singletons.
|
||||||
|
|
||||||
Returns a *base* context with infrastructure dependencies.
|
Returns a *base* context with infrastructure dependencies. Callers that
|
||||||
|
need per-run fields (e.g. ``follow_up_to_run_id``) should use
|
||||||
|
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
|
||||||
|
to :func:`run_agent`.
|
||||||
"""
|
"""
|
||||||
from deerflow.config import get_app_config
|
from deerflow.config import get_app_config
|
||||||
|
|
||||||
@@ -137,7 +128,7 @@ def get_run_context(request: Request) -> RunContext:
|
|||||||
store=get_store(request),
|
store=get_store(request),
|
||||||
event_store=get_run_event_store(request),
|
event_store=get_run_event_store(request),
|
||||||
run_events_config=getattr(get_app_config(), "run_events", None),
|
run_events_config=getattr(get_app_config(), "run_events", None),
|
||||||
thread_store=get_thread_store(request),
|
thread_meta_repo=get_thread_meta_repo(request),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -93,14 +93,14 @@ async def authenticate(request):
|
|||||||
|
|
||||||
@auth.on
|
@auth.on
|
||||||
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
||||||
"""Inject user_id metadata on writes; filter by user_id on reads.
|
"""Inject owner_id metadata on writes; filter by owner_id on reads.
|
||||||
|
|
||||||
Gateway stores thread ownership as ``metadata.user_id``.
|
Gateway stores thread ownership as ``metadata.owner_id``.
|
||||||
This handler ensures LangGraph Server enforces the same isolation.
|
This handler ensures LangGraph Server enforces the same isolation.
|
||||||
"""
|
"""
|
||||||
# On create/update: stamp user_id into metadata
|
# On create/update: stamp owner_id into metadata
|
||||||
metadata = value.setdefault("metadata", {})
|
metadata = value.setdefault("metadata", {})
|
||||||
metadata["user_id"] = ctx.user.identity
|
metadata["owner_id"] = ctx.user.identity
|
||||||
|
|
||||||
# Return filter dict — LangGraph applies it to search/read/delete
|
# Return filter dict — LangGraph applies it to search/read/delete
|
||||||
return {"user_id": ctx.user.identity}
|
return {"owner_id": ctx.user.identity}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from pathlib import Path
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
||||||
@@ -23,7 +22,7 @@ def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
|||||||
HTTPException: If the path is invalid or outside allowed directories.
|
HTTPException: If the path is invalid or outside allowed directories.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=get_effective_user_id())
|
return get_paths().resolve_virtual_path(thread_id, virtual_path)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
status = 403 if "traversal" in str(e) else 400
|
status = 403 if "traversal" in str(e) else 400
|
||||||
raise HTTPException(status_code=status, detail=str(e))
|
raise HTTPException(status_code=status, detail=str(e))
|
||||||
|
|||||||
@@ -378,50 +378,9 @@ async def get_me(request: Request):
|
|||||||
|
|
||||||
@router.get("/setup-status")
|
@router.get("/setup-status")
|
||||||
async def setup_status():
|
async def setup_status():
|
||||||
"""Check if an admin account exists. Returns needs_setup=True when no admin exists."""
|
"""Check if admin account exists. Always False after first boot."""
|
||||||
admin_count = await get_local_provider().count_admin_users()
|
user_count = await get_local_provider().count_users()
|
||||||
return {"needs_setup": admin_count == 0}
|
return {"needs_setup": user_count == 0}
|
||||||
|
|
||||||
|
|
||||||
class InitializeAdminRequest(BaseModel):
|
|
||||||
"""Request model for first-boot admin account creation."""
|
|
||||||
|
|
||||||
email: EmailStr
|
|
||||||
password: str = Field(..., min_length=8)
|
|
||||||
|
|
||||||
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/initialize", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def initialize_admin(request: Request, response: Response, body: InitializeAdminRequest):
|
|
||||||
"""Create the first admin account on initial system setup.
|
|
||||||
|
|
||||||
Only callable when no admin exists. Returns 409 Conflict if an admin
|
|
||||||
already exists.
|
|
||||||
|
|
||||||
On success, the admin account is created with ``needs_setup=False`` and
|
|
||||||
the session cookie is set.
|
|
||||||
"""
|
|
||||||
admin_count = await get_local_provider().count_admin_users()
|
|
||||||
if admin_count > 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="admin", needs_setup=False)
|
|
||||||
except ValueError:
|
|
||||||
# DB unique-constraint race: another concurrent request beat us.
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
|
||||||
_set_session_cookie(response, token, request)
|
|
||||||
|
|
||||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
|
||||||
|
|
||||||
|
|
||||||
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
|
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
|
||||||
|
|||||||
@@ -30,16 +30,11 @@ class FeedbackCreateRequest(BaseModel):
|
|||||||
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
||||||
|
|
||||||
|
|
||||||
class FeedbackUpsertRequest(BaseModel):
|
|
||||||
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
|
|
||||||
comment: str | None = Field(default=None, description="Optional text feedback")
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackResponse(BaseModel):
|
class FeedbackResponse(BaseModel):
|
||||||
feedback_id: str
|
feedback_id: str
|
||||||
run_id: str
|
run_id: str
|
||||||
thread_id: str
|
thread_id: str
|
||||||
user_id: str | None = None
|
owner_id: str | None = None
|
||||||
message_id: str | None = None
|
message_id: str | None = None
|
||||||
rating: int
|
rating: int
|
||||||
comment: str | None = None
|
comment: str | None = None
|
||||||
@@ -58,57 +53,6 @@ class FeedbackStatsResponse(BaseModel):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
|
||||||
async def upsert_feedback(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
body: FeedbackUpsertRequest,
|
|
||||||
request: Request,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Create or update feedback for a run (idempotent)."""
|
|
||||||
if body.rating not in (1, -1):
|
|
||||||
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
|
||||||
|
|
||||||
user_id = await get_current_user(request)
|
|
||||||
|
|
||||||
run_store = get_run_store(request)
|
|
||||||
run = await run_store.get(run_id)
|
|
||||||
if run is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
if run.get("thread_id") != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
|
||||||
|
|
||||||
feedback_repo = get_feedback_repo(request)
|
|
||||||
return await feedback_repo.upsert(
|
|
||||||
run_id=run_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
rating=body.rating,
|
|
||||||
user_id=user_id,
|
|
||||||
comment=body.comment,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{thread_id}/runs/{run_id}/feedback")
|
|
||||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
|
||||||
async def delete_run_feedback(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete the current user's feedback for a run."""
|
|
||||||
user_id = await get_current_user(request)
|
|
||||||
feedback_repo = get_feedback_repo(request)
|
|
||||||
deleted = await feedback_repo.delete_by_run(
|
|
||||||
thread_id=thread_id,
|
|
||||||
run_id=run_id,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
if not deleted:
|
|
||||||
raise HTTPException(status_code=404, detail="No feedback found for this run")
|
|
||||||
return {"success": True}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||||
async def create_feedback(
|
async def create_feedback(
|
||||||
@@ -136,7 +80,7 @@ async def create_feedback(
|
|||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
rating=body.rating,
|
rating=body.rating,
|
||||||
user_id=user_id,
|
owner_id=user_id,
|
||||||
message_id=body.message_id,
|
message_id=body.message_id,
|
||||||
comment=body.comment,
|
comment=body.comment,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from deerflow.agents.memory.updater import (
|
|||||||
update_memory_fact,
|
update_memory_fact,
|
||||||
)
|
)
|
||||||
from deerflow.config.memory_config import get_memory_config
|
from deerflow.config.memory_config import get_memory_config
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["memory"])
|
router = APIRouter(prefix="/api", tags=["memory"])
|
||||||
|
|
||||||
@@ -148,7 +147,7 @@ async def get_memory() -> MemoryResponse:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
memory_data = get_memory_data()
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@@ -168,7 +167,7 @@ async def reload_memory() -> MemoryResponse:
|
|||||||
Returns:
|
Returns:
|
||||||
The reloaded memory data.
|
The reloaded memory data.
|
||||||
"""
|
"""
|
||||||
memory_data = reload_memory_data(user_id=get_effective_user_id())
|
memory_data = reload_memory_data()
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@@ -182,7 +181,7 @@ async def reload_memory() -> MemoryResponse:
|
|||||||
async def clear_memory() -> MemoryResponse:
|
async def clear_memory() -> MemoryResponse:
|
||||||
"""Clear all persisted memory data."""
|
"""Clear all persisted memory data."""
|
||||||
try:
|
try:
|
||||||
memory_data = clear_memory_data(user_id=get_effective_user_id())
|
memory_data = clear_memory_data()
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
||||||
|
|
||||||
@@ -203,7 +202,6 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
|||||||
content=request.content,
|
content=request.content,
|
||||||
category=request.category,
|
category=request.category,
|
||||||
confidence=request.confidence,
|
confidence=request.confidence,
|
||||||
user_id=get_effective_user_id(),
|
|
||||||
)
|
)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise _map_memory_fact_value_error(exc) from exc
|
raise _map_memory_fact_value_error(exc) from exc
|
||||||
@@ -223,7 +221,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
|||||||
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
||||||
"""Delete a single fact from memory by fact id."""
|
"""Delete a single fact from memory by fact id."""
|
||||||
try:
|
try:
|
||||||
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
|
memory_data = delete_memory_fact(fact_id)
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
@@ -247,7 +245,6 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
|||||||
content=request.content,
|
content=request.content,
|
||||||
category=request.category,
|
category=request.category,
|
||||||
confidence=request.confidence,
|
confidence=request.confidence,
|
||||||
user_id=get_effective_user_id(),
|
|
||||||
)
|
)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise _map_memory_fact_value_error(exc) from exc
|
raise _map_memory_fact_value_error(exc) from exc
|
||||||
@@ -268,7 +265,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
|||||||
)
|
)
|
||||||
async def export_memory() -> MemoryResponse:
|
async def export_memory() -> MemoryResponse:
|
||||||
"""Export the current memory data."""
|
"""Export the current memory data."""
|
||||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
memory_data = get_memory_data()
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@@ -282,7 +279,7 @@ async def export_memory() -> MemoryResponse:
|
|||||||
async def import_memory(request: MemoryResponse) -> MemoryResponse:
|
async def import_memory(request: MemoryResponse) -> MemoryResponse:
|
||||||
"""Import and persist memory data."""
|
"""Import and persist memory data."""
|
||||||
try:
|
try:
|
||||||
memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id())
|
memory_data = import_memory_data(request.model_dump())
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
||||||
|
|
||||||
@@ -340,7 +337,7 @@ async def get_memory_status() -> MemoryStatusResponse:
|
|||||||
Combined memory configuration and current data.
|
Combined memory configuration and current data.
|
||||||
"""
|
"""
|
||||||
config = get_memory_config()
|
config = get_memory_config()
|
||||||
memory_data = get_memory_data(user_id=get_effective_user_id())
|
memory_data = get_memory_data()
|
||||||
|
|
||||||
return MemoryStatusResponse(
|
return MemoryStatusResponse(
|
||||||
config=MemoryConfigResponse(
|
config=MemoryConfigResponse(
|
||||||
|
|||||||
@@ -11,11 +11,10 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request
|
from fastapi import APIRouter, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
|
||||||
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
|
||||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||||
from app.gateway.services import sse_consumer, start_run
|
from app.gateway.services import sse_consumer, start_run
|
||||||
from deerflow.runtime import serialize_channel_values
|
from deerflow.runtime import serialize_channel_values
|
||||||
@@ -86,58 +85,3 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
|||||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||||
|
|
||||||
return {"status": record.status.value, "error": record.error}
|
return {"status": record.status.value, "error": record.error}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Run-scoped read endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def _resolve_run(run_id: str, request: Request) -> dict:
|
|
||||||
"""Fetch run by run_id with user ownership check. Raises 404 if not found."""
|
|
||||||
run_store = get_run_store(request)
|
|
||||||
record = await run_store.get(run_id) # user_id=AUTO filters by contextvar
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{run_id}/messages")
|
|
||||||
@require_permission("runs", "read")
|
|
||||||
async def run_messages(
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
limit: int = Query(default=50, le=200, ge=1),
|
|
||||||
before_seq: int | None = Query(default=None),
|
|
||||||
after_seq: int | None = Query(default=None),
|
|
||||||
) -> dict:
|
|
||||||
"""Return paginated messages for a run (cursor-based).
|
|
||||||
|
|
||||||
Pagination:
|
|
||||||
- after_seq: messages with seq > after_seq (forward)
|
|
||||||
- before_seq: messages with seq < before_seq (backward)
|
|
||||||
- neither: latest messages
|
|
||||||
|
|
||||||
Response: { data: [...], has_more: bool }
|
|
||||||
"""
|
|
||||||
run = await _resolve_run(run_id, request)
|
|
||||||
event_store = get_run_event_store(request)
|
|
||||||
rows = await event_store.list_messages_by_run(
|
|
||||||
run["thread_id"],
|
|
||||||
run_id,
|
|
||||||
limit=limit + 1,
|
|
||||||
before_seq=before_seq,
|
|
||||||
after_seq=after_seq,
|
|
||||||
)
|
|
||||||
has_more = len(rows) > limit
|
|
||||||
data = rows[:limit] if has_more else rows
|
|
||||||
return {"data": data, "has_more": has_more}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{run_id}/feedback")
|
|
||||||
@require_permission("runs", "read")
|
|
||||||
async def run_feedback(run_id: str, request: Request) -> list[dict]:
|
|
||||||
"""Return all feedback for a run."""
|
|
||||||
run = await _resolve_run(run_id, request)
|
|
||||||
feedback_repo = get_feedback_repo(request)
|
|
||||||
return await feedback_repo.list_by_run(run["thread_id"], run_id)
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from fastapi.responses import Response, StreamingResponse
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||||
from app.gateway.services import sse_consumer, start_run
|
from app.gateway.services import sse_consumer, start_run
|
||||||
from deerflow.runtime import RunRecord, serialize_channel_values
|
from deerflow.runtime import RunRecord, serialize_channel_values
|
||||||
|
|
||||||
@@ -54,6 +54,7 @@ class RunCreateRequest(BaseModel):
|
|||||||
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
||||||
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
||||||
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
||||||
|
follow_up_to_run_id: str | None = Field(default=None, description="Run ID this message follows up on. Auto-detected from latest successful run if not provided.")
|
||||||
|
|
||||||
|
|
||||||
class RunResponse(BaseModel):
|
class RunResponse(BaseModel):
|
||||||
@@ -290,67 +291,17 @@ async def list_thread_messages(
|
|||||||
before_seq: int | None = Query(default=None),
|
before_seq: int | None = Query(default=None),
|
||||||
after_seq: int | None = Query(default=None),
|
after_seq: int | None = Query(default=None),
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Return displayable messages for a thread (across all runs), with feedback attached."""
|
"""Return displayable messages for a thread (across all runs)."""
|
||||||
event_store = get_run_event_store(request)
|
event_store = get_run_event_store(request)
|
||||||
messages = await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
return await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
||||||
|
|
||||||
# Attach feedback to the last AI message of each run
|
|
||||||
feedback_repo = get_feedback_repo(request)
|
|
||||||
user_id = await get_current_user(request)
|
|
||||||
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
|
|
||||||
|
|
||||||
# Find the last ai_message per run_id
|
|
||||||
last_ai_per_run: dict[str, int] = {} # run_id -> index in messages list
|
|
||||||
for i, msg in enumerate(messages):
|
|
||||||
if msg.get("event_type") == "ai_message":
|
|
||||||
last_ai_per_run[msg["run_id"]] = i
|
|
||||||
|
|
||||||
# Attach feedback field
|
|
||||||
last_ai_indices = set(last_ai_per_run.values())
|
|
||||||
for i, msg in enumerate(messages):
|
|
||||||
if i in last_ai_indices:
|
|
||||||
run_id = msg["run_id"]
|
|
||||||
fb = feedback_map.get(run_id)
|
|
||||||
msg["feedback"] = (
|
|
||||||
{
|
|
||||||
"feedback_id": fb["feedback_id"],
|
|
||||||
"rating": fb["rating"],
|
|
||||||
"comment": fb.get("comment"),
|
|
||||||
}
|
|
||||||
if fb
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
msg["feedback"] = None
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/messages")
|
@router.get("/{thread_id}/runs/{run_id}/messages")
|
||||||
@require_permission("runs", "read", owner_check=True)
|
@require_permission("runs", "read", owner_check=True)
|
||||||
async def list_run_messages(
|
async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]:
|
||||||
thread_id: str,
|
"""Return displayable messages for a specific run."""
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
limit: int = Query(default=50, le=200, ge=1),
|
|
||||||
before_seq: int | None = Query(default=None),
|
|
||||||
after_seq: int | None = Query(default=None),
|
|
||||||
) -> dict:
|
|
||||||
"""Return paginated messages for a specific run.
|
|
||||||
|
|
||||||
Response: { data: [...], has_more: bool }
|
|
||||||
"""
|
|
||||||
event_store = get_run_event_store(request)
|
event_store = get_run_event_store(request)
|
||||||
rows = await event_store.list_messages_by_run(
|
return await event_store.list_messages_by_run(thread_id, run_id)
|
||||||
thread_id,
|
|
||||||
run_id,
|
|
||||||
limit=limit + 1,
|
|
||||||
before_seq=before_seq,
|
|
||||||
after_seq=after_seq,
|
|
||||||
)
|
|
||||||
has_more = len(rows) > limit
|
|
||||||
data = rows[:limit] if has_more else rows
|
|
||||||
return {"data": data, "has_more": has_more}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/events")
|
@router.get("/{thread_id}/runs/{run_id}/events")
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ matching the LangGraph Platform wire format expected by the
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -26,7 +25,6 @@ from app.gateway.deps import get_checkpointer
|
|||||||
from app.gateway.utils import sanitize_log_param
|
from app.gateway.utils import sanitize_log_param
|
||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
from deerflow.runtime import serialize_channel_values
|
from deerflow.runtime import serialize_channel_values
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||||
@@ -36,7 +34,7 @@ router = APIRouter(prefix="/api/threads", tags=["threads"])
|
|||||||
# them. Pydantic ``@field_validator("metadata")`` strips them on every
|
# them. Pydantic ``@field_validator("metadata")`` strips them on every
|
||||||
# inbound model below so a malicious client cannot reflect a forged
|
# inbound model below so a malicious client cannot reflect a forged
|
||||||
# owner identity through the API surface. Defense-in-depth — the
|
# owner identity through the API surface. Defense-in-depth — the
|
||||||
# row-level invariant is still ``threads_meta.user_id`` populated from
|
# row-level invariant is still ``threads_meta.owner_id`` populated from
|
||||||
# the auth contextvar; this list closes the metadata-blob echo gap.
|
# the auth contextvar; this list closes the metadata-blob echo gap.
|
||||||
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
|
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
|
||||||
|
|
||||||
@@ -144,11 +142,11 @@ class ThreadHistoryRequest(BaseModel):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _delete_thread_data(thread_id: str, paths: Paths | None = None, *, user_id: str | None = None) -> ThreadDeleteResponse:
|
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
|
||||||
"""Delete local persisted filesystem data for a thread."""
|
"""Delete local persisted filesystem data for a thread."""
|
||||||
path_manager = paths or get_paths()
|
path_manager = paths or get_paths()
|
||||||
try:
|
try:
|
||||||
path_manager.delete_thread_dir(thread_id, user_id=user_id)
|
path_manager.delete_thread_dir(thread_id)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@@ -196,10 +194,10 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
|||||||
and removes the thread_meta row from the configured ThreadMetaStore
|
and removes the thread_meta row from the configured ThreadMetaStore
|
||||||
(sqlite or memory).
|
(sqlite or memory).
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_store
|
from app.gateway.deps import get_thread_meta_repo
|
||||||
|
|
||||||
# Clean local filesystem
|
# Clean local filesystem
|
||||||
response = _delete_thread_data(thread_id, user_id=get_effective_user_id())
|
response = _delete_thread_data(thread_id)
|
||||||
|
|
||||||
# Remove checkpoints (best-effort)
|
# Remove checkpoints (best-effort)
|
||||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
checkpointer = getattr(request.app.state, "checkpointer", None)
|
||||||
@@ -213,8 +211,8 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
|||||||
# Remove thread_meta row (best-effort) — required for sqlite backend
|
# Remove thread_meta row (best-effort) — required for sqlite backend
|
||||||
# so the deleted thread no longer appears in /threads/search.
|
# so the deleted thread no longer appears in /threads/search.
|
||||||
try:
|
try:
|
||||||
thread_store = get_thread_store(request)
|
thread_meta_repo = get_thread_meta_repo(request)
|
||||||
await thread_store.delete(thread_id)
|
await thread_meta_repo.delete(thread_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
|
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
@@ -229,17 +227,17 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
and an empty checkpoint (so state endpoints work immediately).
|
and an empty checkpoint (so state endpoints work immediately).
|
||||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
Idempotent: returns the existing record when ``thread_id`` already exists.
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_store
|
from app.gateway.deps import get_thread_meta_repo
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
checkpointer = get_checkpointer(request)
|
||||||
thread_store = get_thread_store(request)
|
thread_meta_repo = get_thread_meta_repo(request)
|
||||||
thread_id = body.thread_id or str(uuid.uuid4())
|
thread_id = body.thread_id or str(uuid.uuid4())
|
||||||
now = time.time()
|
now = time.time()
|
||||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||||
|
|
||||||
# Idempotency: return existing record when already present
|
# Idempotency: return existing record when already present
|
||||||
existing_record = await thread_store.get(thread_id)
|
existing_record = await thread_meta_repo.get(thread_id)
|
||||||
if existing_record is not None:
|
if existing_record is not None:
|
||||||
return ThreadResponse(
|
return ThreadResponse(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
@@ -251,7 +249,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
|
|
||||||
# Write thread_meta so the thread appears in /threads/search immediately
|
# Write thread_meta so the thread appears in /threads/search immediately
|
||||||
try:
|
try:
|
||||||
await thread_store.create(
|
await thread_meta_repo.create(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=getattr(body, "assistant_id", None),
|
assistant_id=getattr(body, "assistant_id", None),
|
||||||
metadata=body.metadata,
|
metadata=body.metadata,
|
||||||
@@ -295,9 +293,9 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
|||||||
Delegates to the configured ThreadMetaStore implementation
|
Delegates to the configured ThreadMetaStore implementation
|
||||||
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_store
|
from app.gateway.deps import get_thread_meta_repo
|
||||||
|
|
||||||
repo = get_thread_store(request)
|
repo = get_thread_meta_repo(request)
|
||||||
rows = await repo.search(
|
rows = await repo.search(
|
||||||
metadata=body.metadata or None,
|
metadata=body.metadata or None,
|
||||||
status=body.status,
|
status=body.status,
|
||||||
@@ -322,22 +320,22 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
|||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
||||||
"""Merge metadata into a thread record."""
|
"""Merge metadata into a thread record."""
|
||||||
from app.gateway.deps import get_thread_store
|
from app.gateway.deps import get_thread_meta_repo
|
||||||
|
|
||||||
thread_store = get_thread_store(request)
|
thread_meta_repo = get_thread_meta_repo(request)
|
||||||
record = await thread_store.get(thread_id)
|
record = await thread_meta_repo.get(thread_id)
|
||||||
if record is None:
|
if record is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||||
|
|
||||||
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
|
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
|
||||||
try:
|
try:
|
||||||
await thread_store.update_metadata(thread_id, body.metadata)
|
await thread_meta_repo.update_metadata(thread_id, body.metadata)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
|
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
|
||||||
raise HTTPException(status_code=500, detail="Failed to update thread")
|
raise HTTPException(status_code=500, detail="Failed to update thread")
|
||||||
|
|
||||||
# Re-read to get the merged metadata + refreshed updated_at
|
# Re-read to get the merged metadata + refreshed updated_at
|
||||||
record = await thread_store.get(thread_id) or record
|
record = await thread_meta_repo.get(thread_id) or record
|
||||||
return ThreadResponse(
|
return ThreadResponse(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
status=record.get("status", "idle"),
|
status=record.get("status", "idle"),
|
||||||
@@ -356,12 +354,12 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
|||||||
execution status from the checkpointer. Falls back to the checkpointer
|
execution status from the checkpointer. Falls back to the checkpointer
|
||||||
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
|
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_store
|
from app.gateway.deps import get_thread_meta_repo
|
||||||
|
|
||||||
thread_store = get_thread_store(request)
|
thread_meta_repo = get_thread_meta_repo(request)
|
||||||
checkpointer = get_checkpointer(request)
|
checkpointer = get_checkpointer(request)
|
||||||
|
|
||||||
record: dict | None = await thread_store.get(thread_id)
|
record: dict | None = await thread_meta_repo.get(thread_id)
|
||||||
|
|
||||||
# Derive accurate status from the checkpointer
|
# Derive accurate status from the checkpointer
|
||||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
@@ -404,7 +402,6 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
@require_permission("threads", "read", owner_check=True)
|
||||||
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
||||||
@@ -443,10 +440,8 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
|||||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||||
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
||||||
|
|
||||||
values = serialize_channel_values(channel_values)
|
|
||||||
|
|
||||||
return ThreadStateResponse(
|
return ThreadStateResponse(
|
||||||
values=values,
|
values=serialize_channel_values(channel_values),
|
||||||
next=next_tasks,
|
next=next_tasks,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||||
@@ -467,10 +462,10 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
|||||||
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
|
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
|
||||||
change immediately in both sqlite and memory backends.
|
change immediately in both sqlite and memory backends.
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_store
|
from app.gateway.deps import get_thread_meta_repo
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
checkpointer = get_checkpointer(request)
|
||||||
thread_store = get_thread_store(request)
|
thread_meta_repo = get_thread_meta_repo(request)
|
||||||
|
|
||||||
# checkpoint_ns must be present in the config for aput — default to ""
|
# checkpoint_ns must be present in the config for aput — default to ""
|
||||||
# (the root graph namespace). checkpoint_id is optional; omitting it
|
# (the root graph namespace). checkpoint_id is optional; omitting it
|
||||||
@@ -534,7 +529,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
|||||||
new_title = body.values["title"]
|
new_title = body.values["title"]
|
||||||
if new_title: # Skip empty strings and None
|
if new_title: # Skip empty strings and None
|
||||||
try:
|
try:
|
||||||
await thread_store.update_display_name(thread_id, new_title)
|
await thread_meta_repo.update_display_name(thread_id, new_title)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
@@ -587,7 +582,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
|||||||
if thread_data := channel_values.get("thread_data"):
|
if thread_data := channel_values.get("thread_data"):
|
||||||
values["thread_data"] = thread_data
|
values["thread_data"] = thread_data
|
||||||
|
|
||||||
# Attach messages only to the latest checkpoint entry.
|
# Attach messages from checkpointer only for the latest checkpoint
|
||||||
if is_latest_checkpoint:
|
if is_latest_checkpoint:
|
||||||
messages = channel_values.get("messages")
|
messages = channel_values.get("messages")
|
||||||
if messages:
|
if messages:
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||||
from deerflow.uploads.manager import (
|
from deerflow.uploads.manager import (
|
||||||
PathTraversalError,
|
PathTraversalError,
|
||||||
@@ -56,7 +55,7 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=UploadResponse)
|
@router.post("", response_model=UploadResponse)
|
||||||
@require_permission("threads", "write", owner_check=True, require_existing=False)
|
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||||
async def upload_files(
|
async def upload_files(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -70,7 +69,7 @@ async def upload_files(
|
|||||||
uploads_dir = ensure_uploads_dir(thread_id)
|
uploads_dir = ensure_uploads_dir(thread_id)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
|
||||||
uploaded_files = []
|
uploaded_files = []
|
||||||
|
|
||||||
sandbox_provider = get_sandbox_provider()
|
sandbox_provider = get_sandbox_provider()
|
||||||
@@ -148,7 +147,7 @@ async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
|||||||
enrich_file_listing(result, thread_id)
|
enrich_file_listing(result, thread_id)
|
||||||
|
|
||||||
# Gateway additionally includes the sandbox-relative path.
|
# Gateway additionally includes the sandbox-relative path.
|
||||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
|
||||||
for f in result["files"]:
|
for f in result["files"]:
|
||||||
f["path"] = str(sandbox_uploads / f["filename"])
|
f["path"] = str(sandbox_uploads / f["filename"])
|
||||||
|
|
||||||
|
|||||||
@@ -195,6 +195,21 @@ async def start_run(
|
|||||||
|
|
||||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||||
|
|
||||||
|
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
|
||||||
|
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
|
||||||
|
if follow_up_to_run_id is None:
|
||||||
|
run_store = get_run_store(request)
|
||||||
|
try:
|
||||||
|
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
|
||||||
|
if recent_runs and recent_runs[0].get("status") == "success":
|
||||||
|
follow_up_to_run_id = recent_runs[0]["run_id"]
|
||||||
|
except Exception:
|
||||||
|
pass # Don't block run creation
|
||||||
|
|
||||||
|
# Enrich base context with per-run field
|
||||||
|
if follow_up_to_run_id:
|
||||||
|
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
record = await run_mgr.create_or_reject(
|
record = await run_mgr.create_or_reject(
|
||||||
thread_id,
|
thread_id,
|
||||||
@@ -203,6 +218,7 @@ async def start_run(
|
|||||||
metadata=body.metadata or {},
|
metadata=body.metadata or {},
|
||||||
kwargs={"input": body.input, "config": body.config},
|
kwargs={"input": body.input, "config": body.config},
|
||||||
multitask_strategy=body.multitask_strategy,
|
multitask_strategy=body.multitask_strategy,
|
||||||
|
follow_up_to_run_id=follow_up_to_run_id,
|
||||||
)
|
)
|
||||||
except ConflictError as exc:
|
except ConflictError as exc:
|
||||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||||
@@ -213,15 +229,15 @@ async def start_run(
|
|||||||
# even for threads that were never explicitly created via POST /threads
|
# even for threads that were never explicitly created via POST /threads
|
||||||
# (e.g. stateless runs).
|
# (e.g. stateless runs).
|
||||||
try:
|
try:
|
||||||
existing = await run_ctx.thread_store.get(thread_id)
|
existing = await run_ctx.thread_meta_repo.get(thread_id)
|
||||||
if existing is None:
|
if existing is None:
|
||||||
await run_ctx.thread_store.create(
|
await run_ctx.thread_meta_repo.create(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=body.assistant_id,
|
assistant_id=body.assistant_id,
|
||||||
metadata=body.metadata,
|
metadata=body.metadata,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await run_ctx.thread_store.update_status(thread_id, "running")
|
await run_ctx.thread_meta_repo.update_status(thread_id, "running")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
@@ -269,7 +285,7 @@ async def start_run(
|
|||||||
record.task = task
|
record.task = task
|
||||||
|
|
||||||
# Title sync is handled by worker.py's finally block which reads the
|
# Title sync is handled by worker.py's finally block which reads the
|
||||||
# title from the checkpoint and calls thread_store.update_display_name
|
# title from the checkpoint and calls thread_meta_repo.update_display_name
|
||||||
# after the run completes.
|
# after the run completes.
|
||||||
|
|
||||||
return record
|
return record
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ title:
|
|||||||
# checkpointer.py
|
# checkpointer.py
|
||||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||||
|
|
||||||
checkpointer = SqliteSaver.from_conn_string("deerflow.db")
|
checkpointer = SqliteSaver.from_conn_string("checkpoints.db")
|
||||||
```
|
```
|
||||||
|
|
||||||
```json
|
```json
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointer
|
||||||
from .factory import create_deerflow_agent
|
from .factory import create_deerflow_agent
|
||||||
from .features import Next, Prev, RuntimeFeatures
|
from .features import Next, Prev, RuntimeFeatures
|
||||||
from .lead_agent import make_lead_agent
|
from .lead_agent import make_lead_agent
|
||||||
@@ -17,4 +18,7 @@ __all__ = [
|
|||||||
"make_lead_agent",
|
"make_lead_agent",
|
||||||
"SandboxState",
|
"SandboxState",
|
||||||
"ThreadState",
|
"ThreadState",
|
||||||
|
"get_checkpointer",
|
||||||
|
"reset_checkpointer",
|
||||||
|
"make_checkpointer",
|
||||||
]
|
]
|
||||||
|
|||||||
+4
-4
@@ -7,12 +7,12 @@ Supported backends: memory, sqlite, postgres.
|
|||||||
|
|
||||||
Usage (e.g. FastAPI lifespan)::
|
Usage (e.g. FastAPI lifespan)::
|
||||||
|
|
||||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||||
|
|
||||||
async with make_checkpointer() as checkpointer:
|
async with make_checkpointer() as checkpointer:
|
||||||
app.state.checkpointer = checkpointer # InMemorySaver if not configured
|
app.state.checkpointer = checkpointer # InMemorySaver if not configured
|
||||||
|
|
||||||
For sync usage see :mod:`deerflow.runtime.checkpointer.provider`.
|
For sync usage see :mod:`deerflow.agents.checkpointer.provider`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -24,12 +24,12 @@ from collections.abc import AsyncIterator
|
|||||||
|
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
from deerflow.config.app_config import get_app_config
|
from deerflow.agents.checkpointer.provider import (
|
||||||
from deerflow.runtime.checkpointer.provider import (
|
|
||||||
POSTGRES_CONN_REQUIRED,
|
POSTGRES_CONN_REQUIRED,
|
||||||
POSTGRES_INSTALL,
|
POSTGRES_INSTALL,
|
||||||
SQLITE_INSTALL,
|
SQLITE_INSTALL,
|
||||||
)
|
)
|
||||||
|
from deerflow.config.app_config import get_app_config
|
||||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
+1
-1
@@ -7,7 +7,7 @@ Supported backends: memory, sqlite, postgres.
|
|||||||
|
|
||||||
Usage::
|
Usage::
|
||||||
|
|
||||||
from deerflow.runtime.checkpointer.provider import get_checkpointer, checkpointer_context
|
from deerflow.agents.checkpointer.provider import get_checkpointer, checkpointer_context
|
||||||
|
|
||||||
# Singleton — reused across calls, closed on process exit
|
# Singleton — reused across calls, closed on process exit
|
||||||
cp = get_checkpointer()
|
cp = get_checkpointer()
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||||
@@ -9,7 +9,6 @@ from deerflow.agents.middlewares.clarification_middleware import ClarificationMi
|
|||||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||||
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||||
from deerflow.agents.middlewares.summarization_middleware import SummarizationMiddleware
|
|
||||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||||
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
|
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
|
||||||
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
||||||
|
|||||||
@@ -519,13 +519,12 @@ def _get_memory_context(agent_name: str | None = None) -> str:
|
|||||||
try:
|
try:
|
||||||
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
||||||
from deerflow.config.memory_config import get_memory_config
|
from deerflow.config.memory_config import get_memory_config
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
config = get_memory_config()
|
config = get_memory_config()
|
||||||
if not config.enabled or not config.injection_enabled:
|
if not config.enabled or not config.injection_enabled:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
memory_data = get_memory_data(agent_name)
|
||||||
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
||||||
|
|
||||||
if not memory_content.strip():
|
if not memory_content.strip():
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ class ConversationContext:
|
|||||||
messages: list[Any]
|
messages: list[Any]
|
||||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||||
agent_name: str | None = None
|
agent_name: str | None = None
|
||||||
user_id: str | None = None
|
|
||||||
correction_detected: bool = False
|
correction_detected: bool = False
|
||||||
reinforcement_detected: bool = False
|
reinforcement_detected: bool = False
|
||||||
|
|
||||||
@@ -45,7 +44,6 @@ class MemoryUpdateQueue:
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
messages: list[Any],
|
messages: list[Any],
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
user_id: str | None = None,
|
|
||||||
correction_detected: bool = False,
|
correction_detected: bool = False,
|
||||||
reinforcement_detected: bool = False,
|
reinforcement_detected: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -55,9 +53,6 @@ class MemoryUpdateQueue:
|
|||||||
thread_id: The thread ID.
|
thread_id: The thread ID.
|
||||||
messages: The conversation messages.
|
messages: The conversation messages.
|
||||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||||
user_id: The user ID captured at enqueue time. Stored in ConversationContext so it
|
|
||||||
survives the threading.Timer boundary (ContextVar does not propagate across
|
|
||||||
raw threads).
|
|
||||||
correction_detected: Whether recent turns include an explicit correction signal.
|
correction_detected: Whether recent turns include an explicit correction signal.
|
||||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||||
"""
|
"""
|
||||||
@@ -76,7 +71,6 @@ class MemoryUpdateQueue:
|
|||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
user_id=user_id,
|
|
||||||
correction_detected=merged_correction_detected,
|
correction_detected=merged_correction_detected,
|
||||||
reinforcement_detected=merged_reinforcement_detected,
|
reinforcement_detected=merged_reinforcement_detected,
|
||||||
)
|
)
|
||||||
@@ -142,7 +136,6 @@ class MemoryUpdateQueue:
|
|||||||
agent_name=context.agent_name,
|
agent_name=context.agent_name,
|
||||||
correction_detected=context.correction_detected,
|
correction_detected=context.correction_detected,
|
||||||
reinforcement_detected=context.reinforcement_detected,
|
reinforcement_detected=context.reinforcement_detected,
|
||||||
user_id=context.user_id,
|
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||||
|
|||||||
@@ -43,17 +43,17 @@ class MemoryStorage(abc.ABC):
|
|||||||
"""Abstract base class for memory storage providers."""
|
"""Abstract base class for memory storage providers."""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
def load(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||||
"""Load memory data for the given agent."""
|
"""Load memory data for the given agent."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||||
"""Force reload memory data for the given agent."""
|
"""Force reload memory data for the given agent."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||||
"""Save memory data for the given agent."""
|
"""Save memory data for the given agent."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -63,9 +63,9 @@ class FileMemoryStorage(MemoryStorage):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize the file memory storage."""
|
"""Initialize the file memory storage."""
|
||||||
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global)
|
# Per-agent memory cache: keyed by agent_name (None = global)
|
||||||
# Value: (memory_data, file_mtime)
|
# Value: (memory_data, file_mtime)
|
||||||
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {}
|
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
|
||||||
|
|
||||||
def _validate_agent_name(self, agent_name: str) -> None:
|
def _validate_agent_name(self, agent_name: str) -> None:
|
||||||
"""Validate that the agent name is safe to use in filesystem paths.
|
"""Validate that the agent name is safe to use in filesystem paths.
|
||||||
@@ -78,29 +78,21 @@ class FileMemoryStorage(MemoryStorage):
|
|||||||
if not AGENT_NAME_PATTERN.match(agent_name):
|
if not AGENT_NAME_PATTERN.match(agent_name):
|
||||||
raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}")
|
raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}")
|
||||||
|
|
||||||
def _get_memory_file_path(self, agent_name: str | None = None, *, user_id: str | None = None) -> Path:
|
def _get_memory_file_path(self, agent_name: str | None = None) -> Path:
|
||||||
"""Get the path to the memory file."""
|
"""Get the path to the memory file."""
|
||||||
if user_id is not None:
|
|
||||||
if agent_name is not None:
|
|
||||||
self._validate_agent_name(agent_name)
|
|
||||||
return get_paths().user_agent_memory_file(user_id, agent_name)
|
|
||||||
config = get_memory_config()
|
|
||||||
if config.storage_path and Path(config.storage_path).is_absolute():
|
|
||||||
return Path(config.storage_path)
|
|
||||||
return get_paths().user_memory_file(user_id)
|
|
||||||
# Legacy: no user_id
|
|
||||||
if agent_name is not None:
|
if agent_name is not None:
|
||||||
self._validate_agent_name(agent_name)
|
self._validate_agent_name(agent_name)
|
||||||
return get_paths().agent_memory_file(agent_name)
|
return get_paths().agent_memory_file(agent_name)
|
||||||
|
|
||||||
config = get_memory_config()
|
config = get_memory_config()
|
||||||
if config.storage_path:
|
if config.storage_path:
|
||||||
p = Path(config.storage_path)
|
p = Path(config.storage_path)
|
||||||
return p if p.is_absolute() else get_paths().base_dir / p
|
return p if p.is_absolute() else get_paths().base_dir / p
|
||||||
return get_paths().memory_file
|
return get_paths().memory_file
|
||||||
|
|
||||||
def _load_memory_from_file(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
def _load_memory_from_file(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||||
"""Load memory data from file."""
|
"""Load memory data from file."""
|
||||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
file_path = self._get_memory_file_path(agent_name)
|
||||||
|
|
||||||
if not file_path.exists():
|
if not file_path.exists():
|
||||||
return create_empty_memory()
|
return create_empty_memory()
|
||||||
@@ -113,42 +105,40 @@ class FileMemoryStorage(MemoryStorage):
|
|||||||
logger.warning("Failed to load memory file: %s", e)
|
logger.warning("Failed to load memory file: %s", e)
|
||||||
return create_empty_memory()
|
return create_empty_memory()
|
||||||
|
|
||||||
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
def load(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||||
"""Load memory data (cached with file modification time check)."""
|
"""Load memory data (cached with file modification time check)."""
|
||||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
file_path = self._get_memory_file_path(agent_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
|
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||||
except OSError:
|
except OSError:
|
||||||
current_mtime = None
|
current_mtime = None
|
||||||
|
|
||||||
cache_key = (user_id, agent_name)
|
cached = self._memory_cache.get(agent_name)
|
||||||
cached = self._memory_cache.get(cache_key)
|
|
||||||
|
|
||||||
if cached is None or cached[1] != current_mtime:
|
if cached is None or cached[1] != current_mtime:
|
||||||
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
memory_data = self._load_memory_from_file(agent_name)
|
||||||
self._memory_cache[cache_key] = (memory_data, current_mtime)
|
self._memory_cache[agent_name] = (memory_data, current_mtime)
|
||||||
return memory_data
|
return memory_data
|
||||||
|
|
||||||
return cached[0]
|
return cached[0]
|
||||||
|
|
||||||
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||||
"""Reload memory data from file, forcing cache invalidation."""
|
"""Reload memory data from file, forcing cache invalidation."""
|
||||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
file_path = self._get_memory_file_path(agent_name)
|
||||||
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
memory_data = self._load_memory_from_file(agent_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mtime = file_path.stat().st_mtime if file_path.exists() else None
|
mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||||
except OSError:
|
except OSError:
|
||||||
mtime = None
|
mtime = None
|
||||||
|
|
||||||
cache_key = (user_id, agent_name)
|
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||||
self._memory_cache[cache_key] = (memory_data, mtime)
|
|
||||||
return memory_data
|
return memory_data
|
||||||
|
|
||||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||||
"""Save memory data to file and update cache."""
|
"""Save memory data to file and update cache."""
|
||||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
file_path = self._get_memory_file_path(agent_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -165,8 +155,7 @@ class FileMemoryStorage(MemoryStorage):
|
|||||||
except OSError:
|
except OSError:
|
||||||
mtime = None
|
mtime = None
|
||||||
|
|
||||||
cache_key = (user_id, agent_name)
|
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||||
self._memory_cache[cache_key] = (memory_data, mtime)
|
|
||||||
logger.info("Memory saved to %s", file_path)
|
logger.info("Memory saved to %s", file_path)
|
||||||
return True
|
return True
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
|
|||||||
@@ -27,28 +27,27 @@ def _create_empty_memory() -> dict[str, Any]:
|
|||||||
return create_empty_memory()
|
return create_empty_memory()
|
||||||
|
|
||||||
|
|
||||||
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||||
"""Backward-compatible wrapper around the configured memory storage save path."""
|
"""Backward-compatible wrapper around the configured memory storage save path."""
|
||||||
return get_memory_storage().save(memory_data, agent_name, user_id=user_id)
|
return get_memory_storage().save(memory_data, agent_name)
|
||||||
|
|
||||||
|
|
||||||
def get_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||||
"""Get the current memory data via storage provider."""
|
"""Get the current memory data via storage provider."""
|
||||||
return get_memory_storage().load(agent_name, user_id=user_id)
|
return get_memory_storage().load(agent_name)
|
||||||
|
|
||||||
|
|
||||||
def reload_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||||
"""Reload memory data via storage provider."""
|
"""Reload memory data via storage provider."""
|
||||||
return get_memory_storage().reload(agent_name, user_id=user_id)
|
return get_memory_storage().reload(agent_name)
|
||||||
|
|
||||||
|
|
||||||
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]:
|
||||||
"""Persist imported memory data via storage provider.
|
"""Persist imported memory data via storage provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_data: Full memory payload to persist.
|
memory_data: Full memory payload to persist.
|
||||||
agent_name: If provided, imports into per-agent memory.
|
agent_name: If provided, imports into per-agent memory.
|
||||||
user_id: If provided, scopes memory to a specific user.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The saved memory data after storage normalization.
|
The saved memory data after storage normalization.
|
||||||
@@ -57,15 +56,15 @@ def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = Non
|
|||||||
OSError: If persisting the imported memory fails.
|
OSError: If persisting the imported memory fails.
|
||||||
"""
|
"""
|
||||||
storage = get_memory_storage()
|
storage = get_memory_storage()
|
||||||
if not storage.save(memory_data, agent_name, user_id=user_id):
|
if not storage.save(memory_data, agent_name):
|
||||||
raise OSError("Failed to save imported memory data")
|
raise OSError("Failed to save imported memory data")
|
||||||
return storage.load(agent_name, user_id=user_id)
|
return storage.load(agent_name)
|
||||||
|
|
||||||
|
|
||||||
def clear_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||||
"""Clear all stored memory data and persist an empty structure."""
|
"""Clear all stored memory data and persist an empty structure."""
|
||||||
cleared_memory = create_empty_memory()
|
cleared_memory = create_empty_memory()
|
||||||
if not _save_memory_to_file(cleared_memory, agent_name, user_id=user_id):
|
if not _save_memory_to_file(cleared_memory, agent_name):
|
||||||
raise OSError("Failed to save cleared memory data")
|
raise OSError("Failed to save cleared memory data")
|
||||||
return cleared_memory
|
return cleared_memory
|
||||||
|
|
||||||
@@ -82,8 +81,6 @@ def create_memory_fact(
|
|||||||
category: str = "context",
|
category: str = "context",
|
||||||
confidence: float = 0.5,
|
confidence: float = 0.5,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
*,
|
|
||||||
user_id: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Create a new fact and persist the updated memory data."""
|
"""Create a new fact and persist the updated memory data."""
|
||||||
normalized_content = content.strip()
|
normalized_content = content.strip()
|
||||||
@@ -93,7 +90,7 @@ def create_memory_fact(
|
|||||||
normalized_category = category.strip() or "context"
|
normalized_category = category.strip() or "context"
|
||||||
validated_confidence = _validate_confidence(confidence)
|
validated_confidence = _validate_confidence(confidence)
|
||||||
now = utc_now_iso_z()
|
now = utc_now_iso_z()
|
||||||
memory_data = get_memory_data(agent_name, user_id=user_id)
|
memory_data = get_memory_data(agent_name)
|
||||||
updated_memory = dict(memory_data)
|
updated_memory = dict(memory_data)
|
||||||
facts = list(memory_data.get("facts", []))
|
facts = list(memory_data.get("facts", []))
|
||||||
facts.append(
|
facts.append(
|
||||||
@@ -108,15 +105,15 @@ def create_memory_fact(
|
|||||||
)
|
)
|
||||||
updated_memory["facts"] = facts
|
updated_memory["facts"] = facts
|
||||||
|
|
||||||
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
if not _save_memory_to_file(updated_memory, agent_name):
|
||||||
raise OSError("Failed to save memory data after creating fact")
|
raise OSError("Failed to save memory data after creating fact")
|
||||||
|
|
||||||
return updated_memory
|
return updated_memory
|
||||||
|
|
||||||
|
|
||||||
def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]:
|
||||||
"""Delete a fact by its id and persist the updated memory data."""
|
"""Delete a fact by its id and persist the updated memory data."""
|
||||||
memory_data = get_memory_data(agent_name, user_id=user_id)
|
memory_data = get_memory_data(agent_name)
|
||||||
facts = memory_data.get("facts", [])
|
facts = memory_data.get("facts", [])
|
||||||
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
|
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
|
||||||
if len(updated_facts) == len(facts):
|
if len(updated_facts) == len(facts):
|
||||||
@@ -125,7 +122,7 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id:
|
|||||||
updated_memory = dict(memory_data)
|
updated_memory = dict(memory_data)
|
||||||
updated_memory["facts"] = updated_facts
|
updated_memory["facts"] = updated_facts
|
||||||
|
|
||||||
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
if not _save_memory_to_file(updated_memory, agent_name):
|
||||||
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
|
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
|
||||||
|
|
||||||
return updated_memory
|
return updated_memory
|
||||||
@@ -137,11 +134,9 @@ def update_memory_fact(
|
|||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
confidence: float | None = None,
|
confidence: float | None = None,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
*,
|
|
||||||
user_id: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Update an existing fact and persist the updated memory data."""
|
"""Update an existing fact and persist the updated memory data."""
|
||||||
memory_data = get_memory_data(agent_name, user_id=user_id)
|
memory_data = get_memory_data(agent_name)
|
||||||
updated_memory = dict(memory_data)
|
updated_memory = dict(memory_data)
|
||||||
updated_facts: list[dict[str, Any]] = []
|
updated_facts: list[dict[str, Any]] = []
|
||||||
found = False
|
found = False
|
||||||
@@ -168,7 +163,7 @@ def update_memory_fact(
|
|||||||
|
|
||||||
updated_memory["facts"] = updated_facts
|
updated_memory["facts"] = updated_facts
|
||||||
|
|
||||||
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
|
if not _save_memory_to_file(updated_memory, agent_name):
|
||||||
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
|
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
|
||||||
|
|
||||||
return updated_memory
|
return updated_memory
|
||||||
@@ -281,7 +276,6 @@ class MemoryUpdater:
|
|||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
correction_detected: bool = False,
|
correction_detected: bool = False,
|
||||||
reinforcement_detected: bool = False,
|
reinforcement_detected: bool = False,
|
||||||
user_id: str | None = None,
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Update memory based on conversation messages.
|
"""Update memory based on conversation messages.
|
||||||
|
|
||||||
@@ -291,7 +285,6 @@ class MemoryUpdater:
|
|||||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||||
correction_detected: Whether recent turns include an explicit correction signal.
|
correction_detected: Whether recent turns include an explicit correction signal.
|
||||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||||
user_id: If provided, scopes memory to a specific user.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if update was successful, False otherwise.
|
True if update was successful, False otherwise.
|
||||||
@@ -305,7 +298,7 @@ class MemoryUpdater:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Get current memory
|
# Get current memory
|
||||||
current_memory = get_memory_data(agent_name, user_id=user_id)
|
current_memory = get_memory_data(agent_name)
|
||||||
|
|
||||||
# Format conversation for prompt
|
# Format conversation for prompt
|
||||||
conversation_text = format_conversation_for_update(messages)
|
conversation_text = format_conversation_for_update(messages)
|
||||||
@@ -360,7 +353,7 @@ class MemoryUpdater:
|
|||||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||||
|
|
||||||
# Save
|
# Save
|
||||||
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
|
return get_memory_storage().save(updated_memory, agent_name)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||||
@@ -462,7 +455,6 @@ def update_memory_from_conversation(
|
|||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
correction_detected: bool = False,
|
correction_detected: bool = False,
|
||||||
reinforcement_detected: bool = False,
|
reinforcement_detected: bool = False,
|
||||||
user_id: str | None = None,
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Convenience function to update memory from a conversation.
|
"""Convenience function to update memory from a conversation.
|
||||||
|
|
||||||
@@ -472,10 +464,9 @@ def update_memory_from_conversation(
|
|||||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||||
correction_detected: Whether recent turns include an explicit correction signal.
|
correction_detected: Whether recent turns include an explicit correction signal.
|
||||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||||
user_id: If provided, scopes memory to a specific user.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if successful, False otherwise.
|
True if successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
updater = MemoryUpdater()
|
updater = MemoryUpdater()
|
||||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected, user_id=user_id)
|
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
|
||||||
|
|||||||
@@ -283,7 +283,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|||||||
# the conversation; injecting one mid-conversation crashes
|
# the conversation; injecting one mid-conversation crashes
|
||||||
# langchain_anthropic's _format_messages(). HumanMessage works
|
# langchain_anthropic's _format_messages(). HumanMessage works
|
||||||
# with all providers. See #1299.
|
# with all providers. See #1299.
|
||||||
return {"messages": [HumanMessage(content=warning, name="loop_warning")]}
|
return {"messages": [HumanMessage(content=warning)]}
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from langgraph.runtime import Runtime
|
|||||||
|
|
||||||
from deerflow.agents.memory.queue import get_memory_queue
|
from deerflow.agents.memory.queue import get_memory_queue
|
||||||
from deerflow.config.memory_config import get_memory_config
|
from deerflow.config.memory_config import get_memory_config
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -237,16 +236,11 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
|||||||
# Queue the filtered conversation for memory update
|
# Queue the filtered conversation for memory update
|
||||||
correction_detected = detect_correction(filtered_messages)
|
correction_detected = detect_correction(filtered_messages)
|
||||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||||
# Capture user_id at enqueue time while the request context is still alive.
|
|
||||||
# threading.Timer fires on a different thread where ContextVar values are not
|
|
||||||
# propagated, so we must store user_id explicitly in ConversationContext.
|
|
||||||
user_id = get_effective_user_id()
|
|
||||||
queue = get_memory_queue()
|
queue = get_memory_queue()
|
||||||
queue.add(
|
queue.add(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
messages=filtered_messages,
|
messages=filtered_messages,
|
||||||
agent_name=self._agent_name,
|
agent_name=self._agent_name,
|
||||||
user_id=user_id,
|
|
||||||
correction_detected=correction_detected,
|
correction_detected=correction_detected,
|
||||||
reinforcement_detected=reinforcement_detected,
|
reinforcement_detected=reinforcement_detected,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
from typing import override
|
|
||||||
|
|
||||||
from langchain.agents.middleware import SummarizationMiddleware as BaseSummarizationMiddleware
|
|
||||||
from langchain_core.messages.human import HumanMessage
|
|
||||||
|
|
||||||
|
|
||||||
class SummarizationMiddleware(BaseSummarizationMiddleware):
|
|
||||||
@override
|
|
||||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
|
||||||
"""Override the base implementation to let the human message with the special name 'summary'.
|
|
||||||
And this message will be ignored to display in the frontend, but still can be used as context for the model.
|
|
||||||
"""
|
|
||||||
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
|
|
||||||
@@ -1,16 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime
|
|
||||||
from typing import NotRequired, override
|
from typing import NotRequired, override
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langgraph.config import get_config
|
from langgraph.config import get_config
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from deerflow.agents.thread_state import ThreadDataState
|
from deerflow.agents.thread_state import ThreadDataState
|
||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -49,34 +46,32 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
|||||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||||
self._lazy_init = lazy_init
|
self._lazy_init = lazy_init
|
||||||
|
|
||||||
def _get_thread_paths(self, thread_id: str, user_id: str | None = None) -> dict[str, str]:
|
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
|
||||||
"""Get the paths for a thread's data directories.
|
"""Get the paths for a thread's data directories.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thread_id: The thread ID.
|
thread_id: The thread ID.
|
||||||
user_id: Optional user ID for per-user path isolation.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with workspace_path, uploads_path, and outputs_path.
|
Dictionary with workspace_path, uploads_path, and outputs_path.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"workspace_path": str(self._paths.sandbox_work_dir(thread_id, user_id=user_id)),
|
"workspace_path": str(self._paths.sandbox_work_dir(thread_id)),
|
||||||
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id, user_id=user_id)),
|
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)),
|
||||||
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id, user_id=user_id)),
|
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _create_thread_directories(self, thread_id: str, user_id: str | None = None) -> dict[str, str]:
|
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
|
||||||
"""Create the thread data directories.
|
"""Create the thread data directories.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thread_id: The thread ID.
|
thread_id: The thread ID.
|
||||||
user_id: Optional user ID for per-user path isolation.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with the created directory paths.
|
Dictionary with the created directory paths.
|
||||||
"""
|
"""
|
||||||
self._paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
self._paths.ensure_thread_dirs(thread_id)
|
||||||
return self._get_thread_paths(thread_id, user_id=user_id)
|
return self._get_thread_paths(thread_id)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||||
@@ -89,30 +84,16 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
|||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
raise ValueError("Thread ID is required in runtime context or config.configurable")
|
raise ValueError("Thread ID is required in runtime context or config.configurable")
|
||||||
|
|
||||||
user_id = get_effective_user_id()
|
|
||||||
|
|
||||||
if self._lazy_init:
|
if self._lazy_init:
|
||||||
# Lazy initialization: only compute paths, don't create directories
|
# Lazy initialization: only compute paths, don't create directories
|
||||||
paths = self._get_thread_paths(thread_id, user_id=user_id)
|
paths = self._get_thread_paths(thread_id)
|
||||||
else:
|
else:
|
||||||
# Eager initialization: create directories immediately
|
# Eager initialization: create directories immediately
|
||||||
paths = self._create_thread_directories(thread_id, user_id=user_id)
|
paths = self._create_thread_directories(thread_id)
|
||||||
logger.debug("Created thread data directories for thread %s", thread_id)
|
logger.debug("Created thread data directories for thread %s", thread_id)
|
||||||
|
|
||||||
messages = list(state.get("messages", []))
|
|
||||||
last_message = messages[-1] if messages else None
|
|
||||||
|
|
||||||
if last_message and isinstance(last_message, HumanMessage):
|
|
||||||
messages[-1] = HumanMessage(
|
|
||||||
content=last_message.content,
|
|
||||||
id=last_message.id,
|
|
||||||
name=last_message.name or "user-input",
|
|
||||||
additional_kwargs={**last_message.additional_kwargs, "run_id": runtime.context.get("run_id"), "timestamp": datetime.now(UTC).isoformat()},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"thread_data": {
|
"thread_data": {
|
||||||
**paths,
|
**paths,
|
||||||
},
|
}
|
||||||
"messages": messages,
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from langchain_core.messages import HumanMessage
|
|||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
from deerflow.utils.file_conversion import extract_outline
|
from deerflow.utils.file_conversion import extract_outline
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -222,7 +221,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
|||||||
thread_id = get_config().get("configurable", {}).get("thread_id")
|
thread_id = get_config().get("configurable", {}).get("thread_id")
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
||||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) if thread_id else None
|
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
||||||
|
|
||||||
# Get newly uploaded files from the current message's additional_kwargs.files
|
# Get newly uploaded files from the current message's additional_kwargs.files
|
||||||
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
|
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
|
||||||
@@ -279,7 +278,6 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
|||||||
updated_message = HumanMessage(
|
updated_message = HumanMessage(
|
||||||
content=f"{files_message}\n\n{original_content}",
|
content=f"{files_message}\n\n{original_content}",
|
||||||
id=last_message.id,
|
id=last_message.id,
|
||||||
name=last_message.name,
|
|
||||||
additional_kwargs=last_message.additional_kwargs,
|
additional_kwargs=last_message.additional_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ from deerflow.config.app_config import get_app_config, reload_app_config
|
|||||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
from deerflow.skills.installer import install_skill_from_archive
|
from deerflow.skills.installer import install_skill_from_archive
|
||||||
from deerflow.uploads.manager import (
|
from deerflow.uploads.manager import (
|
||||||
claim_unique_filename,
|
claim_unique_filename,
|
||||||
@@ -241,7 +240,7 @@ class DeerFlowClient:
|
|||||||
}
|
}
|
||||||
checkpointer = self._checkpointer
|
checkpointer = self._checkpointer
|
||||||
if checkpointer is None:
|
if checkpointer is None:
|
||||||
from deerflow.runtime.checkpointer import get_checkpointer
|
from deerflow.agents.checkpointer import get_checkpointer
|
||||||
|
|
||||||
checkpointer = get_checkpointer()
|
checkpointer = get_checkpointer()
|
||||||
if checkpointer is not None:
|
if checkpointer is not None:
|
||||||
@@ -375,7 +374,7 @@ class DeerFlowClient:
|
|||||||
"""
|
"""
|
||||||
checkpointer = self._checkpointer
|
checkpointer = self._checkpointer
|
||||||
if checkpointer is None:
|
if checkpointer is None:
|
||||||
from deerflow.runtime.checkpointer.provider import get_checkpointer
|
from deerflow.agents.checkpointer.provider import get_checkpointer
|
||||||
|
|
||||||
checkpointer = get_checkpointer()
|
checkpointer = get_checkpointer()
|
||||||
|
|
||||||
@@ -430,7 +429,7 @@ class DeerFlowClient:
|
|||||||
"""
|
"""
|
||||||
checkpointer = self._checkpointer
|
checkpointer = self._checkpointer
|
||||||
if checkpointer is None:
|
if checkpointer is None:
|
||||||
from deerflow.runtime.checkpointer.provider import get_checkpointer
|
from deerflow.agents.checkpointer.provider import get_checkpointer
|
||||||
|
|
||||||
checkpointer = get_checkpointer()
|
checkpointer = get_checkpointer()
|
||||||
|
|
||||||
@@ -770,19 +769,19 @@ class DeerFlowClient:
|
|||||||
"""
|
"""
|
||||||
from deerflow.agents.memory.updater import get_memory_data
|
from deerflow.agents.memory.updater import get_memory_data
|
||||||
|
|
||||||
return get_memory_data(user_id=get_effective_user_id())
|
return get_memory_data()
|
||||||
|
|
||||||
def export_memory(self) -> dict:
|
def export_memory(self) -> dict:
|
||||||
"""Export current memory data for backup or transfer."""
|
"""Export current memory data for backup or transfer."""
|
||||||
from deerflow.agents.memory.updater import get_memory_data
|
from deerflow.agents.memory.updater import get_memory_data
|
||||||
|
|
||||||
return get_memory_data(user_id=get_effective_user_id())
|
return get_memory_data()
|
||||||
|
|
||||||
def import_memory(self, memory_data: dict) -> dict:
|
def import_memory(self, memory_data: dict) -> dict:
|
||||||
"""Import and persist full memory data."""
|
"""Import and persist full memory data."""
|
||||||
from deerflow.agents.memory.updater import import_memory_data
|
from deerflow.agents.memory.updater import import_memory_data
|
||||||
|
|
||||||
return import_memory_data(memory_data, user_id=get_effective_user_id())
|
return import_memory_data(memory_data)
|
||||||
|
|
||||||
def get_model(self, name: str) -> dict | None:
|
def get_model(self, name: str) -> dict | None:
|
||||||
"""Get a specific model's configuration by name.
|
"""Get a specific model's configuration by name.
|
||||||
@@ -957,13 +956,13 @@ class DeerFlowClient:
|
|||||||
"""
|
"""
|
||||||
from deerflow.agents.memory.updater import reload_memory_data
|
from deerflow.agents.memory.updater import reload_memory_data
|
||||||
|
|
||||||
return reload_memory_data(user_id=get_effective_user_id())
|
return reload_memory_data()
|
||||||
|
|
||||||
def clear_memory(self) -> dict:
|
def clear_memory(self) -> dict:
|
||||||
"""Clear all persisted memory data."""
|
"""Clear all persisted memory data."""
|
||||||
from deerflow.agents.memory.updater import clear_memory_data
|
from deerflow.agents.memory.updater import clear_memory_data
|
||||||
|
|
||||||
return clear_memory_data(user_id=get_effective_user_id())
|
return clear_memory_data()
|
||||||
|
|
||||||
def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict:
|
def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict:
|
||||||
"""Create a single fact manually."""
|
"""Create a single fact manually."""
|
||||||
@@ -1180,7 +1179,7 @@ class DeerFlowClient:
|
|||||||
ValueError: If the path is invalid.
|
ValueError: If the path is invalid.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
actual = get_paths().resolve_virtual_path(thread_id, path, user_id=get_effective_user_id())
|
actual = get_paths().resolve_virtual_path(thread_id, path)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
if "traversal" in str(exc):
|
if "traversal" in str(exc):
|
||||||
from deerflow.uploads.manager import PathTraversalError
|
from deerflow.uploads.manager import PathTraversalError
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ except ImportError: # pragma: no cover - Windows fallback
|
|||||||
|
|
||||||
from deerflow.config import get_app_config
|
from deerflow.config import get_app_config
|
||||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
from deerflow.sandbox.sandbox import Sandbox
|
from deerflow.sandbox.sandbox import Sandbox
|
||||||
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
||||||
|
|
||||||
@@ -261,16 +260,15 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
mounted Docker socket (DooD), the host Docker daemon can resolve the paths.
|
mounted Docker socket (DooD), the host Docker daemon can resolve the paths.
|
||||||
"""
|
"""
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
user_id = get_effective_user_id()
|
paths.ensure_thread_dirs(thread_id)
|
||||||
paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
(paths.host_sandbox_work_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
|
(paths.host_sandbox_work_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
|
||||||
(paths.host_sandbox_uploads_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
|
(paths.host_sandbox_uploads_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
|
||||||
(paths.host_sandbox_outputs_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
|
(paths.host_sandbox_outputs_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
|
||||||
# ACP workspace: read-only inside the sandbox (lead agent reads results;
|
# ACP workspace: read-only inside the sandbox (lead agent reads results;
|
||||||
# the ACP subprocess writes from the host side, not from within the container).
|
# the ACP subprocess writes from the host side, not from within the container).
|
||||||
(paths.host_acp_workspace_dir(thread_id, user_id=user_id), "/mnt/acp-workspace", True),
|
(paths.host_acp_workspace_dir(thread_id), "/mnt/acp-workspace", True),
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -482,9 +480,8 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
across multiple processes, preventing container-name conflicts.
|
across multiple processes, preventing container-name conflicts.
|
||||||
"""
|
"""
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
user_id = get_effective_user_id()
|
paths.ensure_thread_dirs(thread_id)
|
||||||
paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
lock_path = paths.thread_dir(thread_id) / f"{sandbox_id}.lock"
|
||||||
lock_path = paths.thread_dir(thread_id, user_id=user_id) / f"{sandbox_id}.lock"
|
|
||||||
|
|
||||||
with open(lock_path, "a", encoding="utf-8") as lock_file:
|
with open(lock_path, "a", encoding="utf-8") as lock_file:
|
||||||
locked = False
|
locked = False
|
||||||
|
|||||||
@@ -4,12 +4,8 @@ Controls BOTH the LangGraph checkpointer and the DeerFlow application
|
|||||||
persistence layer (runs, threads metadata, users, etc.). The user
|
persistence layer (runs, threads metadata, users, etc.). The user
|
||||||
configures one backend; the system handles physical separation details.
|
configures one backend; the system handles physical separation details.
|
||||||
|
|
||||||
SQLite mode: checkpointer and app share a single .db file
|
SQLite mode: checkpointer and app use different .db files in the same
|
||||||
({sqlite_dir}/deerflow.db) with WAL journal mode enabled on every
|
directory to avoid write-lock contention. This is automatic.
|
||||||
connection. WAL allows concurrent readers and a single writer without
|
|
||||||
blocking, making a unified file safe for both workloads. Writers
|
|
||||||
that contend for the lock wait via the default 5-second sqlite3
|
|
||||||
busy timeout rather than failing immediately.
|
|
||||||
|
|
||||||
Postgres mode: both use the same database URL but maintain independent
|
Postgres mode: both use the same database URL but maintain independent
|
||||||
connection pools with different lifecycles.
|
connection pools with different lifecycles.
|
||||||
@@ -44,7 +40,7 @@ class DatabaseConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
sqlite_dir: str = Field(
|
sqlite_dir: str = Field(
|
||||||
default=".deer-flow/data",
|
default=".deer-flow/data",
|
||||||
description=("Directory for the SQLite database file. Both checkpointer and application data share {sqlite_dir}/deerflow.db."),
|
description=("Directory for SQLite database files. Checkpointer uses {sqlite_dir}/checkpoints.db, application data uses {sqlite_dir}/app.db."),
|
||||||
)
|
)
|
||||||
postgres_url: str = Field(
|
postgres_url: str = Field(
|
||||||
default="",
|
default="",
|
||||||
@@ -73,27 +69,21 @@ class DatabaseConfig(BaseModel):
|
|||||||
|
|
||||||
return str(Path(self.sqlite_dir).resolve())
|
return str(Path(self.sqlite_dir).resolve())
|
||||||
|
|
||||||
@property
|
|
||||||
def sqlite_path(self) -> str:
|
|
||||||
"""Unified SQLite file path shared by checkpointer and app."""
|
|
||||||
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
|
|
||||||
|
|
||||||
# Backward-compatible aliases
|
|
||||||
@property
|
@property
|
||||||
def checkpointer_sqlite_path(self) -> str:
|
def checkpointer_sqlite_path(self) -> str:
|
||||||
"""SQLite file path for the LangGraph checkpointer (alias for sqlite_path)."""
|
"""SQLite file path for the LangGraph checkpointer."""
|
||||||
return self.sqlite_path
|
return os.path.join(self._resolved_sqlite_dir, "checkpoints.db")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def app_sqlite_path(self) -> str:
|
def app_sqlite_path(self) -> str:
|
||||||
"""SQLite file path for application ORM data (alias for sqlite_path)."""
|
"""SQLite file path for application ORM data."""
|
||||||
return self.sqlite_path
|
return os.path.join(self._resolved_sqlite_dir, "app.db")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def app_sqlalchemy_url(self) -> str:
|
def app_sqlalchemy_url(self) -> str:
|
||||||
"""SQLAlchemy async URL for the application ORM engine."""
|
"""SQLAlchemy async URL for the application ORM engine."""
|
||||||
if self.backend == "sqlite":
|
if self.backend == "sqlite":
|
||||||
return f"sqlite+aiosqlite:///{self.sqlite_path}"
|
return f"sqlite+aiosqlite:///{self.app_sqlite_path}"
|
||||||
if self.backend == "postgres":
|
if self.backend == "postgres":
|
||||||
url = self.postgres_url
|
url = self.postgres_url
|
||||||
if url.startswith("postgresql://"):
|
if url.startswith("postgresql://"):
|
||||||
|
|||||||
@@ -14,9 +14,8 @@ class MemoryConfig(BaseModel):
|
|||||||
default="",
|
default="",
|
||||||
description=(
|
description=(
|
||||||
"Path to store memory data. "
|
"Path to store memory data. "
|
||||||
"If empty, defaults to per-user memory at `{base_dir}/users/{user_id}/memory.json`. "
|
"If empty, defaults to `{base_dir}/memory.json` (see Paths.memory_file). "
|
||||||
"Absolute paths are used as-is and opt out of per-user isolation "
|
"Absolute paths are used as-is. "
|
||||||
"(all users share the same file). "
|
|
||||||
"Relative paths are resolved against `Paths.base_dir` "
|
"Relative paths are resolved against `Paths.base_dir` "
|
||||||
"(not the backend working directory). "
|
"(not the backend working directory). "
|
||||||
"Note: if you previously set this to `.deer-flow/memory.json`, "
|
"Note: if you previously set this to `.deer-flow/memory.json`, "
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from pathlib import Path, PureWindowsPath
|
|||||||
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
||||||
|
|
||||||
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||||
_SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
|
||||||
|
|
||||||
|
|
||||||
def _default_local_base_dir() -> Path:
|
def _default_local_base_dir() -> Path:
|
||||||
@@ -23,13 +22,6 @@ def _validate_thread_id(thread_id: str) -> str:
|
|||||||
return thread_id
|
return thread_id
|
||||||
|
|
||||||
|
|
||||||
def _validate_user_id(user_id: str) -> str:
|
|
||||||
"""Validate a user ID before using it in filesystem paths."""
|
|
||||||
if not _SAFE_USER_ID_RE.match(user_id):
|
|
||||||
raise ValueError(f"Invalid user_id {user_id!r}: only alphanumeric characters, hyphens, and underscores are allowed.")
|
|
||||||
return user_id
|
|
||||||
|
|
||||||
|
|
||||||
def _join_host_path(base: str, *parts: str) -> str:
|
def _join_host_path(base: str, *parts: str) -> str:
|
||||||
"""Join host filesystem path segments while preserving native style.
|
"""Join host filesystem path segments while preserving native style.
|
||||||
|
|
||||||
@@ -142,63 +134,44 @@ class Paths:
|
|||||||
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
|
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
|
||||||
return self.agent_dir(name) / "memory.json"
|
return self.agent_dir(name) / "memory.json"
|
||||||
|
|
||||||
def user_dir(self, user_id: str) -> Path:
|
def thread_dir(self, thread_id: str) -> Path:
|
||||||
"""Directory for a specific user: `{base_dir}/users/{user_id}/`."""
|
|
||||||
return self.base_dir / "users" / _validate_user_id(user_id)
|
|
||||||
|
|
||||||
def user_memory_file(self, user_id: str) -> Path:
|
|
||||||
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
|
|
||||||
return self.user_dir(user_id) / "memory.json"
|
|
||||||
|
|
||||||
def user_agent_memory_file(self, user_id: str, agent_name: str) -> Path:
|
|
||||||
"""Per-user per-agent memory: `{base_dir}/users/{user_id}/agents/{name}/memory.json`."""
|
|
||||||
return self.user_dir(user_id) / "agents" / agent_name.lower() / "memory.json"
|
|
||||||
|
|
||||||
def thread_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
|
||||||
"""
|
"""
|
||||||
Host path for a thread's data.
|
Host path for a thread's data: `{base_dir}/threads/{thread_id}/`
|
||||||
|
|
||||||
When *user_id* is provided:
|
|
||||||
`{base_dir}/users/{user_id}/threads/{thread_id}/`
|
|
||||||
Otherwise (legacy layout):
|
|
||||||
`{base_dir}/threads/{thread_id}/`
|
|
||||||
|
|
||||||
This directory contains a `user-data/` subdirectory that is mounted
|
This directory contains a `user-data/` subdirectory that is mounted
|
||||||
as `/mnt/user-data/` inside the sandbox.
|
as `/mnt/user-data/` inside the sandbox.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `thread_id` or `user_id` contains unsafe characters (path
|
ValueError: If `thread_id` contains unsafe characters (path separators
|
||||||
separators or `..`) that could cause directory traversal.
|
or `..`) that could cause directory traversal.
|
||||||
"""
|
"""
|
||||||
if user_id is not None:
|
|
||||||
return self.user_dir(user_id) / "threads" / _validate_thread_id(thread_id)
|
|
||||||
return self.base_dir / "threads" / _validate_thread_id(thread_id)
|
return self.base_dir / "threads" / _validate_thread_id(thread_id)
|
||||||
|
|
||||||
def sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
def sandbox_work_dir(self, thread_id: str) -> Path:
|
||||||
"""
|
"""
|
||||||
Host path for the agent's workspace directory.
|
Host path for the agent's workspace directory.
|
||||||
Host: `{base_dir}/threads/{thread_id}/user-data/workspace/`
|
Host: `{base_dir}/threads/{thread_id}/user-data/workspace/`
|
||||||
Sandbox: `/mnt/user-data/workspace/`
|
Sandbox: `/mnt/user-data/workspace/`
|
||||||
"""
|
"""
|
||||||
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "workspace"
|
return self.thread_dir(thread_id) / "user-data" / "workspace"
|
||||||
|
|
||||||
def sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
def sandbox_uploads_dir(self, thread_id: str) -> Path:
|
||||||
"""
|
"""
|
||||||
Host path for user-uploaded files.
|
Host path for user-uploaded files.
|
||||||
Host: `{base_dir}/threads/{thread_id}/user-data/uploads/`
|
Host: `{base_dir}/threads/{thread_id}/user-data/uploads/`
|
||||||
Sandbox: `/mnt/user-data/uploads/`
|
Sandbox: `/mnt/user-data/uploads/`
|
||||||
"""
|
"""
|
||||||
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "uploads"
|
return self.thread_dir(thread_id) / "user-data" / "uploads"
|
||||||
|
|
||||||
def sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
def sandbox_outputs_dir(self, thread_id: str) -> Path:
|
||||||
"""
|
"""
|
||||||
Host path for agent-generated artifacts.
|
Host path for agent-generated artifacts.
|
||||||
Host: `{base_dir}/threads/{thread_id}/user-data/outputs/`
|
Host: `{base_dir}/threads/{thread_id}/user-data/outputs/`
|
||||||
Sandbox: `/mnt/user-data/outputs/`
|
Sandbox: `/mnt/user-data/outputs/`
|
||||||
"""
|
"""
|
||||||
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "outputs"
|
return self.thread_dir(thread_id) / "user-data" / "outputs"
|
||||||
|
|
||||||
def acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
def acp_workspace_dir(self, thread_id: str) -> Path:
|
||||||
"""
|
"""
|
||||||
Host path for the ACP workspace of a specific thread.
|
Host path for the ACP workspace of a specific thread.
|
||||||
Host: `{base_dir}/threads/{thread_id}/acp-workspace/`
|
Host: `{base_dir}/threads/{thread_id}/acp-workspace/`
|
||||||
@@ -207,43 +180,41 @@ class Paths:
|
|||||||
Each thread gets its own isolated ACP workspace so that concurrent
|
Each thread gets its own isolated ACP workspace so that concurrent
|
||||||
sessions cannot read each other's ACP agent outputs.
|
sessions cannot read each other's ACP agent outputs.
|
||||||
"""
|
"""
|
||||||
return self.thread_dir(thread_id, user_id=user_id) / "acp-workspace"
|
return self.thread_dir(thread_id) / "acp-workspace"
|
||||||
|
|
||||||
def sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
|
def sandbox_user_data_dir(self, thread_id: str) -> Path:
|
||||||
"""
|
"""
|
||||||
Host path for the user-data root.
|
Host path for the user-data root.
|
||||||
Host: `{base_dir}/threads/{thread_id}/user-data/`
|
Host: `{base_dir}/threads/{thread_id}/user-data/`
|
||||||
Sandbox: `/mnt/user-data/`
|
Sandbox: `/mnt/user-data/`
|
||||||
"""
|
"""
|
||||||
return self.thread_dir(thread_id, user_id=user_id) / "user-data"
|
return self.thread_dir(thread_id) / "user-data"
|
||||||
|
|
||||||
def host_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
def host_thread_dir(self, thread_id: str) -> str:
|
||||||
"""Host path for a thread directory, preserving Windows path syntax."""
|
"""Host path for a thread directory, preserving Windows path syntax."""
|
||||||
if user_id is not None:
|
|
||||||
return _join_host_path(self._host_base_dir_str(), "users", _validate_user_id(user_id), "threads", _validate_thread_id(thread_id))
|
|
||||||
return _join_host_path(self._host_base_dir_str(), "threads", _validate_thread_id(thread_id))
|
return _join_host_path(self._host_base_dir_str(), "threads", _validate_thread_id(thread_id))
|
||||||
|
|
||||||
def host_sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
def host_sandbox_user_data_dir(self, thread_id: str) -> str:
|
||||||
"""Host path for a thread's user-data root."""
|
"""Host path for a thread's user-data root."""
|
||||||
return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "user-data")
|
return _join_host_path(self.host_thread_dir(thread_id), "user-data")
|
||||||
|
|
||||||
def host_sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
def host_sandbox_work_dir(self, thread_id: str) -> str:
|
||||||
"""Host path for the workspace mount source."""
|
"""Host path for the workspace mount source."""
|
||||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "workspace")
|
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "workspace")
|
||||||
|
|
||||||
def host_sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
def host_sandbox_uploads_dir(self, thread_id: str) -> str:
|
||||||
"""Host path for the uploads mount source."""
|
"""Host path for the uploads mount source."""
|
||||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "uploads")
|
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "uploads")
|
||||||
|
|
||||||
def host_sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
def host_sandbox_outputs_dir(self, thread_id: str) -> str:
|
||||||
"""Host path for the outputs mount source."""
|
"""Host path for the outputs mount source."""
|
||||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "outputs")
|
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "outputs")
|
||||||
|
|
||||||
def host_acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
|
def host_acp_workspace_dir(self, thread_id: str) -> str:
|
||||||
"""Host path for the ACP workspace mount source."""
|
"""Host path for the ACP workspace mount source."""
|
||||||
return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "acp-workspace")
|
return _join_host_path(self.host_thread_dir(thread_id), "acp-workspace")
|
||||||
|
|
||||||
def ensure_thread_dirs(self, thread_id: str, *, user_id: str | None = None) -> None:
|
def ensure_thread_dirs(self, thread_id: str) -> None:
|
||||||
"""Create all standard sandbox directories for a thread.
|
"""Create all standard sandbox directories for a thread.
|
||||||
|
|
||||||
Directories are created with mode 0o777 so that sandbox containers
|
Directories are created with mode 0o777 so that sandbox containers
|
||||||
@@ -257,24 +228,24 @@ class Paths:
|
|||||||
ACP agent invocation.
|
ACP agent invocation.
|
||||||
"""
|
"""
|
||||||
for d in [
|
for d in [
|
||||||
self.sandbox_work_dir(thread_id, user_id=user_id),
|
self.sandbox_work_dir(thread_id),
|
||||||
self.sandbox_uploads_dir(thread_id, user_id=user_id),
|
self.sandbox_uploads_dir(thread_id),
|
||||||
self.sandbox_outputs_dir(thread_id, user_id=user_id),
|
self.sandbox_outputs_dir(thread_id),
|
||||||
self.acp_workspace_dir(thread_id, user_id=user_id),
|
self.acp_workspace_dir(thread_id),
|
||||||
]:
|
]:
|
||||||
d.mkdir(parents=True, exist_ok=True)
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
d.chmod(0o777)
|
d.chmod(0o777)
|
||||||
|
|
||||||
def delete_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> None:
|
def delete_thread_dir(self, thread_id: str) -> None:
|
||||||
"""Delete all persisted data for a thread.
|
"""Delete all persisted data for a thread.
|
||||||
|
|
||||||
The operation is idempotent: missing thread directories are ignored.
|
The operation is idempotent: missing thread directories are ignored.
|
||||||
"""
|
"""
|
||||||
thread_dir = self.thread_dir(thread_id, user_id=user_id)
|
thread_dir = self.thread_dir(thread_id)
|
||||||
if thread_dir.exists():
|
if thread_dir.exists():
|
||||||
shutil.rmtree(thread_dir)
|
shutil.rmtree(thread_dir)
|
||||||
|
|
||||||
def resolve_virtual_path(self, thread_id: str, virtual_path: str, *, user_id: str | None = None) -> Path:
|
def resolve_virtual_path(self, thread_id: str, virtual_path: str) -> Path:
|
||||||
"""Resolve a sandbox virtual path to the actual host filesystem path.
|
"""Resolve a sandbox virtual path to the actual host filesystem path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -282,7 +253,6 @@ class Paths:
|
|||||||
virtual_path: Virtual path as seen inside the sandbox, e.g.
|
virtual_path: Virtual path as seen inside the sandbox, e.g.
|
||||||
``/mnt/user-data/outputs/report.pdf``.
|
``/mnt/user-data/outputs/report.pdf``.
|
||||||
Leading slashes are stripped before matching.
|
Leading slashes are stripped before matching.
|
||||||
user_id: Optional user ID for user-scoped path resolution.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The resolved absolute host filesystem path.
|
The resolved absolute host filesystem path.
|
||||||
@@ -300,7 +270,7 @@ class Paths:
|
|||||||
raise ValueError(f"Path must start with /{prefix}")
|
raise ValueError(f"Path must start with /{prefix}")
|
||||||
|
|
||||||
relative = stripped[len(prefix) :].lstrip("/")
|
relative = stripped[len(prefix) :].lstrip("/")
|
||||||
base = self.sandbox_user_data_dir(thread_id, user_id=user_id).resolve()
|
base = self.sandbox_user_data_dir(thread_id).resolve()
|
||||||
actual = (base / relative).resolve()
|
actual = (base / relative).resolve()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -98,11 +98,6 @@ async def init_engine(
|
|||||||
# SQLite deployment (TC-UPG-06 in AUTH_TEST_PLAN.md). The companion
|
# SQLite deployment (TC-UPG-06 in AUTH_TEST_PLAN.md). The companion
|
||||||
# ``synchronous=NORMAL`` is the safe-and-fast pairing — fsync only
|
# ``synchronous=NORMAL`` is the safe-and-fast pairing — fsync only
|
||||||
# at WAL checkpoint boundaries instead of every commit.
|
# at WAL checkpoint boundaries instead of every commit.
|
||||||
# Note: we do not set PRAGMA busy_timeout here — Python's sqlite3
|
|
||||||
# driver already defaults to a 5-second busy timeout (see the
|
|
||||||
# ``timeout`` kwarg of ``sqlite3.connect``), and aiosqlite /
|
|
||||||
# SQLAlchemy's aiosqlite dialect inherit that default. Setting
|
|
||||||
# it again would be a no-op.
|
|
||||||
@event.listens_for(_engine.sync_engine, "connect")
|
@event.listens_for(_engine.sync_engine, "connect")
|
||||||
def _enable_sqlite_wal(dbapi_conn, _record): # noqa: ARG001 — SQLAlchemy contract
|
def _enable_sqlite_wal(dbapi_conn, _record): # noqa: ARG001 — SQLAlchemy contract
|
||||||
cursor = dbapi_conn.cursor()
|
cursor = dbapi_conn.cursor()
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from sqlalchemy import DateTime, String, Text, UniqueConstraint
|
from sqlalchemy import DateTime, String, Text
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from deerflow.persistence.base import Base
|
from deerflow.persistence.base import Base
|
||||||
@@ -13,14 +13,10 @@ from deerflow.persistence.base import Base
|
|||||||
class FeedbackRow(Base):
|
class FeedbackRow(Base):
|
||||||
__tablename__ = "feedback"
|
__tablename__ = "feedback"
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
|
|
||||||
)
|
|
||||||
|
|
||||||
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||||
message_id: Mapped[str | None] = mapped_column(String(64))
|
message_id: Mapped[str | None] = mapped_column(String(64))
|
||||||
# message_id is an optional RunEventStore event identifier —
|
# message_id is an optional RunEventStore event identifier —
|
||||||
# allows feedback to target a specific message or the entire run
|
# allows feedback to target a specific message or the entire run
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from sqlalchemy import case, func, select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
from deerflow.persistence.feedback.model import FeedbackRow
|
from deerflow.persistence.feedback.model import FeedbackRow
|
||||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||||
|
|
||||||
|
|
||||||
class FeedbackRepository:
|
class FeedbackRepository:
|
||||||
@@ -33,19 +33,19 @@ class FeedbackRepository:
|
|||||||
run_id: str,
|
run_id: str,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
rating: int,
|
rating: int,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
comment: str | None = None,
|
comment: str | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Create a feedback record. rating must be +1 or -1."""
|
"""Create a feedback record. rating must be +1 or -1."""
|
||||||
if rating not in (1, -1):
|
if rating not in (1, -1):
|
||||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.create")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create")
|
||||||
row = FeedbackRow(
|
row = FeedbackRow(
|
||||||
feedback_id=str(uuid.uuid4()),
|
feedback_id=str(uuid.uuid4()),
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
user_id=resolved_user_id,
|
owner_id=resolved_owner_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
rating=rating,
|
rating=rating,
|
||||||
comment=comment,
|
comment=comment,
|
||||||
@@ -61,14 +61,14 @@ class FeedbackRepository:
|
|||||||
self,
|
self,
|
||||||
feedback_id: str,
|
feedback_id: str,
|
||||||
*,
|
*,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> dict | None:
|
) -> dict | None:
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.get")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(FeedbackRow, feedback_id)
|
row = await session.get(FeedbackRow, feedback_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||||
return None
|
return None
|
||||||
return self._row_to_dict(row)
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
@@ -78,12 +78,12 @@ class FeedbackRepository:
|
|||||||
run_id: str,
|
run_id: str,
|
||||||
*,
|
*,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_run")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run")
|
||||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
||||||
if resolved_user_id is not None:
|
if resolved_owner_id is not None:
|
||||||
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
@@ -94,12 +94,12 @@ class FeedbackRepository:
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread")
|
||||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||||
if resolved_user_id is not None:
|
if resolved_owner_id is not None:
|
||||||
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
@@ -109,97 +109,19 @@ class FeedbackRepository:
|
|||||||
self,
|
self,
|
||||||
feedback_id: str,
|
feedback_id: str,
|
||||||
*,
|
*,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(FeedbackRow, feedback_id)
|
row = await session.get(FeedbackRow, feedback_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return False
|
return False
|
||||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||||
return False
|
return False
|
||||||
await session.delete(row)
|
await session.delete(row)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def upsert(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
run_id: str,
|
|
||||||
thread_id: str,
|
|
||||||
rating: int,
|
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
|
||||||
comment: str | None = None,
|
|
||||||
) -> dict:
|
|
||||||
"""Create or update feedback for (thread_id, run_id, user_id). rating must be +1 or -1."""
|
|
||||||
if rating not in (1, -1):
|
|
||||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.upsert")
|
|
||||||
async with self._sf() as session:
|
|
||||||
stmt = select(FeedbackRow).where(
|
|
||||||
FeedbackRow.thread_id == thread_id,
|
|
||||||
FeedbackRow.run_id == run_id,
|
|
||||||
FeedbackRow.user_id == resolved_user_id,
|
|
||||||
)
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is not None:
|
|
||||||
row.rating = rating
|
|
||||||
row.comment = comment
|
|
||||||
row.created_at = datetime.now(UTC)
|
|
||||||
else:
|
|
||||||
row = FeedbackRow(
|
|
||||||
feedback_id=str(uuid.uuid4()),
|
|
||||||
run_id=run_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
user_id=resolved_user_id,
|
|
||||||
rating=rating,
|
|
||||||
comment=comment,
|
|
||||||
created_at=datetime.now(UTC),
|
|
||||||
)
|
|
||||||
session.add(row)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(row)
|
|
||||||
return self._row_to_dict(row)
|
|
||||||
|
|
||||||
async def delete_by_run(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
|
||||||
) -> bool:
|
|
||||||
"""Delete the current user's feedback for a run. Returns True if a record was deleted."""
|
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete_by_run")
|
|
||||||
async with self._sf() as session:
|
|
||||||
stmt = select(FeedbackRow).where(
|
|
||||||
FeedbackRow.thread_id == thread_id,
|
|
||||||
FeedbackRow.run_id == run_id,
|
|
||||||
FeedbackRow.user_id == resolved_user_id,
|
|
||||||
)
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
return False
|
|
||||||
await session.delete(row)
|
|
||||||
await session.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def list_by_thread_grouped(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
*,
|
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
|
||||||
) -> dict[str, dict]:
|
|
||||||
"""Return feedback grouped by run_id for a thread: {run_id: feedback_dict}."""
|
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread_grouped")
|
|
||||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
|
||||||
if resolved_user_id is not None:
|
|
||||||
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
|
||||||
async with self._sf() as session:
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
return {row.run_id: self._row_to_dict(row) for row in result.scalars()}
|
|
||||||
|
|
||||||
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
|
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
|
||||||
"""Aggregate feedback stats for a run using database-side counting."""
|
"""Aggregate feedback stats for a run using database-side counting."""
|
||||||
stmt = select(
|
stmt = select(
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
script_location = %(here)s
|
script_location = %(here)s
|
||||||
# Default URL for offline mode / autogenerate.
|
# Default URL for offline mode / autogenerate.
|
||||||
# Runtime uses engine from DeerFlow config.
|
# Runtime uses engine from DeerFlow config.
|
||||||
sqlalchemy.url = sqlite+aiosqlite:///./data/deerflow.db
|
sqlalchemy.url = sqlite+aiosqlite:///./data/app.db
|
||||||
|
|
||||||
[loggers]
|
[loggers]
|
||||||
keys = root,sqlalchemy,alembic
|
keys = root,sqlalchemy,alembic
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class RunEventRow(Base):
|
|||||||
# Owner of the conversation this event belongs to. Nullable for data
|
# Owner of the conversation this event belongs to. Nullable for data
|
||||||
# created before auth was introduced; populated by auth middleware on
|
# created before auth was introduced; populated by auth middleware on
|
||||||
# new writes and by the boot-time orphan migration on existing rows.
|
# new writes and by the boot-time orphan migration on existing rows.
|
||||||
user_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||||
# "message" | "trace" | "lifecycle"
|
# "message" | "trace" | "lifecycle"
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class RunRow(Base):
|
|||||||
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
assistant_id: Mapped[str | None] = mapped_column(String(128))
|
assistant_id: Mapped[str | None] = mapped_column(String(128))
|
||||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||||
status: Mapped[str] = mapped_column(String(20), default="pending")
|
status: Mapped[str] = mapped_column(String(20), default="pending")
|
||||||
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
|
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|||||||
|
|
||||||
from deerflow.persistence.run.model import RunRow
|
from deerflow.persistence.run.model import RunRow
|
||||||
from deerflow.runtime.runs.store.base import RunStore
|
from deerflow.runtime.runs.store.base import RunStore
|
||||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||||
|
|
||||||
|
|
||||||
class RunRepository(RunStore):
|
class RunRepository(RunStore):
|
||||||
@@ -69,7 +69,7 @@ class RunRepository(RunStore):
|
|||||||
*,
|
*,
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=None,
|
assistant_id=None,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
status="pending",
|
status="pending",
|
||||||
multitask_strategy="reject",
|
multitask_strategy="reject",
|
||||||
metadata=None,
|
metadata=None,
|
||||||
@@ -78,13 +78,13 @@ class RunRepository(RunStore):
|
|||||||
created_at=None,
|
created_at=None,
|
||||||
follow_up_to_run_id=None,
|
follow_up_to_run_id=None,
|
||||||
):
|
):
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put")
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
row = RunRow(
|
row = RunRow(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
user_id=resolved_user_id,
|
owner_id=resolved_owner_id,
|
||||||
status=status,
|
status=status,
|
||||||
multitask_strategy=multitask_strategy,
|
multitask_strategy=multitask_strategy,
|
||||||
metadata_json=self._safe_json(metadata) or {},
|
metadata_json=self._safe_json(metadata) or {},
|
||||||
@@ -102,14 +102,14 @@ class RunRepository(RunStore):
|
|||||||
self,
|
self,
|
||||||
run_id,
|
run_id,
|
||||||
*,
|
*,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.get")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(RunRow, run_id)
|
row = await session.get(RunRow, run_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||||
return None
|
return None
|
||||||
return self._row_to_dict(row)
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
@@ -117,13 +117,13 @@ class RunRepository(RunStore):
|
|||||||
self,
|
self,
|
||||||
thread_id,
|
thread_id,
|
||||||
*,
|
*,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
limit=100,
|
limit=100,
|
||||||
):
|
):
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.list_by_thread")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread")
|
||||||
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
||||||
if resolved_user_id is not None:
|
if resolved_owner_id is not None:
|
||||||
stmt = stmt.where(RunRow.user_id == resolved_user_id)
|
stmt = stmt.where(RunRow.owner_id == resolved_owner_id)
|
||||||
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
@@ -141,14 +141,14 @@ class RunRepository(RunStore):
|
|||||||
self,
|
self,
|
||||||
run_id,
|
run_id,
|
||||||
*,
|
*,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.delete")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(RunRow, run_id)
|
row = await session.get(RunRow, run_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return
|
return
|
||||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||||
return
|
return
|
||||||
await session.delete(row)
|
await session.delete(row)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -1,38 +1,13 @@
|
|||||||
"""Thread metadata persistence — ORM, abstract store, and concrete implementations."""
|
"""Thread metadata persistence — ORM, abstract store, and concrete implementations."""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||||
from deerflow.persistence.thread_meta.sql import ThreadMetaRepository
|
from deerflow.persistence.thread_meta.sql import ThreadMetaRepository
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from langgraph.store.base import BaseStore
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MemoryThreadMetaStore",
|
"MemoryThreadMetaStore",
|
||||||
"ThreadMetaRepository",
|
"ThreadMetaRepository",
|
||||||
"ThreadMetaRow",
|
"ThreadMetaRow",
|
||||||
"ThreadMetaStore",
|
"ThreadMetaStore",
|
||||||
"make_thread_store",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def make_thread_store(
|
|
||||||
session_factory: async_sessionmaker[AsyncSession] | None,
|
|
||||||
store: BaseStore | None = None,
|
|
||||||
) -> ThreadMetaStore:
|
|
||||||
"""Create the appropriate ThreadMetaStore based on available backends.
|
|
||||||
|
|
||||||
Returns a SQL-backed repository when a session factory is available,
|
|
||||||
otherwise falls back to the in-memory LangGraph Store implementation.
|
|
||||||
"""
|
|
||||||
if session_factory is not None:
|
|
||||||
return ThreadMetaRepository(session_factory)
|
|
||||||
if store is None:
|
|
||||||
raise ValueError("make_thread_store requires either a session_factory (SQL) or a store (memory)")
|
|
||||||
return MemoryThreadMetaStore(store)
|
|
||||||
|
|||||||
@@ -3,21 +3,12 @@
|
|||||||
Implementations:
|
Implementations:
|
||||||
- ThreadMetaRepository: SQL-backed (sqlite / postgres via SQLAlchemy)
|
- ThreadMetaRepository: SQL-backed (sqlite / postgres via SQLAlchemy)
|
||||||
- MemoryThreadMetaStore: wraps LangGraph BaseStore (memory mode)
|
- MemoryThreadMetaStore: wraps LangGraph BaseStore (memory mode)
|
||||||
|
|
||||||
All mutating and querying methods accept a ``user_id`` parameter with
|
|
||||||
three-state semantics (see :mod:`deerflow.runtime.user_context`):
|
|
||||||
|
|
||||||
- ``AUTO`` (default): resolve from the request-scoped contextvar.
|
|
||||||
- Explicit ``str``: use the provided value verbatim.
|
|
||||||
- Explicit ``None``: bypass owner filtering (migration/CLI only).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
|
||||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadMetaStore(abc.ABC):
|
class ThreadMetaStore(abc.ABC):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -26,14 +17,14 @@ class ThreadMetaStore(abc.ABC):
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
assistant_id: str | None = None,
|
assistant_id: str | None = None,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None = None,
|
||||||
display_name: str | None = None,
|
display_name: str | None = None,
|
||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None:
|
async def get(self, thread_id: str) -> dict | None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -44,33 +35,26 @@ class ThreadMetaStore(abc.ABC):
|
|||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
async def update_status(self, thread_id: str, status: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
|
||||||
"""Merge ``metadata`` into the thread's metadata field.
|
"""Merge ``metadata`` into the thread's metadata field.
|
||||||
|
|
||||||
Existing keys are overwritten by the new values; keys absent from
|
Existing keys are overwritten by the new values; keys absent from
|
||||||
``metadata`` are preserved. No-op if the thread does not exist
|
``metadata`` are preserved. No-op if the thread does not exist.
|
||||||
or the owner check fails.
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
async def delete(self, thread_id: str) -> None:
|
||||||
"""Check if ``user_id`` has access to ``thread_id``."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from typing import Any
|
|||||||
from langgraph.store.base import BaseStore
|
from langgraph.store.base import BaseStore
|
||||||
|
|
||||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
|
||||||
|
|
||||||
THREADS_NS: tuple[str, ...] = ("threads",)
|
THREADS_NS: tuple[str, ...] = ("threads",)
|
||||||
|
|
||||||
@@ -22,37 +21,20 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
|||||||
def __init__(self, store: BaseStore) -> None:
|
def __init__(self, store: BaseStore) -> None:
|
||||||
self._store = store
|
self._store = store
|
||||||
|
|
||||||
async def _get_owned_record(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
user_id: str | None | _AutoSentinel,
|
|
||||||
method_name: str,
|
|
||||||
) -> dict | None:
|
|
||||||
"""Fetch a record and verify ownership. Returns a mutable copy, or None."""
|
|
||||||
resolved = resolve_user_id(user_id, method_name=method_name)
|
|
||||||
item = await self._store.aget(THREADS_NS, thread_id)
|
|
||||||
if item is None:
|
|
||||||
return None
|
|
||||||
record = dict(item.value)
|
|
||||||
if resolved is not None and record.get("user_id") != resolved:
|
|
||||||
return None
|
|
||||||
return record
|
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
assistant_id: str | None = None,
|
assistant_id: str | None = None,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None = None,
|
||||||
display_name: str | None = None,
|
display_name: str | None = None,
|
||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create")
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
record: dict[str, Any] = {
|
record: dict[str, Any] = {
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
"assistant_id": assistant_id,
|
"assistant_id": assistant_id,
|
||||||
"user_id": resolved_user_id,
|
"owner_id": owner_id,
|
||||||
"display_name": display_name,
|
"display_name": display_name,
|
||||||
"status": "idle",
|
"status": "idle",
|
||||||
"metadata": metadata or {},
|
"metadata": metadata or {},
|
||||||
@@ -63,8 +45,9 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
|||||||
await self._store.aput(THREADS_NS, thread_id, record)
|
await self._store.aput(THREADS_NS, thread_id, record)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None:
|
async def get(self, thread_id: str) -> dict | None:
|
||||||
return await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.get")
|
item = await self._store.aget(THREADS_NS, thread_id)
|
||||||
|
return item.value if item is not None else None
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
@@ -73,16 +56,12 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
|||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search")
|
|
||||||
filter_dict: dict[str, Any] = {}
|
filter_dict: dict[str, Any] = {}
|
||||||
if metadata:
|
if metadata:
|
||||||
filter_dict.update(metadata)
|
filter_dict.update(metadata)
|
||||||
if status:
|
if status:
|
||||||
filter_dict["status"] = status
|
filter_dict["status"] = status
|
||||||
if resolved_user_id is not None:
|
|
||||||
filter_dict["user_id"] = resolved_user_id
|
|
||||||
|
|
||||||
items = await self._store.asearch(
|
items = await self._store.asearch(
|
||||||
THREADS_NS,
|
THREADS_NS,
|
||||||
@@ -92,45 +71,37 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
|||||||
)
|
)
|
||||||
return [self._item_to_dict(item) for item in items]
|
return [self._item_to_dict(item) for item in items]
|
||||||
|
|
||||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||||
item = await self._store.aget(THREADS_NS, thread_id)
|
item = await self._store.aget(THREADS_NS, thread_id)
|
||||||
if item is None:
|
if item is None:
|
||||||
return not require_existing
|
|
||||||
record_user_id = item.value.get("user_id")
|
|
||||||
if record_user_id is None:
|
|
||||||
return True
|
|
||||||
return record_user_id == user_id
|
|
||||||
|
|
||||||
async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
|
||||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_display_name")
|
|
||||||
if record is None:
|
|
||||||
return
|
return
|
||||||
|
record = dict(item.value)
|
||||||
record["display_name"] = display_name
|
record["display_name"] = display_name
|
||||||
record["updated_at"] = time.time()
|
record["updated_at"] = time.time()
|
||||||
await self._store.aput(THREADS_NS, thread_id, record)
|
await self._store.aput(THREADS_NS, thread_id, record)
|
||||||
|
|
||||||
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
async def update_status(self, thread_id: str, status: str) -> None:
|
||||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_status")
|
item = await self._store.aget(THREADS_NS, thread_id)
|
||||||
if record is None:
|
if item is None:
|
||||||
return
|
return
|
||||||
|
record = dict(item.value)
|
||||||
record["status"] = status
|
record["status"] = status
|
||||||
record["updated_at"] = time.time()
|
record["updated_at"] = time.time()
|
||||||
await self._store.aput(THREADS_NS, thread_id, record)
|
await self._store.aput(THREADS_NS, thread_id, record)
|
||||||
|
|
||||||
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
|
||||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_metadata")
|
"""Merge ``metadata`` into the in-memory record. No-op if absent."""
|
||||||
if record is None:
|
item = await self._store.aget(THREADS_NS, thread_id)
|
||||||
|
if item is None:
|
||||||
return
|
return
|
||||||
|
record = dict(item.value)
|
||||||
merged = dict(record.get("metadata") or {})
|
merged = dict(record.get("metadata") or {})
|
||||||
merged.update(metadata)
|
merged.update(metadata)
|
||||||
record["metadata"] = merged
|
record["metadata"] = merged
|
||||||
record["updated_at"] = time.time()
|
record["updated_at"] = time.time()
|
||||||
await self._store.aput(THREADS_NS, thread_id, record)
|
await self._store.aput(THREADS_NS, thread_id, record)
|
||||||
|
|
||||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
async def delete(self, thread_id: str) -> None:
|
||||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
|
|
||||||
if record is None:
|
|
||||||
return
|
|
||||||
await self._store.adelete(THREADS_NS, thread_id)
|
await self._store.adelete(THREADS_NS, thread_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -140,7 +111,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
|||||||
return {
|
return {
|
||||||
"thread_id": item.key,
|
"thread_id": item.key,
|
||||||
"assistant_id": val.get("assistant_id"),
|
"assistant_id": val.get("assistant_id"),
|
||||||
"user_id": val.get("user_id"),
|
"owner_id": val.get("owner_id"),
|
||||||
"display_name": val.get("display_name"),
|
"display_name": val.get("display_name"),
|
||||||
"status": val.get("status", "idle"),
|
"status": val.get("status", "idle"),
|
||||||
"metadata": val.get("metadata", {}),
|
"metadata": val.get("metadata", {}),
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class ThreadMetaRow(Base):
|
|||||||
|
|
||||||
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
|
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
|
||||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||||
display_name: Mapped[str | None] = mapped_column(String(256))
|
display_name: Mapped[str | None] = mapped_column(String(256))
|
||||||
status: Mapped[str] = mapped_column(String(20), default="idle")
|
status: Mapped[str] = mapped_column(String(20), default="idle")
|
||||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|||||||
|
|
||||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||||
|
|
||||||
|
|
||||||
class ThreadMetaRepository(ThreadMetaStore):
|
class ThreadMetaRepository(ThreadMetaStore):
|
||||||
@@ -32,18 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
assistant_id: str | None = None,
|
assistant_id: str | None = None,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
display_name: str | None = None,
|
display_name: str | None = None,
|
||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# Auto-resolve user_id from contextvar when AUTO; explicit None
|
# Auto-resolve owner_id from contextvar when AUTO; explicit None
|
||||||
# creates an orphan row (used by migration scripts).
|
# creates an orphan row (used by migration scripts).
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.create")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create")
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
row = ThreadMetaRow(
|
row = ThreadMetaRow(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
user_id=resolved_user_id,
|
owner_id=resolved_owner_id,
|
||||||
display_name=display_name,
|
display_name=display_name,
|
||||||
metadata_json=metadata or {},
|
metadata_json=metadata or {},
|
||||||
created_at=now,
|
created_at=now,
|
||||||
@@ -59,34 +59,40 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> dict | None:
|
) -> dict | None:
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.get")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
# Enforce owner filter unless explicitly bypassed (user_id=None).
|
# Enforce owner filter unless explicitly bypassed (owner_id=None).
|
||||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||||
return None
|
return None
|
||||||
return self._row_to_dict(row)
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
|
||||||
"""Check if ``user_id`` has access to ``thread_id``.
|
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
|
||||||
|
async with self._sf() as session:
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return [self._row_to_dict(r) for r in result.scalars()]
|
||||||
|
|
||||||
|
async def check_access(self, thread_id: str, owner_id: str, *, require_existing: bool = False) -> bool:
|
||||||
|
"""Check if ``owner_id`` has access to ``thread_id``.
|
||||||
|
|
||||||
Two modes — one row, two distinct semantics depending on what
|
Two modes — one row, two distinct semantics depending on what
|
||||||
the caller is about to do:
|
the caller is about to do:
|
||||||
|
|
||||||
- ``require_existing=False`` (default, permissive):
|
- ``require_existing=False`` (default, permissive):
|
||||||
Returns True for: row missing (untracked legacy thread),
|
Returns True for: row missing (untracked legacy thread),
|
||||||
``row.user_id`` is None (shared / pre-auth data),
|
``row.owner_id`` is None (shared / pre-auth data),
|
||||||
or ``row.user_id == user_id``. Use for **read-style**
|
or ``row.owner_id == owner_id``. Use for **read-style**
|
||||||
decorators where treating an untracked thread as accessible
|
decorators where treating an untracked thread as accessible
|
||||||
preserves backward-compat.
|
preserves backward-compat.
|
||||||
|
|
||||||
- ``require_existing=True`` (strict):
|
- ``require_existing=True`` (strict):
|
||||||
Returns True **only** when the row exists AND
|
Returns True **only** when the row exists AND
|
||||||
(``row.user_id == user_id`` OR ``row.user_id is None``).
|
(``row.owner_id == owner_id`` OR ``row.owner_id is None``).
|
||||||
Use for **destructive / mutating** decorators (DELETE, PATCH,
|
Use for **destructive / mutating** decorators (DELETE, PATCH,
|
||||||
state-update) so a thread that has *already been deleted*
|
state-update) so a thread that has *already been deleted*
|
||||||
cannot be re-targeted by any caller — closing the
|
cannot be re-targeted by any caller — closing the
|
||||||
@@ -97,9 +103,9 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return not require_existing
|
return not require_existing
|
||||||
if row.user_id is None:
|
if row.owner_id is None:
|
||||||
return True
|
return True
|
||||||
return row.user_id == user_id
|
return row.owner_id == owner_id
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
@@ -108,17 +114,17 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Search threads with optional metadata and status filters.
|
"""Search threads with optional metadata and status filters.
|
||||||
|
|
||||||
Owner filter is enforced by default: caller must be in a user
|
Owner filter is enforced by default: caller must be in a user
|
||||||
context. Pass ``user_id=None`` to bypass (migration/CLI).
|
context. Pass ``owner_id=None`` to bypass (migration/CLI).
|
||||||
"""
|
"""
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search")
|
||||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||||
if resolved_user_id is not None:
|
if resolved_owner_id is not None:
|
||||||
stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id)
|
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id)
|
||||||
if status:
|
if status:
|
||||||
stmt = stmt.where(ThreadMetaRow.status == status)
|
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||||
|
|
||||||
@@ -138,24 +144,24 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [self._row_to_dict(r) for r in result.scalars()]
|
return [self._row_to_dict(r) for r in result.scalars()]
|
||||||
|
|
||||||
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool:
|
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool:
|
||||||
"""Return True if the row exists and is owned (or filter bypassed)."""
|
"""Return True if the row exists and is owned (or filter bypassed)."""
|
||||||
if resolved_user_id is None:
|
if resolved_owner_id is None:
|
||||||
return True # explicit bypass
|
return True # explicit bypass
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
return row is not None and row.user_id == resolved_user_id
|
return row is not None and row.owner_id == resolved_owner_id
|
||||||
|
|
||||||
async def update_display_name(
|
async def update_display_name(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
display_name: str,
|
display_name: str,
|
||||||
*,
|
*,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update the display_name (title) for a thread."""
|
"""Update the display_name (title) for a thread."""
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_display_name")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||||
return
|
return
|
||||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -165,11 +171,11 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
status: str,
|
status: str,
|
||||||
*,
|
*,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_status")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||||
return
|
return
|
||||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -179,20 +185,20 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
metadata: dict,
|
metadata: dict,
|
||||||
*,
|
*,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Merge ``metadata`` into ``metadata_json``.
|
"""Merge ``metadata`` into ``metadata_json``.
|
||||||
|
|
||||||
Read-modify-write inside a single session/transaction so concurrent
|
Read-modify-write inside a single session/transaction so concurrent
|
||||||
callers see consistent state. No-op if the row does not exist or
|
callers see consistent state. No-op if the row does not exist or
|
||||||
the user_id check fails.
|
the owner_id check fails.
|
||||||
"""
|
"""
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_metadata")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return
|
return
|
||||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||||
return
|
return
|
||||||
merged = dict(row.metadata_json or {})
|
merged = dict(row.metadata_json or {})
|
||||||
merged.update(metadata)
|
merged.update(metadata)
|
||||||
@@ -204,14 +210,14 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.delete")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return
|
return
|
||||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||||
return
|
return
|
||||||
await session.delete(row)
|
await session.delete(row)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -5,18 +5,12 @@ Re-exports the public API of :mod:`~deerflow.runtime.runs` and
|
|||||||
directly from ``deerflow.runtime``.
|
directly from ``deerflow.runtime``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .checkpointer import checkpointer_context, get_checkpointer, make_checkpointer, reset_checkpointer
|
|
||||||
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
|
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
|
||||||
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
|
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
|
||||||
from .store import get_store, make_store, reset_store, store_context
|
from .store import get_store, make_store, reset_store, store_context
|
||||||
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
|
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# checkpointer
|
|
||||||
"checkpointer_context",
|
|
||||||
"get_checkpointer",
|
|
||||||
"make_checkpointer",
|
|
||||||
"reset_checkpointer",
|
|
||||||
# runs
|
# runs
|
||||||
"ConflictError",
|
"ConflictError",
|
||||||
"DisconnectMode",
|
"DisconnectMode",
|
||||||
|
|||||||
@@ -83,18 +83,8 @@ class RunEventStore(abc.ABC):
|
|||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
run_id: str,
|
run_id: str,
|
||||||
*,
|
|
||||||
limit: int = 50,
|
|
||||||
before_seq: int | None = None,
|
|
||||||
after_seq: int | None = None,
|
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending.
|
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending."""
|
||||||
|
|
||||||
Supports bidirectional cursor pagination:
|
|
||||||
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
|
|
||||||
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
|
|
||||||
- neither: return the latest ``limit`` records (ascending)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def count_messages(self, thread_id: str) -> int:
|
async def count_messages(self, thread_id: str) -> int:
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|||||||
|
|
||||||
from deerflow.persistence.models.run_event import RunEventRow
|
from deerflow.persistence.models.run_event import RunEventRow
|
||||||
from deerflow.runtime.events.store.base import RunEventStore
|
from deerflow.runtime.events.store.base import RunEventStore
|
||||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -55,22 +55,16 @@ class DbRunEventStore(RunEventStore):
|
|||||||
return content, metadata or {}
|
return content, metadata or {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _user_id_from_context() -> str | None:
|
def _owner_from_context() -> str | None:
|
||||||
"""Soft read of user_id from contextvar for write paths.
|
"""Soft read of owner_id from contextvar for write paths.
|
||||||
|
|
||||||
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
||||||
which is the expected case for background worker writes. HTTP
|
which is the expected case for background worker writes. HTTP
|
||||||
request writes will have the contextvar set by auth middleware
|
request writes will have the contextvar set by auth middleware
|
||||||
and get their user_id stamped automatically.
|
and get their user_id stamped automatically.
|
||||||
|
|
||||||
Coerces ``user.id`` to ``str`` at the boundary: ``User.id`` is
|
|
||||||
typed as ``UUID`` by the auth layer, but ``run_events.user_id``
|
|
||||||
is ``VARCHAR(64)`` and aiosqlite cannot bind a raw UUID object
|
|
||||||
to a VARCHAR column ("type 'UUID' is not supported") — the
|
|
||||||
INSERT would silently roll back and the worker would hang.
|
|
||||||
"""
|
"""
|
||||||
user = get_current_user()
|
user = get_current_user()
|
||||||
return str(user.id) if user is not None else None
|
return user.id if user is not None else None
|
||||||
|
|
||||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||||
"""Write a single event — low-frequency path only.
|
"""Write a single event — low-frequency path only.
|
||||||
@@ -87,7 +81,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||||
else:
|
else:
|
||||||
db_content = content
|
db_content = content
|
||||||
user_id = self._user_id_from_context()
|
owner_id = self._owner_from_context()
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||||
@@ -98,7 +92,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
row = RunEventRow(
|
row = RunEventRow(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
user_id=user_id,
|
owner_id=owner_id,
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
category=category,
|
category=category,
|
||||||
content=db_content,
|
content=db_content,
|
||||||
@@ -112,7 +106,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
async def put_batch(self, events):
|
async def put_batch(self, events):
|
||||||
if not events:
|
if not events:
|
||||||
return []
|
return []
|
||||||
user_id = self._user_id_from_context()
|
owner_id = self._owner_from_context()
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||||
@@ -136,7 +130,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
row = RunEventRow(
|
row = RunEventRow(
|
||||||
thread_id=e["thread_id"],
|
thread_id=e["thread_id"],
|
||||||
run_id=e["run_id"],
|
run_id=e["run_id"],
|
||||||
user_id=e.get("user_id", user_id),
|
owner_id=e.get("owner_id", owner_id),
|
||||||
event_type=e["event_type"],
|
event_type=e["event_type"],
|
||||||
category=category,
|
category=category,
|
||||||
content=db_content,
|
content=db_content,
|
||||||
@@ -155,12 +149,12 @@ class DbRunEventStore(RunEventStore):
|
|||||||
limit=50,
|
limit=50,
|
||||||
before_seq=None,
|
before_seq=None,
|
||||||
after_seq=None,
|
after_seq=None,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages")
|
||||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||||
if resolved_user_id is not None:
|
if resolved_owner_id is not None:
|
||||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||||
if before_seq is not None:
|
if before_seq is not None:
|
||||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||||
if after_seq is not None:
|
if after_seq is not None:
|
||||||
@@ -187,12 +181,12 @@ class DbRunEventStore(RunEventStore):
|
|||||||
*,
|
*,
|
||||||
event_types=None,
|
event_types=None,
|
||||||
limit=500,
|
limit=500,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events")
|
||||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||||
if resolved_user_id is not None:
|
if resolved_owner_id is not None:
|
||||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||||
if event_types:
|
if event_types:
|
||||||
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
||||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||||
@@ -205,46 +199,27 @@ class DbRunEventStore(RunEventStore):
|
|||||||
thread_id,
|
thread_id,
|
||||||
run_id,
|
run_id,
|
||||||
*,
|
*,
|
||||||
limit=50,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
before_seq=None,
|
|
||||||
after_seq=None,
|
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
|
||||||
):
|
):
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||||
stmt = select(RunEventRow).where(
|
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
|
||||||
RunEventRow.thread_id == thread_id,
|
if resolved_owner_id is not None:
|
||||||
RunEventRow.run_id == run_id,
|
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||||
RunEventRow.category == "message",
|
stmt = stmt.order_by(RunEventRow.seq.asc())
|
||||||
)
|
async with self._sf() as session:
|
||||||
if resolved_user_id is not None:
|
result = await session.execute(stmt)
|
||||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
return [self._row_to_dict(r) for r in result.scalars()]
|
||||||
if before_seq is not None:
|
|
||||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
|
||||||
if after_seq is not None:
|
|
||||||
stmt = stmt.where(RunEventRow.seq > after_seq)
|
|
||||||
|
|
||||||
if after_seq is not None:
|
|
||||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
|
||||||
async with self._sf() as session:
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
return [self._row_to_dict(r) for r in result.scalars()]
|
|
||||||
else:
|
|
||||||
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
|
|
||||||
async with self._sf() as session:
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
rows = list(result.scalars())
|
|
||||||
return [self._row_to_dict(r) for r in reversed(rows)]
|
|
||||||
|
|
||||||
async def count_messages(
|
async def count_messages(
|
||||||
self,
|
self,
|
||||||
thread_id,
|
thread_id,
|
||||||
*,
|
*,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages")
|
||||||
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||||
if resolved_user_id is not None:
|
if resolved_owner_id is not None:
|
||||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
return await session.scalar(stmt) or 0
|
return await session.scalar(stmt) or 0
|
||||||
|
|
||||||
@@ -252,13 +227,13 @@ class DbRunEventStore(RunEventStore):
|
|||||||
self,
|
self,
|
||||||
thread_id,
|
thread_id,
|
||||||
*,
|
*,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
count_conditions = [RunEventRow.thread_id == thread_id]
|
count_conditions = [RunEventRow.thread_id == thread_id]
|
||||||
if resolved_user_id is not None:
|
if resolved_owner_id is not None:
|
||||||
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||||
count = await session.scalar(count_stmt) or 0
|
count = await session.scalar(count_stmt) or 0
|
||||||
if count > 0:
|
if count > 0:
|
||||||
@@ -271,13 +246,13 @@ class DbRunEventStore(RunEventStore):
|
|||||||
thread_id,
|
thread_id,
|
||||||
run_id,
|
run_id,
|
||||||
*,
|
*,
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
owner_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run")
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
||||||
if resolved_user_id is not None:
|
if resolved_owner_id is not None:
|
||||||
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||||
count = await session.scalar(count_stmt) or 0
|
count = await session.scalar(count_stmt) or 0
|
||||||
if count > 0:
|
if count > 0:
|
||||||
|
|||||||
@@ -152,17 +152,9 @@ class JsonlRunEventStore(RunEventStore):
|
|||||||
events = [e for e in events if e.get("event_type") in event_types]
|
events = [e for e in events if e.get("event_type") in event_types]
|
||||||
return events[:limit]
|
return events[:limit]
|
||||||
|
|
||||||
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
|
async def list_messages_by_run(self, thread_id, run_id):
|
||||||
events = self._read_run_events(thread_id, run_id)
|
events = self._read_run_events(thread_id, run_id)
|
||||||
filtered = [e for e in events if e.get("category") == "message"]
|
return [e for e in events if e.get("category") == "message"]
|
||||||
if before_seq is not None:
|
|
||||||
filtered = [e for e in filtered if e.get("seq", 0) < before_seq]
|
|
||||||
if after_seq is not None:
|
|
||||||
filtered = [e for e in filtered if e.get("seq", 0) > after_seq]
|
|
||||||
if after_seq is not None:
|
|
||||||
return filtered[:limit]
|
|
||||||
else:
|
|
||||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
|
||||||
|
|
||||||
async def count_messages(self, thread_id):
|
async def count_messages(self, thread_id):
|
||||||
all_events = self._read_thread_events(thread_id)
|
all_events = self._read_thread_events(thread_id)
|
||||||
|
|||||||
@@ -97,17 +97,9 @@ class MemoryRunEventStore(RunEventStore):
|
|||||||
filtered = [e for e in filtered if e["event_type"] in event_types]
|
filtered = [e for e in filtered if e["event_type"] in event_types]
|
||||||
return filtered[:limit]
|
return filtered[:limit]
|
||||||
|
|
||||||
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
|
async def list_messages_by_run(self, thread_id, run_id):
|
||||||
all_events = self._events.get(thread_id, [])
|
all_events = self._events.get(thread_id, [])
|
||||||
filtered = [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
|
return [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
|
||||||
if before_seq is not None:
|
|
||||||
filtered = [e for e in filtered if e["seq"] < before_seq]
|
|
||||||
if after_seq is not None:
|
|
||||||
filtered = [e for e in filtered if e["seq"] > after_seq]
|
|
||||||
if after_seq is not None:
|
|
||||||
return filtered[:limit]
|
|
||||||
else:
|
|
||||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
|
||||||
|
|
||||||
async def count_messages(self, thread_id):
|
async def count_messages(self, thread_id):
|
||||||
all_events = self._events.get(thread_id, [])
|
all_events = self._events.get(thread_id, [])
|
||||||
|
|||||||
@@ -6,10 +6,7 @@ handles token usage accumulation.
|
|||||||
|
|
||||||
Key design decisions:
|
Key design decisions:
|
||||||
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
||||||
- on_chat_model_start captures structured prompts as llm_request (OpenAI format) and
|
- on_chat_model_start captures structured prompts as llm_request (OpenAI format)
|
||||||
extracts the first human message for run.input, because it is more reliable than
|
|
||||||
on_chain_start (fires on every node) — messages here are fully structured.
|
|
||||||
- on_chain_start with parent_run_id=None emits a run.start trace marking root invocation.
|
|
||||||
- on_llm_end emits llm_response in OpenAI Chat Completions format
|
- on_llm_end emits llm_response in OpenAI Chat Completions format
|
||||||
- Token usage accumulated in memory, written to RunRow on run completion
|
- Token usage accumulated in memory, written to RunRow on run completion
|
||||||
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
||||||
@@ -21,12 +18,10 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
|
||||||
from langgraph.types import Command
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from deerflow.runtime.events.store.base import RunEventStore
|
from deerflow.runtime.events.store.base import RunEventStore
|
||||||
@@ -55,7 +50,6 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
|
|
||||||
# Write buffer
|
# Write buffer
|
||||||
self._buffer: list[dict] = []
|
self._buffer: list[dict] = []
|
||||||
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
|
|
||||||
|
|
||||||
# Token accumulators
|
# Token accumulators
|
||||||
self._total_input_tokens = 0
|
self._total_input_tokens = 0
|
||||||
@@ -77,39 +71,34 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
# LLM request/response tracking
|
# LLM request/response tracking
|
||||||
self._llm_call_index = 0
|
self._llm_call_index = 0
|
||||||
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
|
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
|
||||||
|
self._cached_models: dict[str, str] = {} # langchain run_id -> model name
|
||||||
|
|
||||||
|
# Tool call ID cache
|
||||||
|
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
|
||||||
|
|
||||||
# -- Lifecycle callbacks --
|
# -- Lifecycle callbacks --
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
self,
|
if kwargs.get("parent_run_id") is not None:
|
||||||
serialized: dict[str, Any],
|
return
|
||||||
inputs: dict[str, Any],
|
self._put(
|
||||||
*,
|
event_type="run_start",
|
||||||
run_id: UUID,
|
category="lifecycle",
|
||||||
parent_run_id: UUID | None = None,
|
metadata={"input_preview": str(inputs)[:500]},
|
||||||
tags: list[str] | None = None,
|
)
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
caller = self._identify_caller(tags)
|
|
||||||
if parent_run_id is None:
|
|
||||||
# Root graph invocation — emit a single trace event for the run start.
|
|
||||||
chain_name = (serialized or {}).get("name", "unknown")
|
|
||||||
self._put(
|
|
||||||
event_type="run.start",
|
|
||||||
category="trace",
|
|
||||||
content={"chain": chain_name},
|
|
||||||
metadata={"caller": caller, **(metadata or {})},
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
if kwargs.get("parent_run_id") is not None:
|
||||||
|
return
|
||||||
|
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
|
||||||
self._flush_sync()
|
self._flush_sync()
|
||||||
|
|
||||||
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
|
if kwargs.get("parent_run_id") is not None:
|
||||||
|
return
|
||||||
self._put(
|
self._put(
|
||||||
event_type="run.error",
|
event_type="run_error",
|
||||||
category="error",
|
category="lifecycle",
|
||||||
content=str(error),
|
content=str(error),
|
||||||
metadata={"error_type": type(error).__name__},
|
metadata={"error_type": type(error).__name__},
|
||||||
)
|
)
|
||||||
@@ -117,132 +106,253 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
|
|
||||||
# -- LLM callbacks --
|
# -- LLM callbacks --
|
||||||
|
|
||||||
def on_chat_model_start(
|
def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
self,
|
"""Capture structured prompt messages for llm_request event."""
|
||||||
serialized: dict,
|
from deerflow.runtime.converters import langchain_messages_to_openai
|
||||||
messages: list[list[BaseMessage]],
|
|
||||||
*,
|
|
||||||
run_id: UUID,
|
|
||||||
tags: list[str] | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Capture structured prompt messages for llm_request event.
|
|
||||||
|
|
||||||
This is also the canonical place to extract the first human message:
|
|
||||||
messages are fully structured here, it fires only on real LLM calls,
|
|
||||||
and the content is never compressed by checkpoint trimming.
|
|
||||||
"""
|
|
||||||
rid = str(run_id)
|
rid = str(run_id)
|
||||||
self._llm_start_times[rid] = time.monotonic()
|
self._llm_start_times[rid] = time.monotonic()
|
||||||
self._llm_call_index += 1
|
self._llm_call_index += 1
|
||||||
# Mark this run_id as seen so on_llm_end knows not to increment again.
|
|
||||||
self._cached_prompts[rid] = []
|
|
||||||
|
|
||||||
logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}")
|
model_name = serialized.get("name", "")
|
||||||
|
self._cached_models[rid] = model_name
|
||||||
|
|
||||||
# Capture the first human message sent to any LLM in this run.
|
# Convert the first message list (LangChain passes list-of-lists)
|
||||||
if not self._first_human_msg:
|
prompt_msgs = messages[0] if messages else []
|
||||||
for batch in messages.reversed():
|
openai_msgs = langchain_messages_to_openai(prompt_msgs)
|
||||||
for m in batch.reversed():
|
self._cached_prompts[rid] = openai_msgs
|
||||||
if isinstance(m, HumanMessage) and m.name != "summary":
|
|
||||||
caller = self._identify_caller(tags)
|
|
||||||
self.set_first_human_message(m.text)
|
|
||||||
self._put(
|
|
||||||
event_type="llm.human.input",
|
|
||||||
category="message",
|
|
||||||
content=m.model_dump(),
|
|
||||||
metadata={"caller": caller},
|
|
||||||
)
|
|
||||||
break
|
|
||||||
if self._first_human_msg:
|
|
||||||
break
|
|
||||||
|
|
||||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
caller = self._identify_caller(kwargs)
|
||||||
|
self._put(
|
||||||
|
event_type="llm_request",
|
||||||
|
category="trace",
|
||||||
|
content={"model": model_name, "messages": openai_msgs},
|
||||||
|
metadata={"caller": caller, "llm_call_index": self._llm_call_index},
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
# Fallback: on_chat_model_start is preferred. This just tracks latency.
|
# Fallback: on_chat_model_start is preferred. This just tracks latency.
|
||||||
self._llm_start_times[str(run_id)] = time.monotonic()
|
self._llm_start_times[str(run_id)] = time.monotonic()
|
||||||
|
|
||||||
def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None:
|
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
messages: list[AnyMessage] = []
|
from deerflow.runtime.converters import langchain_to_openai_completion
|
||||||
logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}")
|
|
||||||
for generation in response.generations:
|
|
||||||
for gen in generation:
|
|
||||||
if hasattr(gen, "message"):
|
|
||||||
messages.append(gen.message)
|
|
||||||
else:
|
|
||||||
logger.warning(f"on_llm_end {run_id}: generation has no message attribute: {gen}")
|
|
||||||
|
|
||||||
for message in messages:
|
try:
|
||||||
caller = self._identify_caller(tags)
|
message = response.generations[0][0].message
|
||||||
|
except (IndexError, AttributeError):
|
||||||
|
logger.debug("on_llm_end: could not extract message from response")
|
||||||
|
return
|
||||||
|
|
||||||
# Latency
|
caller = self._identify_caller(kwargs)
|
||||||
rid = str(run_id)
|
|
||||||
start = self._llm_start_times.pop(rid, None)
|
|
||||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
|
||||||
|
|
||||||
# Token usage from message
|
# Latency
|
||||||
usage = getattr(message, "usage_metadata", None)
|
rid = str(run_id)
|
||||||
usage_dict = dict(usage) if usage else {}
|
start = self._llm_start_times.pop(rid, None)
|
||||||
|
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||||
|
|
||||||
# Resolve call index
|
# Token usage from message
|
||||||
|
usage = getattr(message, "usage_metadata", None)
|
||||||
|
usage_dict = dict(usage) if usage else {}
|
||||||
|
|
||||||
|
# Resolve call index
|
||||||
|
call_index = self._llm_call_index
|
||||||
|
if rid not in self._cached_prompts:
|
||||||
|
# Fallback: on_chat_model_start was not called
|
||||||
|
self._llm_call_index += 1
|
||||||
call_index = self._llm_call_index
|
call_index = self._llm_call_index
|
||||||
if rid not in self._cached_prompts:
|
|
||||||
# Fallback: on_chat_model_start was not called
|
|
||||||
self._llm_call_index += 1
|
|
||||||
call_index = self._llm_call_index
|
|
||||||
|
|
||||||
# Trace event: llm_response (OpenAI completion format)
|
# Clean up caches
|
||||||
self._put(
|
self._cached_prompts.pop(rid, None)
|
||||||
event_type="llm.ai.response",
|
self._cached_models.pop(rid, None)
|
||||||
category="message",
|
|
||||||
content=message.model_dump(),
|
|
||||||
metadata={
|
|
||||||
"caller": caller,
|
|
||||||
"usage": usage_dict,
|
|
||||||
"latency_ms": latency_ms,
|
|
||||||
"llm_call_index": call_index,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Token accumulation
|
# Trace event: llm_response (OpenAI completion format)
|
||||||
if self._track_tokens:
|
content = getattr(message, "content", "")
|
||||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
self._put(
|
||||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
event_type="llm_response",
|
||||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
category="trace",
|
||||||
if total_tk == 0:
|
content=langchain_to_openai_completion(message),
|
||||||
total_tk = input_tk + output_tk
|
metadata={
|
||||||
if total_tk > 0:
|
"caller": caller,
|
||||||
self._total_input_tokens += input_tk
|
"usage": usage_dict,
|
||||||
self._total_output_tokens += output_tk
|
"latency_ms": latency_ms,
|
||||||
self._total_tokens += total_tk
|
"llm_call_index": call_index,
|
||||||
self._llm_call_count += 1
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Message events: only lead_agent gets message-category events.
|
||||||
|
# Content uses message.model_dump() to align with checkpoint format.
|
||||||
|
tool_calls = getattr(message, "tool_calls", None) or []
|
||||||
|
if caller == "lead_agent":
|
||||||
|
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||||
|
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
|
||||||
|
if tool_calls:
|
||||||
|
# ai_tool_call: agent decided to use tools
|
||||||
|
self._put(
|
||||||
|
event_type="ai_tool_call",
|
||||||
|
category="message",
|
||||||
|
content=message.model_dump(),
|
||||||
|
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
|
||||||
|
)
|
||||||
|
elif isinstance(content, str) and content:
|
||||||
|
# ai_message: final text reply
|
||||||
|
self._put(
|
||||||
|
event_type="ai_message",
|
||||||
|
category="message",
|
||||||
|
content=message.model_dump(),
|
||||||
|
metadata={"model_name": model_name, "finish_reason": "stop"},
|
||||||
|
)
|
||||||
|
self._last_ai_msg = content
|
||||||
|
self._msg_count += 1
|
||||||
|
|
||||||
|
# Token accumulation
|
||||||
|
if self._track_tokens:
|
||||||
|
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||||
|
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||||
|
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||||
|
if total_tk == 0:
|
||||||
|
total_tk = input_tk + output_tk
|
||||||
|
if total_tk > 0:
|
||||||
|
self._total_input_tokens += input_tk
|
||||||
|
self._total_output_tokens += output_tk
|
||||||
|
self._total_tokens += total_tk
|
||||||
|
self._llm_call_count += 1
|
||||||
|
if caller.startswith("subagent:"):
|
||||||
|
self._subagent_tokens += total_tk
|
||||||
|
elif caller.startswith("middleware:"):
|
||||||
|
self._middleware_tokens += total_tk
|
||||||
|
else:
|
||||||
|
self._lead_agent_tokens += total_tk
|
||||||
|
|
||||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
self._llm_start_times.pop(str(run_id), None)
|
self._llm_start_times.pop(str(run_id), None)
|
||||||
self._put(event_type="llm.error", category="trace", content=str(error))
|
self._put(event_type="llm_error", category="trace", content=str(error))
|
||||||
|
|
||||||
def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs):
|
# -- Tool callbacks --
|
||||||
"""Handle tool start event, cache tool call ID for later correlation"""
|
|
||||||
tool_call_id = str(run_id)
|
|
||||||
logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}")
|
|
||||||
|
|
||||||
def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs):
|
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
"""Handle tool end event, append message and clear node data"""
|
tool_call_id = kwargs.get("tool_call_id")
|
||||||
try:
|
if tool_call_id:
|
||||||
if isinstance(output, ToolMessage):
|
self._tool_call_ids[str(run_id)] = tool_call_id
|
||||||
msg = cast(ToolMessage, output)
|
self._put(
|
||||||
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
|
event_type="tool_start",
|
||||||
elif isinstance(output, Command):
|
category="trace",
|
||||||
cmd = cast(Command, output)
|
metadata={
|
||||||
messages = cmd.update.get("messages", [])
|
"tool_name": serialized.get("name", ""),
|
||||||
for message in messages:
|
"tool_call_id": tool_call_id,
|
||||||
if isinstance(message, BaseMessage):
|
"args": str(input_str)[:2000],
|
||||||
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
|
},
|
||||||
else:
|
)
|
||||||
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
|
|
||||||
else:
|
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}")
|
from langchain_core.messages import ToolMessage
|
||||||
finally:
|
|
||||||
logger.info(f"Tool end for node {run_id}")
|
# Extract fields from ToolMessage object when LangChain provides one.
|
||||||
|
# LangChain's _format_output wraps tool results into a ToolMessage
|
||||||
|
# with tool_call_id, name, status, and artifact — more complete than
|
||||||
|
# what kwargs alone provides.
|
||||||
|
if isinstance(output, ToolMessage):
|
||||||
|
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||||
|
tool_name = output.name or kwargs.get("name", "")
|
||||||
|
status = getattr(output, "status", "success") or "success"
|
||||||
|
content_str = output.content if isinstance(output.content, str) else str(output.content)
|
||||||
|
# Use model_dump() for checkpoint-aligned message content.
|
||||||
|
# Override tool_call_id if it was resolved from cache.
|
||||||
|
msg_content = output.model_dump()
|
||||||
|
if msg_content.get("tool_call_id") != tool_call_id:
|
||||||
|
msg_content["tool_call_id"] = tool_call_id
|
||||||
|
else:
|
||||||
|
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||||
|
tool_name = kwargs.get("name", "")
|
||||||
|
status = "success"
|
||||||
|
content_str = str(output)
|
||||||
|
# Construct checkpoint-aligned dict when output is a plain string.
|
||||||
|
msg_content = ToolMessage(
|
||||||
|
content=content_str,
|
||||||
|
tool_call_id=tool_call_id or "",
|
||||||
|
name=tool_name,
|
||||||
|
status=status,
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
# Trace event (always)
|
||||||
|
self._put(
|
||||||
|
event_type="tool_end",
|
||||||
|
category="trace",
|
||||||
|
content=content_str,
|
||||||
|
metadata={
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
"status": status,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Message event: tool_result (checkpoint-aligned model_dump format)
|
||||||
|
self._put(
|
||||||
|
event_type="tool_result",
|
||||||
|
category="message",
|
||||||
|
content=msg_content,
|
||||||
|
metadata={"tool_name": tool_name, "status": status},
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
|
||||||
|
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||||
|
tool_name = kwargs.get("name", "")
|
||||||
|
|
||||||
|
# Trace event
|
||||||
|
self._put(
|
||||||
|
event_type="tool_error",
|
||||||
|
category="trace",
|
||||||
|
content=str(error),
|
||||||
|
metadata={
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Message event: tool_result with error status (checkpoint-aligned)
|
||||||
|
msg_content = ToolMessage(
|
||||||
|
content=str(error),
|
||||||
|
tool_call_id=tool_call_id or "",
|
||||||
|
name=tool_name,
|
||||||
|
status="error",
|
||||||
|
).model_dump()
|
||||||
|
self._put(
|
||||||
|
event_type="tool_result",
|
||||||
|
category="message",
|
||||||
|
content=msg_content,
|
||||||
|
metadata={"tool_name": tool_name, "status": "error"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Custom event callback --
|
||||||
|
|
||||||
|
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
|
from deerflow.runtime.serialization import serialize_lc_object
|
||||||
|
|
||||||
|
if name == "summarization":
|
||||||
|
data_dict = data if isinstance(data, dict) else {}
|
||||||
|
self._put(
|
||||||
|
event_type="summarization",
|
||||||
|
category="trace",
|
||||||
|
content=data_dict.get("summary", ""),
|
||||||
|
metadata={
|
||||||
|
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
|
||||||
|
"replaced_count": data_dict.get("replaced_count", 0),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self._put(
|
||||||
|
event_type="middleware:summarize",
|
||||||
|
category="middleware",
|
||||||
|
content={"role": "system", "content": data_dict.get("summary", "")},
|
||||||
|
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
|
||||||
|
self._put(
|
||||||
|
event_type=name,
|
||||||
|
category="trace",
|
||||||
|
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
|
||||||
|
)
|
||||||
|
|
||||||
# -- Internal methods --
|
# -- Internal methods --
|
||||||
|
|
||||||
@@ -271,10 +381,6 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
"""
|
"""
|
||||||
if not self._buffer:
|
if not self._buffer:
|
||||||
return
|
return
|
||||||
# Skip if a flush is already in flight — avoids concurrent writes
|
|
||||||
# to the same SQLite file from multiple fire-and-forget tasks.
|
|
||||||
if self._pending_flush_tasks:
|
|
||||||
return
|
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
@@ -283,7 +389,6 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
batch = self._buffer.copy()
|
batch = self._buffer.copy()
|
||||||
self._buffer.clear()
|
self._buffer.clear()
|
||||||
task = loop.create_task(self._flush_async(batch))
|
task = loop.create_task(self._flush_async(batch))
|
||||||
self._pending_flush_tasks.add(task)
|
|
||||||
task.add_done_callback(self._on_flush_done)
|
task.add_done_callback(self._on_flush_done)
|
||||||
|
|
||||||
async def _flush_async(self, batch: list[dict]) -> None:
|
async def _flush_async(self, batch: list[dict]) -> None:
|
||||||
@@ -299,17 +404,16 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
# Return failed events to buffer for retry on next flush
|
# Return failed events to buffer for retry on next flush
|
||||||
self._buffer = batch + self._buffer
|
self._buffer = batch + self._buffer
|
||||||
|
|
||||||
def _on_flush_done(self, task: asyncio.Task) -> None:
|
@staticmethod
|
||||||
self._pending_flush_tasks.discard(task)
|
def _on_flush_done(task: asyncio.Task) -> None:
|
||||||
if task.cancelled():
|
if task.cancelled():
|
||||||
return
|
return
|
||||||
exc = task.exception()
|
exc = task.exception()
|
||||||
if exc:
|
if exc:
|
||||||
logger.warning("Journal flush task failed: %s", exc)
|
logger.warning("Journal flush task failed: %s", exc)
|
||||||
|
|
||||||
def _identify_caller(self, tags: list[str] | None, **kwargs) -> str:
|
def _identify_caller(self, kwargs: dict) -> str:
|
||||||
_tags = tags or kwargs.get("tags", [])
|
for tag in kwargs.get("tags") or []:
|
||||||
for tag in _tags:
|
|
||||||
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
||||||
return tag
|
return tag
|
||||||
# Default to lead_agent: the main agent graph does not inject
|
# Default to lead_agent: the main agent graph does not inject
|
||||||
@@ -346,17 +450,10 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
|
|
||||||
async def flush(self) -> None:
|
async def flush(self) -> None:
|
||||||
"""Force flush remaining buffer. Called in worker's finally block."""
|
"""Force flush remaining buffer. Called in worker's finally block."""
|
||||||
if self._pending_flush_tasks:
|
if self._buffer:
|
||||||
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
|
batch = self._buffer.copy()
|
||||||
|
self._buffer.clear()
|
||||||
while self._buffer:
|
await self._store.put_batch(batch)
|
||||||
batch = self._buffer[: self._flush_threshold]
|
|
||||||
del self._buffer[: self._flush_threshold]
|
|
||||||
try:
|
|
||||||
await self._store.put_batch(batch)
|
|
||||||
except Exception:
|
|
||||||
self._buffer = batch + self._buffer
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_completion_data(self) -> dict:
|
def get_completion_data(self) -> dict:
|
||||||
"""Return accumulated token and message data for run completion."""
|
"""Return accumulated token and message data for run completion."""
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class RunManager:
|
|||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._store = store
|
self._store = store
|
||||||
|
|
||||||
async def _persist_to_store(self, record: RunRecord) -> None:
|
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
|
||||||
"""Best-effort persist run record to backing store."""
|
"""Best-effort persist run record to backing store."""
|
||||||
if self._store is None:
|
if self._store is None:
|
||||||
return
|
return
|
||||||
@@ -68,6 +68,7 @@ class RunManager:
|
|||||||
metadata=record.metadata or {},
|
metadata=record.metadata or {},
|
||||||
kwargs=record.kwargs or {},
|
kwargs=record.kwargs or {},
|
||||||
created_at=record.created_at,
|
created_at=record.created_at,
|
||||||
|
follow_up_to_run_id=follow_up_to_run_id,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
||||||
@@ -89,6 +90,7 @@ class RunManager:
|
|||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
|
follow_up_to_run_id: str | None = None,
|
||||||
) -> RunRecord:
|
) -> RunRecord:
|
||||||
"""Create a new pending run and register it."""
|
"""Create a new pending run and register it."""
|
||||||
run_id = str(uuid.uuid4())
|
run_id = str(uuid.uuid4())
|
||||||
@@ -107,7 +109,7 @@ class RunManager:
|
|||||||
)
|
)
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
await self._persist_to_store(record)
|
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
||||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
@@ -120,7 +122,7 @@ class RunManager:
|
|||||||
async with self._lock:
|
async with self._lock:
|
||||||
# Dict insertion order matches creation order, so reversing it gives
|
# Dict insertion order matches creation order, so reversing it gives
|
||||||
# us deterministic newest-first results even when timestamps tie.
|
# us deterministic newest-first results even when timestamps tie.
|
||||||
return [r for r in self._runs.values() if r.thread_id == thread_id]
|
return [r for r in reversed(self._runs.values()) if r.thread_id == thread_id]
|
||||||
|
|
||||||
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
||||||
"""Transition a run to a new status."""
|
"""Transition a run to a new status."""
|
||||||
@@ -174,6 +176,7 @@ class RunManager:
|
|||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
|
follow_up_to_run_id: str | None = None,
|
||||||
) -> RunRecord:
|
) -> RunRecord:
|
||||||
"""Atomically check for inflight runs and create a new one.
|
"""Atomically check for inflight runs and create a new one.
|
||||||
|
|
||||||
@@ -227,7 +230,7 @@ class RunManager:
|
|||||||
)
|
)
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
|
|
||||||
await self._persist_to_store(record)
|
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
||||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ RunManager depends on this interface. Implementations:
|
|||||||
- MemoryRunStore: in-memory dict (development, tests)
|
- MemoryRunStore: in-memory dict (development, tests)
|
||||||
- Future: RunRepository backed by SQLAlchemy ORM
|
- Future: RunRepository backed by SQLAlchemy ORM
|
||||||
|
|
||||||
All methods accept an optional user_id for user isolation.
|
All methods accept an optional owner_id for user isolation.
|
||||||
When user_id is None, no user filtering is applied (single-user mode).
|
When owner_id is None, no user filtering is applied (single-user mode).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -22,13 +22,14 @@ class RunStore(abc.ABC):
|
|||||||
*,
|
*,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
assistant_id: str | None = None,
|
assistant_id: str | None = None,
|
||||||
user_id: str | None = None,
|
owner_id: str | None = None,
|
||||||
status: str = "pending",
|
status: str = "pending",
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
kwargs: dict[str, Any] | None = None,
|
kwargs: dict[str, Any] | None = None,
|
||||||
error: str | None = None,
|
error: str | None = None,
|
||||||
created_at: str | None = None,
|
created_at: str | None = None,
|
||||||
|
follow_up_to_run_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -41,7 +42,7 @@ class RunStore(abc.ABC):
|
|||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
user_id: str | None = None,
|
owner_id: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -21,25 +21,27 @@ class MemoryRunStore(RunStore):
|
|||||||
*,
|
*,
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=None,
|
assistant_id=None,
|
||||||
user_id=None,
|
owner_id=None,
|
||||||
status="pending",
|
status="pending",
|
||||||
multitask_strategy="reject",
|
multitask_strategy="reject",
|
||||||
metadata=None,
|
metadata=None,
|
||||||
kwargs=None,
|
kwargs=None,
|
||||||
error=None,
|
error=None,
|
||||||
created_at=None,
|
created_at=None,
|
||||||
|
follow_up_to_run_id=None,
|
||||||
):
|
):
|
||||||
now = datetime.now(UTC).isoformat()
|
now = datetime.now(UTC).isoformat()
|
||||||
self._runs[run_id] = {
|
self._runs[run_id] = {
|
||||||
"run_id": run_id,
|
"run_id": run_id,
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
"assistant_id": assistant_id,
|
"assistant_id": assistant_id,
|
||||||
"user_id": user_id,
|
"owner_id": owner_id,
|
||||||
"status": status,
|
"status": status,
|
||||||
"multitask_strategy": multitask_strategy,
|
"multitask_strategy": multitask_strategy,
|
||||||
"metadata": metadata or {},
|
"metadata": metadata or {},
|
||||||
"kwargs": kwargs or {},
|
"kwargs": kwargs or {},
|
||||||
"error": error,
|
"error": error,
|
||||||
|
"follow_up_to_run_id": follow_up_to_run_id,
|
||||||
"created_at": created_at or now,
|
"created_at": created_at or now,
|
||||||
"updated_at": now,
|
"updated_at": now,
|
||||||
}
|
}
|
||||||
@@ -47,8 +49,8 @@ class MemoryRunStore(RunStore):
|
|||||||
async def get(self, run_id):
|
async def get(self, run_id):
|
||||||
return self._runs.get(run_id)
|
return self._runs.get(run_id)
|
||||||
|
|
||||||
async def list_by_thread(self, thread_id, *, user_id=None, limit=100):
|
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
|
||||||
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (user_id is None or r.get("user_id") == user_id)]
|
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (owner_id is None or r.get("owner_id") == owner_id)]
|
||||||
results.sort(key=lambda r: r["created_at"], reverse=True)
|
results.sort(key=lambda r: r["created_at"], reverse=True)
|
||||||
return results[:limit]
|
return results[:limit]
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,8 @@ class RunContext:
|
|||||||
store: Any | None = field(default=None)
|
store: Any | None = field(default=None)
|
||||||
event_store: Any | None = field(default=None)
|
event_store: Any | None = field(default=None)
|
||||||
run_events_config: Any | None = field(default=None)
|
run_events_config: Any | None = field(default=None)
|
||||||
thread_store: Any | None = field(default=None)
|
thread_meta_repo: Any | None = field(default=None)
|
||||||
|
follow_up_to_run_id: str | None = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
async def run_agent(
|
async def run_agent(
|
||||||
@@ -74,7 +75,8 @@ async def run_agent(
|
|||||||
store = ctx.store
|
store = ctx.store
|
||||||
event_store = ctx.event_store
|
event_store = ctx.event_store
|
||||||
run_events_config = ctx.run_events_config
|
run_events_config = ctx.run_events_config
|
||||||
thread_store = ctx.thread_store
|
thread_meta_repo = ctx.thread_meta_repo
|
||||||
|
follow_up_to_run_id = ctx.follow_up_to_run_id
|
||||||
|
|
||||||
run_id = record.run_id
|
run_id = record.run_id
|
||||||
thread_id = record.thread_id
|
thread_id = record.thread_id
|
||||||
@@ -83,9 +85,34 @@ async def run_agent(
|
|||||||
pre_run_snapshot: dict[str, Any] | None = None
|
pre_run_snapshot: dict[str, Any] | None = None
|
||||||
snapshot_capture_failed = False
|
snapshot_capture_failed = False
|
||||||
|
|
||||||
|
# Initialize RunJournal for event capture
|
||||||
journal = None
|
journal = None
|
||||||
|
if event_store is not None:
|
||||||
|
from deerflow.runtime.journal import RunJournal
|
||||||
|
|
||||||
journal = None
|
journal = RunJournal(
|
||||||
|
run_id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
event_store=event_store,
|
||||||
|
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write human_message event (model_dump format, aligned with checkpoint)
|
||||||
|
human_msg = _extract_human_message(graph_input)
|
||||||
|
if human_msg is not None:
|
||||||
|
msg_metadata = {}
|
||||||
|
if follow_up_to_run_id:
|
||||||
|
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
||||||
|
await event_store.put(
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
event_type="human_message",
|
||||||
|
category="message",
|
||||||
|
content=human_msg.model_dump(),
|
||||||
|
metadata=msg_metadata or None,
|
||||||
|
)
|
||||||
|
content = human_msg.content
|
||||||
|
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
||||||
|
|
||||||
# Track whether "events" was requested but skipped
|
# Track whether "events" was requested but skipped
|
||||||
if "events" in requested_modes:
|
if "events" in requested_modes:
|
||||||
@@ -95,22 +122,6 @@ async def run_agent(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize RunJournal + write human_message event.
|
|
||||||
# These are inside the try block so any exception (e.g. a DB
|
|
||||||
# error writing the event) flows through the except/finally
|
|
||||||
# path that publishes an "end" event to the SSE bridge —
|
|
||||||
# otherwise a failure here would leave the stream hanging
|
|
||||||
# with no terminator.
|
|
||||||
if event_store is not None:
|
|
||||||
from deerflow.runtime.journal import RunJournal
|
|
||||||
|
|
||||||
journal = RunJournal(
|
|
||||||
run_id=run_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
event_store=event_store,
|
|
||||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. Mark running
|
# 1. Mark running
|
||||||
await run_manager.set_status(run_id, RunStatus.running)
|
await run_manager.set_status(run_id, RunStatus.running)
|
||||||
|
|
||||||
@@ -148,13 +159,12 @@ async def run_agent(
|
|||||||
|
|
||||||
# Inject runtime context so middlewares can access thread_id
|
# Inject runtime context so middlewares can access thread_id
|
||||||
# (langgraph-cli does this automatically; we must do it manually)
|
# (langgraph-cli does this automatically; we must do it manually)
|
||||||
runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store)
|
runtime = Runtime(context={"thread_id": thread_id}, store=store)
|
||||||
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
|
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
|
||||||
# prefers it over ``configurable`` for thread-level data), make
|
# prefers it over ``configurable`` for thread-level data), make
|
||||||
# sure ``thread_id`` is available there too.
|
# sure ``thread_id`` is available there too.
|
||||||
if "context" in config and isinstance(config["context"], dict):
|
if "context" in config and isinstance(config["context"], dict):
|
||||||
config["context"].setdefault("thread_id", thread_id)
|
config["context"].setdefault("thread_id", thread_id)
|
||||||
config["context"].setdefault("run_id", run_id)
|
|
||||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||||
|
|
||||||
# Inject RunJournal as a LangChain callback handler.
|
# Inject RunJournal as a LangChain callback handler.
|
||||||
@@ -295,15 +305,12 @@ async def run_agent(
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
|
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
|
||||||
|
|
||||||
try:
|
# Persist token usage + convenience fields to RunStore
|
||||||
# Persist token usage + convenience fields to RunStore
|
completion = journal.get_completion_data()
|
||||||
completion = journal.get_completion_data()
|
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
|
||||||
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to persist run completion for %s (non-fatal)", run_id, exc_info=True)
|
|
||||||
|
|
||||||
# Sync title from checkpoint to threads_meta.display_name
|
# Sync title from checkpoint to threads_meta.display_name
|
||||||
if checkpointer is not None and thread_store is not None:
|
if checkpointer is not None:
|
||||||
try:
|
try:
|
||||||
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
|
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
|
||||||
@@ -311,17 +318,16 @@ async def run_agent(
|
|||||||
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
|
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
|
||||||
title = ckpt.get("channel_values", {}).get("title")
|
title = ckpt.get("channel_values", {}).get("title")
|
||||||
if title:
|
if title:
|
||||||
await thread_store.update_display_name(thread_id, title)
|
await thread_meta_repo.update_display_name(thread_id, title)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
|
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
|
||||||
|
|
||||||
# Update threads_meta status based on run outcome
|
# Update threads_meta status based on run outcome
|
||||||
if thread_store is not None:
|
try:
|
||||||
try:
|
final_status = "idle" if record.status == RunStatus.success else record.status.value
|
||||||
final_status = "idle" if record.status == RunStatus.success else record.status.value
|
await thread_meta_repo.update_status(thread_id, final_status)
|
||||||
await thread_store.update_status(thread_id, final_status)
|
except Exception:
|
||||||
except Exception:
|
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
|
||||||
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
|
|
||||||
|
|
||||||
await bridge.publish_end(run_id)
|
await bridge.publish_end(run_id)
|
||||||
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ async def make_store() -> AsyncIterator[BaseStore]:
|
|||||||
configured checkpointer.
|
configured checkpointer.
|
||||||
|
|
||||||
Reads from the same ``checkpointer`` section of *config.yaml* used by
|
Reads from the same ``checkpointer`` section of *config.yaml* used by
|
||||||
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so
|
:func:`deerflow.agents.checkpointer.async_provider.make_checkpointer` so
|
||||||
that both singletons always use the same persistence technology::
|
that both singletons always use the same persistence technology::
|
||||||
|
|
||||||
async with make_store() as store:
|
async with make_store() as store:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Async stream bridge factory.
|
"""Async stream bridge factory.
|
||||||
|
|
||||||
Provides an **async context manager** aligned with
|
Provides an **async context manager** aligned with
|
||||||
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer`.
|
:func:`deerflow.agents.checkpointer.async_provider.make_checkpointer`.
|
||||||
|
|
||||||
Usage (e.g. FastAPI lifespan)::
|
Usage (e.g. FastAPI lifespan)::
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""Request-scoped user context for user-based authorization.
|
"""Request-scoped user context for owner-based authorization.
|
||||||
|
|
||||||
This module holds a :class:`~contextvars.ContextVar` that the gateway's
|
This module holds a :class:`~contextvars.ContextVar` that the gateway's
|
||||||
auth middleware sets after a successful authentication. Repository
|
auth middleware sets after a successful authentication. Repository
|
||||||
methods read the contextvar via a sentinel default parameter, letting
|
methods read the contextvar via a sentinel default parameter, letting
|
||||||
routers stay free of ``user_id`` boilerplate.
|
routers stay free of ``owner_id`` boilerplate.
|
||||||
|
|
||||||
Three-state semantics for the repository ``user_id`` parameter (the
|
Three-state semantics for the repository ``owner_id`` parameter (the
|
||||||
consumer side of this module lives in ``deerflow.persistence.*``):
|
consumer side of this module lives in ``deerflow.persistence.*``):
|
||||||
|
|
||||||
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
|
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
|
||||||
@@ -91,35 +91,16 @@ def require_current_user() -> CurrentUser:
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Effective user_id helpers (filesystem isolation)
|
# Sentinel-based owner_id resolution
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
DEFAULT_USER_ID: Final[str] = "default"
|
|
||||||
|
|
||||||
|
|
||||||
def get_effective_user_id() -> str:
|
|
||||||
"""Return the current user's id as a string, or DEFAULT_USER_ID if unset.
|
|
||||||
|
|
||||||
Unlike :func:`require_current_user` this never raises — it is designed
|
|
||||||
for filesystem-path resolution where a valid user bucket is always needed.
|
|
||||||
"""
|
|
||||||
user = _current_user.get()
|
|
||||||
if user is None:
|
|
||||||
return DEFAULT_USER_ID
|
|
||||||
return str(user.id)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Sentinel-based user_id resolution
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
# Repository methods accept a ``user_id`` keyword-only argument that
|
# Repository methods accept an ``owner_id`` keyword-only argument that
|
||||||
# defaults to ``AUTO``. The three possible values drive distinct
|
# defaults to ``AUTO``. The three possible values drive distinct
|
||||||
# behaviours; see the docstring on :func:`resolve_user_id`.
|
# behaviours; see the docstring on :func:`resolve_owner_id`.
|
||||||
|
|
||||||
|
|
||||||
class _AutoSentinel:
|
class _AutoSentinel:
|
||||||
"""Singleton marker meaning 'resolve user_id from contextvar'."""
|
"""Singleton marker meaning 'resolve owner_id from contextvar'."""
|
||||||
|
|
||||||
_instance: _AutoSentinel | None = None
|
_instance: _AutoSentinel | None = None
|
||||||
|
|
||||||
@@ -135,12 +116,12 @@ class _AutoSentinel:
|
|||||||
AUTO: Final[_AutoSentinel] = _AutoSentinel()
|
AUTO: Final[_AutoSentinel] = _AutoSentinel()
|
||||||
|
|
||||||
|
|
||||||
def resolve_user_id(
|
def resolve_owner_id(
|
||||||
value: str | None | _AutoSentinel,
|
value: str | None | _AutoSentinel,
|
||||||
*,
|
*,
|
||||||
method_name: str = "repository method",
|
method_name: str = "repository method",
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Resolve the user_id parameter passed to a repository method.
|
"""Resolve the owner_id parameter passed to a repository method.
|
||||||
|
|
||||||
Three-state semantics:
|
Three-state semantics:
|
||||||
|
|
||||||
@@ -150,16 +131,16 @@ def resolve_user_id(
|
|||||||
- Explicit ``str``: use the provided id verbatim, overriding any
|
- Explicit ``str``: use the provided id verbatim, overriding any
|
||||||
contextvar value. Useful for tests and admin-override flows.
|
contextvar value. Useful for tests and admin-override flows.
|
||||||
- Explicit ``None``: no filter — the repository should skip the
|
- Explicit ``None``: no filter — the repository should skip the
|
||||||
user_id WHERE clause entirely. Reserved for migration scripts
|
owner_id WHERE clause entirely. Reserved for migration scripts
|
||||||
and CLI tools that intentionally bypass isolation.
|
and CLI tools that intentionally bypass isolation.
|
||||||
"""
|
"""
|
||||||
if isinstance(value, _AutoSentinel):
|
if isinstance(value, _AutoSentinel):
|
||||||
user = _current_user.get()
|
user = _current_user.get()
|
||||||
if user is None:
|
if user is None:
|
||||||
raise RuntimeError(f"{method_name} called with user_id=AUTO but no user context is set; pass an explicit user_id, set the contextvar via auth middleware, or opt out with user_id=None for migration/CLI paths.")
|
raise RuntimeError(f"{method_name} called with owner_id=AUTO but no user context is set; pass an explicit owner_id, set the contextvar via auth middleware, or opt out with owner_id=None for migration/CLI paths.")
|
||||||
# Coerce to ``str`` at the boundary: ``User.id`` is typed as
|
# Coerce to ``str`` at the boundary: ``User.id`` is typed as
|
||||||
# ``UUID`` for the API surface, but the persistence layer
|
# ``UUID`` for the API surface, but the persistence layer
|
||||||
# stores ``user_id`` as ``String(64)`` and aiosqlite cannot
|
# stores ``owner_id`` as ``String(64)`` and aiosqlite cannot
|
||||||
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is
|
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is
|
||||||
# not supported"). Honour the documented return type here
|
# not supported"). Honour the documented return type here
|
||||||
# rather than ripple a type change through every caller.
|
# rather than ripple a type change through every caller.
|
||||||
|
|||||||
@@ -200,9 +200,8 @@ def _get_acp_workspace_host_path(thread_id: str | None = None) -> str | None:
|
|||||||
if thread_id is not None:
|
if thread_id is not None:
|
||||||
try:
|
try:
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
host_path = get_paths().acp_workspace_dir(thread_id, user_id=get_effective_user_id())
|
host_path = get_paths().acp_workspace_dir(thread_id)
|
||||||
if host_path.exists():
|
if host_path.exists():
|
||||||
return str(host_path)
|
return str(host_path)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -33,12 +33,11 @@ def _get_work_dir(thread_id: str | None) -> str:
|
|||||||
An absolute physical filesystem path to use as the working directory.
|
An absolute physical filesystem path to use as the working directory.
|
||||||
"""
|
"""
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
if thread_id:
|
if thread_id:
|
||||||
try:
|
try:
|
||||||
work_dir = paths.acp_workspace_dir(thread_id, user_id=get_effective_user_id())
|
work_dir = paths.acp_workspace_dir(thread_id)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning("Invalid thread_id %r for ACP workspace, falling back to global", thread_id)
|
logger.warning("Invalid thread_id %r for ACP workspace, falling back to global", thread_id)
|
||||||
work_dir = paths.base_dir / "acp-workspace"
|
work_dir = paths.base_dir / "acp-workspace"
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from langgraph.typing import ContextT
|
|||||||
|
|
||||||
from deerflow.agents.thread_state import ThreadState
|
from deerflow.agents.thread_state import ThreadState
|
||||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
||||||
|
|
||||||
@@ -48,7 +47,7 @@ def _normalize_presented_filepath(
|
|||||||
virtual_prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
|
virtual_prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
|
||||||
|
|
||||||
if stripped == virtual_prefix or stripped.startswith(virtual_prefix + "/"):
|
if stripped == virtual_prefix or stripped.startswith(virtual_prefix + "/"):
|
||||||
actual_path = get_paths().resolve_virtual_path(thread_id, filepath, user_id=get_effective_user_id())
|
actual_path = get_paths().resolve_virtual_path(thread_id, filepath)
|
||||||
else:
|
else:
|
||||||
actual_path = Path(filepath).expanduser().resolve()
|
actual_path = Path(filepath).expanduser().resolve()
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from pathlib import Path
|
|||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
|
|
||||||
class PathTraversalError(ValueError):
|
class PathTraversalError(ValueError):
|
||||||
@@ -34,7 +33,7 @@ def validate_thread_id(thread_id: str) -> None:
|
|||||||
def get_uploads_dir(thread_id: str) -> Path:
|
def get_uploads_dir(thread_id: str) -> Path:
|
||||||
"""Return the uploads directory path for a thread (no side effects)."""
|
"""Return the uploads directory path for a thread (no side effects)."""
|
||||||
validate_thread_id(thread_id)
|
validate_thread_id(thread_id)
|
||||||
return get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
return get_paths().sandbox_uploads_dir(thread_id)
|
||||||
|
|
||||||
|
|
||||||
def ensure_uploads_dir(thread_id: str) -> Path:
|
def ensure_uploads_dir(thread_id: str) -> Path:
|
||||||
|
|||||||
@@ -39,13 +39,13 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
ollama = ["langchain-ollama>=0.3.0"]
|
|
||||||
postgres = [
|
postgres = [
|
||||||
"asyncpg>=0.29",
|
"asyncpg>=0.29",
|
||||||
"langgraph-checkpoint-postgres>=3.0.5",
|
"langgraph-checkpoint-postgres>=3.0.5",
|
||||||
"psycopg[binary]>=3.3.3",
|
"psycopg[binary]>=3.3.3",
|
||||||
"psycopg-pool>=3.3.0",
|
"psycopg-pool>=3.3.0",
|
||||||
]
|
]
|
||||||
|
ollama = ["langchain-ollama>=0.3.0"]
|
||||||
pymupdf = ["pymupdf4llm>=0.0.17"]
|
pymupdf = ["pymupdf4llm>=0.0.17"]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@@ -23,7 +23,9 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
postgres = ["deerflow-harness[postgres]"]
|
postgres = [
|
||||||
|
"deerflow-harness[postgres]",
|
||||||
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = ["pytest>=8.0.0", "ruff>=0.14.11"]
|
dev = ["pytest>=8.0.0", "ruff>=0.14.11"]
|
||||||
|
|||||||
@@ -1,160 +0,0 @@
|
|||||||
"""One-time migration: move legacy thread dirs and memory into per-user layout.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
PYTHONPATH=. python scripts/migrate_user_isolation.py [--dry-run]
|
|
||||||
|
|
||||||
The script is idempotent — re-running it after a successful migration is a no-op.
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from deerflow.config.paths import Paths, get_paths
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_thread_dirs(
|
|
||||||
paths: Paths,
|
|
||||||
thread_owner_map: dict[str, str],
|
|
||||||
*,
|
|
||||||
dry_run: bool = False,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""Move legacy thread directories into per-user layout.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
paths: Paths instance.
|
|
||||||
thread_owner_map: Mapping of thread_id -> user_id from threads_meta table.
|
|
||||||
dry_run: If True, only log what would happen.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of migration report entries.
|
|
||||||
"""
|
|
||||||
report: list[dict] = []
|
|
||||||
legacy_threads = paths.base_dir / "threads"
|
|
||||||
if not legacy_threads.exists():
|
|
||||||
logger.info("No legacy threads directory found — nothing to migrate.")
|
|
||||||
return report
|
|
||||||
|
|
||||||
for thread_dir in sorted(legacy_threads.iterdir()):
|
|
||||||
if not thread_dir.is_dir():
|
|
||||||
continue
|
|
||||||
thread_id = thread_dir.name
|
|
||||||
user_id = thread_owner_map.get(thread_id, "default")
|
|
||||||
dest = paths.base_dir / "users" / user_id / "threads" / thread_id
|
|
||||||
|
|
||||||
entry = {"thread_id": thread_id, "user_id": user_id, "action": ""}
|
|
||||||
|
|
||||||
if dest.exists():
|
|
||||||
conflicts_dir = paths.base_dir / "migration-conflicts" / thread_id
|
|
||||||
entry["action"] = f"conflict -> {conflicts_dir}"
|
|
||||||
if not dry_run:
|
|
||||||
conflicts_dir.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
shutil.move(str(thread_dir), str(conflicts_dir))
|
|
||||||
logger.warning("Conflict for thread %s: moved to %s", thread_id, conflicts_dir)
|
|
||||||
else:
|
|
||||||
entry["action"] = f"moved -> {dest}"
|
|
||||||
if not dry_run:
|
|
||||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
shutil.move(str(thread_dir), str(dest))
|
|
||||||
logger.info("Migrated thread %s -> user %s", thread_id, user_id)
|
|
||||||
|
|
||||||
report.append(entry)
|
|
||||||
|
|
||||||
# Clean up empty legacy threads dir
|
|
||||||
if not dry_run and legacy_threads.exists() and not any(legacy_threads.iterdir()):
|
|
||||||
legacy_threads.rmdir()
|
|
||||||
|
|
||||||
return report
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_memory(
|
|
||||||
paths: Paths,
|
|
||||||
user_id: str = "default",
|
|
||||||
*,
|
|
||||||
dry_run: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""Move legacy global memory.json into per-user layout.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
paths: Paths instance.
|
|
||||||
user_id: Target user to receive the legacy memory.
|
|
||||||
dry_run: If True, only log.
|
|
||||||
"""
|
|
||||||
legacy_mem = paths.base_dir / "memory.json"
|
|
||||||
if not legacy_mem.exists():
|
|
||||||
logger.info("No legacy memory.json found — nothing to migrate.")
|
|
||||||
return
|
|
||||||
|
|
||||||
dest = paths.user_memory_file(user_id)
|
|
||||||
if dest.exists():
|
|
||||||
legacy_backup = paths.base_dir / "memory.legacy.json"
|
|
||||||
logger.warning("Destination %s exists; renaming legacy to %s", dest, legacy_backup)
|
|
||||||
if not dry_run:
|
|
||||||
legacy_mem.rename(legacy_backup)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("Migrating memory.json -> %s", dest)
|
|
||||||
if not dry_run:
|
|
||||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
shutil.move(str(legacy_mem), str(dest))
|
|
||||||
|
|
||||||
|
|
||||||
def _build_owner_map_from_db(paths: Paths) -> dict[str, str]:
|
|
||||||
"""Query threads_meta table for thread_id -> user_id mapping.
|
|
||||||
|
|
||||||
Uses raw sqlite3 to avoid async dependencies.
|
|
||||||
"""
|
|
||||||
import sqlite3
|
|
||||||
|
|
||||||
db_path = paths.base_dir / "deer-flow.db"
|
|
||||||
if not db_path.exists():
|
|
||||||
logger.info("No database found at %s — using empty owner map.", db_path)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
conn = sqlite3.connect(str(db_path))
|
|
||||||
try:
|
|
||||||
cursor = conn.execute("SELECT thread_id, user_id FROM threads_meta WHERE user_id IS NOT NULL")
|
|
||||||
return {row[0]: row[1] for row in cursor.fetchall()}
|
|
||||||
except sqlite3.OperationalError as e:
|
|
||||||
logger.warning("Failed to query threads_meta: %s", e)
|
|
||||||
return {}
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(description="Migrate DeerFlow data to per-user layout")
|
|
||||||
parser.add_argument("--dry-run", action="store_true", help="Log actions without making changes")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
|
||||||
|
|
||||||
paths = get_paths()
|
|
||||||
logger.info("Base directory: %s", paths.base_dir)
|
|
||||||
logger.info("Dry run: %s", args.dry_run)
|
|
||||||
|
|
||||||
owner_map = _build_owner_map_from_db(paths)
|
|
||||||
logger.info("Found %d thread ownership records in DB", len(owner_map))
|
|
||||||
|
|
||||||
report = migrate_thread_dirs(paths, owner_map, dry_run=args.dry_run)
|
|
||||||
migrate_memory(paths, user_id="default", dry_run=args.dry_run)
|
|
||||||
|
|
||||||
if report:
|
|
||||||
logger.info("Migration report:")
|
|
||||||
for entry in report:
|
|
||||||
logger.info(" thread=%s user=%s action=%s", entry["thread_id"], entry["user_id"], entry["action"])
|
|
||||||
else:
|
|
||||||
logger.info("No threads to migrate.")
|
|
||||||
|
|
||||||
unowned = [e for e in report if e["user_id"] == "default"]
|
|
||||||
if unowned:
|
|
||||||
logger.warning("%d thread(s) had no owner and were assigned to 'default':", len(unowned))
|
|
||||||
for e in unowned:
|
|
||||||
logger.warning(" %s", e["thread_id"])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -3,16 +3,16 @@
|
|||||||
The production gateway runs ``AuthMiddleware`` (validates the JWT cookie)
|
The production gateway runs ``AuthMiddleware`` (validates the JWT cookie)
|
||||||
ahead of every router, plus ``@require_permission(owner_check=True)``
|
ahead of every router, plus ``@require_permission(owner_check=True)``
|
||||||
decorators that read ``request.state.auth`` and call
|
decorators that read ``request.state.auth`` and call
|
||||||
``thread_store.check_access``. Router-level unit tests construct
|
``thread_meta_repo.check_access``. Router-level unit tests construct
|
||||||
**bare** FastAPI apps that include only one router — they have neither
|
**bare** FastAPI apps that include only one router — they have neither
|
||||||
the auth middleware nor a real thread_store, so the decorators raise
|
the auth middleware nor a real thread_meta_repo, so the decorators raise
|
||||||
401 (TestClient path) or ValueError (direct-call path).
|
401 (TestClient path) or ValueError (direct-call path).
|
||||||
|
|
||||||
This module provides two surfaces:
|
This module provides two surfaces:
|
||||||
|
|
||||||
1. :func:`make_authed_test_app` — wraps ``FastAPI()`` with a tiny
|
1. :func:`make_authed_test_app` — wraps ``FastAPI()`` with a tiny
|
||||||
``BaseHTTPMiddleware`` that stamps a fake user / AuthContext on every
|
``BaseHTTPMiddleware`` that stamps a fake user / AuthContext on every
|
||||||
request, plus a permissive ``thread_store`` mock on
|
request, plus a permissive ``thread_meta_repo`` mock on
|
||||||
``app.state``. Use from TestClient-based router tests.
|
``app.state``. Use from TestClient-based router tests.
|
||||||
|
|
||||||
2. :func:`call_unwrapped` — invokes the underlying function bypassing
|
2. :func:`call_unwrapped` — invokes the underlying function bypassing
|
||||||
@@ -86,20 +86,20 @@ def make_authed_test_app(
|
|||||||
user_factory: Callable[[], User] | None = None,
|
user_factory: Callable[[], User] | None = None,
|
||||||
owner_check_passes: bool = True,
|
owner_check_passes: bool = True,
|
||||||
) -> FastAPI:
|
) -> FastAPI:
|
||||||
"""Build a FastAPI test app with stub auth + permissive thread_store.
|
"""Build a FastAPI test app with stub auth + permissive thread_meta_repo.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_factory: Override the default test user. Must return a fully
|
user_factory: Override the default test user. Must return a fully
|
||||||
populated :class:`User`. Useful for cross-user isolation tests
|
populated :class:`User`. Useful for cross-user isolation tests
|
||||||
that need a stable id across requests.
|
that need a stable id across requests.
|
||||||
owner_check_passes: When True (default), ``thread_store.check_access``
|
owner_check_passes: When True (default), ``thread_meta_repo.check_access``
|
||||||
returns True for every call so ``@require_permission(owner_check=True)``
|
returns True for every call so ``@require_permission(owner_check=True)``
|
||||||
never blocks the route under test. Pass False to verify that
|
never blocks the route under test. Pass False to verify that
|
||||||
permission failures surface correctly.
|
permission failures surface correctly.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A ``FastAPI`` app with the stub middleware installed and
|
A ``FastAPI`` app with the stub middleware installed and
|
||||||
``app.state.thread_store`` set to a permissive mock. The
|
``app.state.thread_meta_repo`` set to a permissive mock. The
|
||||||
caller is still responsible for ``app.include_router(...)``.
|
caller is still responsible for ``app.include_router(...)``.
|
||||||
"""
|
"""
|
||||||
factory = user_factory or _make_stub_user
|
factory = user_factory or _make_stub_user
|
||||||
@@ -108,7 +108,7 @@ def make_authed_test_app(
|
|||||||
|
|
||||||
repo = MagicMock()
|
repo = MagicMock()
|
||||||
repo.check_access = AsyncMock(return_value=owner_check_passes)
|
repo.check_access = AsyncMock(return_value=owner_check_passes)
|
||||||
app.state.thread_store = repo
|
app.state.thread_meta_repo = repo
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|||||||
+19
-19
@@ -38,29 +38,11 @@ _executor_mock.get_background_task_result = MagicMock()
|
|||||||
sys.modules["deerflow.subagents.executor"] = _executor_mock
|
sys.modules["deerflow.subagents.executor"] = _executor_mock
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def provisioner_module():
|
|
||||||
"""Load docker/provisioner/app.py as an importable test module.
|
|
||||||
|
|
||||||
Shared by test_provisioner_kubeconfig and test_provisioner_pvc_volumes so
|
|
||||||
that any change to the provisioner entry-point path or module name only
|
|
||||||
needs to be updated in one place.
|
|
||||||
"""
|
|
||||||
repo_root = Path(__file__).resolve().parents[2]
|
|
||||||
module_path = repo_root / "docker" / "provisioner" / "app.py"
|
|
||||||
spec = importlib.util.spec_from_file_location("provisioner_app_test", module_path)
|
|
||||||
assert spec is not None
|
|
||||||
assert spec.loader is not None
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Auto-set user context for every test unless marked no_auto_user
|
# Auto-set user context for every test unless marked no_auto_user
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
# Repository methods read ``user_id`` from a contextvar by default
|
# Repository methods read ``owner_id`` from a contextvar by default
|
||||||
# (see ``deerflow.runtime.user_context``). Without this fixture, every
|
# (see ``deerflow.runtime.user_context``). Without this fixture, every
|
||||||
# pre-existing persistence test would raise RuntimeError because the
|
# pre-existing persistence test would raise RuntimeError because the
|
||||||
# contextvar is unset. The fixture sets a default test user on every
|
# contextvar is unset. The fixture sets a default test user on every
|
||||||
@@ -95,3 +77,21 @@ def _auto_user_context(request):
|
|||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
reset_current_user(token)
|
reset_current_user(token)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def provisioner_module():
|
||||||
|
"""Load docker/provisioner/app.py as an importable test module.
|
||||||
|
|
||||||
|
Shared by test_provisioner_kubeconfig and test_provisioner_pvc_volumes so
|
||||||
|
that any change to the provisioner entry-point path or module name only
|
||||||
|
needs to be updated in one place.
|
||||||
|
"""
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
module_path = repo_root / "docker" / "provisioner" / "app.py"
|
||||||
|
spec = importlib.util.spec_from_file_location("provisioner_app_test", module_path)
|
||||||
|
assert spec is not None
|
||||||
|
assert spec.loader is not None
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ def test_get_thread_mounts_includes_acp_workspace(tmp_path, monkeypatch):
|
|||||||
"""_get_thread_mounts must include /mnt/acp-workspace (read-only) for docker sandbox."""
|
"""_get_thread_mounts must include /mnt/acp-workspace (read-only) for docker sandbox."""
|
||||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||||
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
|
|
||||||
|
|
||||||
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-3")
|
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-3")
|
||||||
|
|
||||||
@@ -96,7 +95,6 @@ def test_get_thread_mounts_preserves_windows_host_path_style(tmp_path, monkeypat
|
|||||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||||
monkeypatch.setenv("DEER_FLOW_HOST_BASE_DIR", r"C:\Users\demo\deer-flow\backend\.deer-flow")
|
monkeypatch.setenv("DEER_FLOW_HOST_BASE_DIR", r"C:\Users\demo\deer-flow\backend\.deer-flow")
|
||||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||||
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
|
|
||||||
|
|
||||||
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-10")
|
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-10")
|
||||||
|
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ class TestResolveAttachments:
|
|||||||
mock_paths = MagicMock()
|
mock_paths = MagicMock()
|
||||||
mock_paths.sandbox_outputs_dir.return_value = outputs_dir
|
mock_paths.sandbox_outputs_dir.return_value = outputs_dir
|
||||||
|
|
||||||
def resolve_side_effect(tid, vpath, *, user_id=None):
|
def resolve_side_effect(tid, vpath):
|
||||||
if "data.csv" in vpath:
|
if "data.csv" in vpath:
|
||||||
return good_file
|
return good_file
|
||||||
return tmp_path / "missing.txt"
|
return tmp_path / "missing.txt"
|
||||||
|
|||||||
@@ -6,13 +6,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import deerflow.config.app_config as app_config_module
|
import deerflow.config.app_config as app_config_module
|
||||||
|
from deerflow.agents.checkpointer import get_checkpointer, reset_checkpointer
|
||||||
from deerflow.config.checkpointer_config import (
|
from deerflow.config.checkpointer_config import (
|
||||||
CheckpointerConfig,
|
CheckpointerConfig,
|
||||||
get_checkpointer_config,
|
get_checkpointer_config,
|
||||||
load_checkpointer_config_from_dict,
|
load_checkpointer_config_from_dict,
|
||||||
set_checkpointer_config,
|
set_checkpointer_config,
|
||||||
)
|
)
|
||||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -78,7 +78,7 @@ class TestGetCheckpointer:
|
|||||||
"""get_checkpointer should return InMemorySaver when not configured."""
|
"""get_checkpointer should return InMemorySaver when not configured."""
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
|
with patch("deerflow.agents.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
|
||||||
cp = get_checkpointer()
|
cp = get_checkpointer()
|
||||||
assert cp is not None
|
assert cp is not None
|
||||||
assert isinstance(cp, InMemorySaver)
|
assert isinstance(cp, InMemorySaver)
|
||||||
@@ -178,7 +178,7 @@ class TestAsyncCheckpointer:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_sqlite_creates_parent_dir_via_to_thread(self):
|
async def test_sqlite_creates_parent_dir_via_to_thread(self):
|
||||||
"""Async SQLite setup should move mkdir off the event loop."""
|
"""Async SQLite setup should move mkdir off the event loop."""
|
||||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||||
|
|
||||||
mock_config = MagicMock()
|
mock_config = MagicMock()
|
||||||
mock_config.checkpointer = CheckpointerConfig(type="sqlite", connection_string="relative/test.db")
|
mock_config.checkpointer = CheckpointerConfig(type="sqlite", connection_string="relative/test.db")
|
||||||
@@ -195,11 +195,11 @@ class TestAsyncCheckpointer:
|
|||||||
mock_module.AsyncSqliteSaver = mock_saver_cls
|
mock_module.AsyncSqliteSaver = mock_saver_cls
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
|
patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config),
|
||||||
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
|
patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}),
|
||||||
patch("deerflow.runtime.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
|
patch("deerflow.agents.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread,
|
||||||
patch(
|
patch(
|
||||||
"deerflow.runtime.checkpointer.async_provider.resolve_sqlite_conn_str",
|
"deerflow.agents.checkpointer.async_provider.resolve_sqlite_conn_str",
|
||||||
return_value="/tmp/resolved/test.db",
|
return_value="/tmp/resolved/test.db",
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -12,14 +12,14 @@ class TestCheckpointerNoneFix:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_async_make_checkpointer_returns_in_memory_saver_when_not_configured(self):
|
async def test_async_make_checkpointer_returns_in_memory_saver_when_not_configured(self):
|
||||||
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
|
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
|
||||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||||
|
|
||||||
# Mock get_app_config to return a config with checkpointer=None and database=None
|
# Mock get_app_config to return a config with checkpointer=None and database=None
|
||||||
mock_config = MagicMock()
|
mock_config = MagicMock()
|
||||||
mock_config.checkpointer = None
|
mock_config.checkpointer = None
|
||||||
mock_config.database = None
|
mock_config.database = None
|
||||||
|
|
||||||
with patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config):
|
with patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config):
|
||||||
async with make_checkpointer() as checkpointer:
|
async with make_checkpointer() as checkpointer:
|
||||||
# Should return InMemorySaver, not None
|
# Should return InMemorySaver, not None
|
||||||
assert checkpointer is not None
|
assert checkpointer is not None
|
||||||
@@ -36,13 +36,13 @@ class TestCheckpointerNoneFix:
|
|||||||
|
|
||||||
def test_sync_checkpointer_context_returns_in_memory_saver_when_not_configured(self):
|
def test_sync_checkpointer_context_returns_in_memory_saver_when_not_configured(self):
|
||||||
"""checkpointer_context should return InMemorySaver when config.checkpointer is None."""
|
"""checkpointer_context should return InMemorySaver when config.checkpointer is None."""
|
||||||
from deerflow.runtime.checkpointer.provider import checkpointer_context
|
from deerflow.agents.checkpointer.provider import checkpointer_context
|
||||||
|
|
||||||
# Mock get_app_config to return a config with checkpointer=None
|
# Mock get_app_config to return a config with checkpointer=None
|
||||||
mock_config = MagicMock()
|
mock_config = MagicMock()
|
||||||
mock_config.checkpointer = None
|
mock_config.checkpointer = None
|
||||||
|
|
||||||
with patch("deerflow.runtime.checkpointer.provider.get_app_config", return_value=mock_config):
|
with patch("deerflow.agents.checkpointer.provider.get_app_config", return_value=mock_config):
|
||||||
with checkpointer_context() as checkpointer:
|
with checkpointer_context() as checkpointer:
|
||||||
# Should return InMemorySaver, not None
|
# Should return InMemorySaver, not None
|
||||||
assert checkpointer is not None
|
assert checkpointer is not None
|
||||||
|
|||||||
@@ -817,7 +817,7 @@ class TestEnsureAgent:
|
|||||||
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
|
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
|
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._agent_name = "custom-agent"
|
client._agent_name = "custom-agent"
|
||||||
client._available_skills = {"test_skill"}
|
client._available_skills = {"test_skill"}
|
||||||
@@ -842,7 +842,7 @@ class TestEnsureAgent:
|
|||||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer),
|
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=mock_checkpointer),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
|
|
||||||
@@ -867,7 +867,7 @@ class TestEnsureAgent:
|
|||||||
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
|
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
|
|
||||||
@@ -886,7 +886,7 @@ class TestEnsureAgent:
|
|||||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=None),
|
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=None),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
|
|
||||||
@@ -1015,7 +1015,7 @@ class TestThreadQueries:
|
|||||||
mock_checkpointer = MagicMock()
|
mock_checkpointer = MagicMock()
|
||||||
mock_checkpointer.list.return_value = []
|
mock_checkpointer.list.return_value = []
|
||||||
|
|
||||||
with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||||
# No internal checkpointer, should fetch from provider
|
# No internal checkpointer, should fetch from provider
|
||||||
result = client.list_threads()
|
result = client.list_threads()
|
||||||
|
|
||||||
@@ -1069,7 +1069,7 @@ class TestThreadQueries:
|
|||||||
mock_checkpointer = MagicMock()
|
mock_checkpointer = MagicMock()
|
||||||
mock_checkpointer.list.return_value = []
|
mock_checkpointer.list.return_value = []
|
||||||
|
|
||||||
with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer):
|
||||||
result = client.get_thread("t99")
|
result = client.get_thread("t99")
|
||||||
|
|
||||||
assert result["thread_id"] == "t99"
|
assert result["thread_id"] == "t99"
|
||||||
@@ -1241,10 +1241,7 @@ class TestMemoryManagement:
|
|||||||
with patch("deerflow.agents.memory.updater.import_memory_data", return_value=imported) as mock_import:
|
with patch("deerflow.agents.memory.updater.import_memory_data", return_value=imported) as mock_import:
|
||||||
result = client.import_memory(imported)
|
result = client.import_memory(imported)
|
||||||
|
|
||||||
assert mock_import.call_count == 1
|
mock_import.assert_called_once_with(imported)
|
||||||
call_args = mock_import.call_args
|
|
||||||
assert call_args.args == (imported,)
|
|
||||||
assert "user_id" in call_args.kwargs
|
|
||||||
assert result == imported
|
assert result == imported
|
||||||
|
|
||||||
def test_reload_memory(self, client):
|
def test_reload_memory(self, client):
|
||||||
@@ -1490,12 +1487,9 @@ class TestUploads:
|
|||||||
|
|
||||||
class TestArtifacts:
|
class TestArtifacts:
|
||||||
def test_get_artifact(self, client):
|
def test_get_artifact(self, client):
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
paths = Paths(base_dir=tmp)
|
paths = Paths(base_dir=tmp)
|
||||||
user_id = get_effective_user_id()
|
outputs = paths.sandbox_outputs_dir("t1")
|
||||||
outputs = paths.sandbox_outputs_dir("t1", user_id=user_id)
|
|
||||||
outputs.mkdir(parents=True)
|
outputs.mkdir(parents=True)
|
||||||
(outputs / "result.txt").write_text("artifact content")
|
(outputs / "result.txt").write_text("artifact content")
|
||||||
|
|
||||||
@@ -1506,12 +1500,9 @@ class TestArtifacts:
|
|||||||
assert "text" in mime
|
assert "text" in mime
|
||||||
|
|
||||||
def test_get_artifact_not_found(self, client):
|
def test_get_artifact_not_found(self, client):
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
paths = Paths(base_dir=tmp)
|
paths = Paths(base_dir=tmp)
|
||||||
user_id = get_effective_user_id()
|
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
||||||
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
|
||||||
|
|
||||||
with patch("deerflow.client.get_paths", return_value=paths):
|
with patch("deerflow.client.get_paths", return_value=paths):
|
||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
@@ -1522,12 +1513,9 @@ class TestArtifacts:
|
|||||||
client.get_artifact("t1", "bad/path/file.txt")
|
client.get_artifact("t1", "bad/path/file.txt")
|
||||||
|
|
||||||
def test_get_artifact_path_traversal(self, client):
|
def test_get_artifact_path_traversal(self, client):
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
paths = Paths(base_dir=tmp)
|
paths = Paths(base_dir=tmp)
|
||||||
user_id = get_effective_user_id()
|
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
||||||
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
|
||||||
|
|
||||||
with patch("deerflow.client.get_paths", return_value=paths):
|
with patch("deerflow.client.get_paths", return_value=paths):
|
||||||
with pytest.raises(PathTraversalError):
|
with pytest.raises(PathTraversalError):
|
||||||
@@ -1711,16 +1699,13 @@ class TestScenarioFileLifecycle:
|
|||||||
|
|
||||||
def test_upload_then_read_artifact(self, client):
|
def test_upload_then_read_artifact(self, client):
|
||||||
"""Upload a file, simulate agent producing artifact, read it back."""
|
"""Upload a file, simulate agent producing artifact, read it back."""
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
tmp_path = Path(tmp)
|
tmp_path = Path(tmp)
|
||||||
uploads_dir = tmp_path / "uploads"
|
uploads_dir = tmp_path / "uploads"
|
||||||
uploads_dir.mkdir()
|
uploads_dir.mkdir()
|
||||||
|
|
||||||
paths = Paths(base_dir=tmp_path)
|
paths = Paths(base_dir=tmp_path)
|
||||||
user_id = get_effective_user_id()
|
outputs_dir = paths.sandbox_outputs_dir("t-artifact")
|
||||||
outputs_dir = paths.sandbox_outputs_dir("t-artifact", user_id=user_id)
|
|
||||||
outputs_dir.mkdir(parents=True)
|
outputs_dir.mkdir(parents=True)
|
||||||
|
|
||||||
# Upload phase
|
# Upload phase
|
||||||
@@ -1859,7 +1844,7 @@ class TestScenarioAgentRecreation:
|
|||||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config_a)
|
client._ensure_agent(config_a)
|
||||||
first_agent = client._agent
|
first_agent = client._agent
|
||||||
@@ -1887,7 +1872,7 @@ class TestScenarioAgentRecreation:
|
|||||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
@@ -1912,7 +1897,7 @@ class TestScenarioAgentRecreation:
|
|||||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||||
patch.object(client, "_get_tools", return_value=[]),
|
patch.object(client, "_get_tools", return_value=[]),
|
||||||
patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()),
|
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||||
):
|
):
|
||||||
client._ensure_agent(config)
|
client._ensure_agent(config)
|
||||||
client.reset_agent()
|
client.reset_agent()
|
||||||
@@ -1970,14 +1955,11 @@ class TestScenarioThreadIsolation:
|
|||||||
|
|
||||||
def test_artifacts_isolated_per_thread(self, client):
|
def test_artifacts_isolated_per_thread(self, client):
|
||||||
"""Artifacts in thread-A are not accessible from thread-B."""
|
"""Artifacts in thread-A are not accessible from thread-B."""
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
paths = Paths(base_dir=tmp)
|
paths = Paths(base_dir=tmp)
|
||||||
user_id = get_effective_user_id()
|
outputs_a = paths.sandbox_outputs_dir("thread-a")
|
||||||
outputs_a = paths.sandbox_outputs_dir("thread-a", user_id=user_id)
|
|
||||||
outputs_a.mkdir(parents=True)
|
outputs_a.mkdir(parents=True)
|
||||||
paths.sandbox_outputs_dir("thread-b", user_id=user_id).mkdir(parents=True)
|
paths.sandbox_user_data_dir("thread-b").mkdir(parents=True)
|
||||||
(outputs_a / "result.txt").write_text("thread-a artifact")
|
(outputs_a / "result.txt").write_text("thread-a artifact")
|
||||||
|
|
||||||
with patch("deerflow.client.get_paths", return_value=paths):
|
with patch("deerflow.client.get_paths", return_value=paths):
|
||||||
@@ -2882,12 +2864,9 @@ class TestUploadDeleteSymlink:
|
|||||||
class TestArtifactHardening:
|
class TestArtifactHardening:
|
||||||
def test_artifact_directory_rejected(self, client):
|
def test_artifact_directory_rejected(self, client):
|
||||||
"""get_artifact rejects paths that resolve to a directory."""
|
"""get_artifact rejects paths that resolve to a directory."""
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
paths = Paths(base_dir=tmp)
|
paths = Paths(base_dir=tmp)
|
||||||
user_id = get_effective_user_id()
|
subdir = paths.sandbox_outputs_dir("t1") / "subdir"
|
||||||
subdir = paths.sandbox_outputs_dir("t1", user_id=user_id) / "subdir"
|
|
||||||
subdir.mkdir(parents=True)
|
subdir.mkdir(parents=True)
|
||||||
|
|
||||||
with patch("deerflow.client.get_paths", return_value=paths):
|
with patch("deerflow.client.get_paths", return_value=paths):
|
||||||
@@ -2896,12 +2875,9 @@ class TestArtifactHardening:
|
|||||||
|
|
||||||
def test_artifact_leading_slash_stripped(self, client):
|
def test_artifact_leading_slash_stripped(self, client):
|
||||||
"""Paths with leading slash are handled correctly."""
|
"""Paths with leading slash are handled correctly."""
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
paths = Paths(base_dir=tmp)
|
paths = Paths(base_dir=tmp)
|
||||||
user_id = get_effective_user_id()
|
outputs = paths.sandbox_outputs_dir("t1")
|
||||||
outputs = paths.sandbox_outputs_dir("t1", user_id=user_id)
|
|
||||||
outputs.mkdir(parents=True)
|
outputs.mkdir(parents=True)
|
||||||
(outputs / "file.txt").write_text("content")
|
(outputs / "file.txt").write_text("content")
|
||||||
|
|
||||||
@@ -3015,12 +2991,9 @@ class TestBugArtifactPrefixMatchTooLoose:
|
|||||||
|
|
||||||
def test_exact_prefix_without_subpath_accepted(self, client):
|
def test_exact_prefix_without_subpath_accepted(self, client):
|
||||||
"""Bare 'mnt/user-data' is accepted (will later fail as directory, not at prefix)."""
|
"""Bare 'mnt/user-data' is accepted (will later fail as directory, not at prefix)."""
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
paths = Paths(base_dir=tmp)
|
paths = Paths(base_dir=tmp)
|
||||||
user_id = get_effective_user_id()
|
paths.sandbox_user_data_dir("t1").mkdir(parents=True)
|
||||||
paths.sandbox_outputs_dir("t1", user_id=user_id).mkdir(parents=True)
|
|
||||||
|
|
||||||
with patch("deerflow.client.get_paths", return_value=paths):
|
with patch("deerflow.client.get_paths", return_value=paths):
|
||||||
# Accepted at prefix check, but fails because it's a directory.
|
# Accepted at prefix check, but fails because it's a directory.
|
||||||
|
|||||||
@@ -262,9 +262,8 @@ class TestFileUploadIntegration:
|
|||||||
|
|
||||||
# Physically exists
|
# Physically exists
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
assert (get_paths().sandbox_uploads_dir(tid, user_id=get_effective_user_id()) / "readme.txt").exists()
|
assert (get_paths().sandbox_uploads_dir(tid) / "readme.txt").exists()
|
||||||
|
|
||||||
def test_upload_duplicate_rename(self, e2e_env, tmp_path):
|
def test_upload_duplicate_rename(self, e2e_env, tmp_path):
|
||||||
"""Uploading two files with the same name auto-renames the second."""
|
"""Uploading two files with the same name auto-renames the second."""
|
||||||
@@ -473,13 +472,12 @@ class TestArtifactAccess:
|
|||||||
def test_get_artifact_happy_path(self, e2e_env):
|
def test_get_artifact_happy_path(self, e2e_env):
|
||||||
"""Write a file to outputs, then read it back via get_artifact()."""
|
"""Write a file to outputs, then read it back via get_artifact()."""
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||||
tid = str(uuid.uuid4())
|
tid = str(uuid.uuid4())
|
||||||
|
|
||||||
# Create an output file in the thread's outputs directory
|
# Create an output file in the thread's outputs directory
|
||||||
outputs_dir = get_paths().sandbox_outputs_dir(tid, user_id=get_effective_user_id())
|
outputs_dir = get_paths().sandbox_outputs_dir(tid)
|
||||||
outputs_dir.mkdir(parents=True, exist_ok=True)
|
outputs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
(outputs_dir / "result.txt").write_text("hello artifact")
|
(outputs_dir / "result.txt").write_text("hello artifact")
|
||||||
|
|
||||||
@@ -490,12 +488,11 @@ class TestArtifactAccess:
|
|||||||
def test_get_artifact_nested_path(self, e2e_env):
|
def test_get_artifact_nested_path(self, e2e_env):
|
||||||
"""Artifacts in subdirectories are accessible."""
|
"""Artifacts in subdirectories are accessible."""
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
|
||||||
|
|
||||||
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
c = DeerFlowClient(checkpointer=None, thinking_enabled=False)
|
||||||
tid = str(uuid.uuid4())
|
tid = str(uuid.uuid4())
|
||||||
|
|
||||||
outputs_dir = get_paths().sandbox_outputs_dir(tid, user_id=get_effective_user_id())
|
outputs_dir = get_paths().sandbox_outputs_dir(tid)
|
||||||
sub = outputs_dir / "charts"
|
sub = outputs_dir / "charts"
|
||||||
sub.mkdir(parents=True, exist_ok=True)
|
sub.mkdir(parents=True, exist_ok=True)
|
||||||
(sub / "data.json").write_text('{"x": 1}')
|
(sub / "data.json").write_text('{"x": 1}')
|
||||||
|
|||||||
+134
-111
@@ -1,19 +1,21 @@
|
|||||||
"""Tests for _ensure_admin_user() in app.py.
|
"""Tests for _ensure_admin_user() in app.py.
|
||||||
|
|
||||||
Covers: first-boot no-op (admin creation removed), orphan migration
|
Covers: first-boot admin creation, auto-reset on needs_setup=True,
|
||||||
when admin exists, no-op on no admin found, and edge cases.
|
no-op on needs_setup=False, migration, and edge cases.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
|
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
|
||||||
|
|
||||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||||
|
from app.gateway.auth.models import User
|
||||||
|
|
||||||
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
|
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
|
||||||
|
|
||||||
@@ -33,85 +35,53 @@ def _make_app_stub(store=None):
|
|||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def _make_provider(admin_count=0):
|
def _make_provider(user_count=0, admin_user=None):
|
||||||
p = AsyncMock()
|
p = AsyncMock()
|
||||||
p.count_users = AsyncMock(return_value=admin_count)
|
p.count_users = AsyncMock(return_value=user_count)
|
||||||
p.count_admin_users = AsyncMock(return_value=admin_count)
|
p.create_user = AsyncMock(
|
||||||
p.create_user = AsyncMock()
|
side_effect=lambda **kw: User(
|
||||||
|
email=kw["email"],
|
||||||
|
password_hash="hashed",
|
||||||
|
system_role=kw.get("system_role", "user"),
|
||||||
|
needs_setup=kw.get("needs_setup", False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
p.get_user_by_email = AsyncMock(return_value=admin_user)
|
||||||
p.update_user = AsyncMock(side_effect=lambda u: u)
|
p.update_user = AsyncMock(side_effect=lambda u: u)
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
def _make_session_factory(admin_row=None):
|
# ── First boot: no users ─────────────────────────────────────────────────
|
||||||
"""Build a mock async session factory that returns a row from execute()."""
|
|
||||||
row_result = MagicMock()
|
|
||||||
row_result.scalar_one_or_none.return_value = admin_row
|
|
||||||
|
|
||||||
execute_result = MagicMock()
|
|
||||||
execute_result.scalar_one_or_none.return_value = admin_row
|
|
||||||
|
|
||||||
session = AsyncMock()
|
|
||||||
session.execute = AsyncMock(return_value=execute_result)
|
|
||||||
|
|
||||||
# Async context manager
|
|
||||||
session_cm = AsyncMock()
|
|
||||||
session_cm.__aenter__ = AsyncMock(return_value=session)
|
|
||||||
session_cm.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
|
|
||||||
sf = MagicMock()
|
|
||||||
sf.return_value = session_cm
|
|
||||||
return sf
|
|
||||||
|
|
||||||
|
|
||||||
# ── First boot: no admin → return early ──────────────────────────────────
|
def test_first_boot_creates_admin():
|
||||||
|
"""count_users==0 → create admin with needs_setup=True."""
|
||||||
|
provider = _make_provider(user_count=0)
|
||||||
def test_first_boot_does_not_create_admin():
|
|
||||||
"""admin_count==0 → do NOT create admin automatically."""
|
|
||||||
provider = _make_provider(admin_count=0)
|
|
||||||
app = _make_app_stub()
|
app = _make_app_stub()
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
from app.gateway.app import _ensure_admin_user
|
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||||
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
asyncio.run(_ensure_admin_user(app))
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
provider.create_user.assert_not_called()
|
provider.create_user.assert_called_once()
|
||||||
|
call_kwargs = provider.create_user.call_args[1]
|
||||||
|
assert call_kwargs["email"] == "admin@deerflow.dev"
|
||||||
|
assert call_kwargs["system_role"] == "admin"
|
||||||
|
assert call_kwargs["needs_setup"] is True
|
||||||
|
assert len(call_kwargs["password"]) > 10 # random password generated
|
||||||
|
|
||||||
|
|
||||||
def test_first_boot_skips_migration():
|
def test_first_boot_triggers_migration_if_store_present():
|
||||||
"""No admin → return early before any migration attempt."""
|
"""First boot with store → _migrate_orphaned_threads called."""
|
||||||
provider = _make_provider(admin_count=0)
|
provider = _make_provider(user_count=0)
|
||||||
store = AsyncMock()
|
store = AsyncMock()
|
||||||
store.asearch = AsyncMock(return_value=[])
|
store.asearch = AsyncMock(return_value=[])
|
||||||
app = _make_app_stub(store=store)
|
app = _make_app_stub(store=store)
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
from app.gateway.app import _ensure_admin_user
|
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||||
|
|
||||||
asyncio.run(_ensure_admin_user(app))
|
|
||||||
|
|
||||||
store.asearch.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Admin exists: migration runs when admin row found ────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_admin_exists_triggers_migration():
|
|
||||||
"""Admin exists and admin row found → _migrate_orphaned_threads called."""
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
admin_row = MagicMock()
|
|
||||||
admin_row.id = uuid4()
|
|
||||||
|
|
||||||
provider = _make_provider(admin_count=1)
|
|
||||||
sf = _make_session_factory(admin_row=admin_row)
|
|
||||||
store = AsyncMock()
|
|
||||||
store.asearch = AsyncMock(return_value=[])
|
|
||||||
app = _make_app_stub(store=store)
|
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
|
||||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
|
||||||
from app.gateway.app import _ensure_admin_user
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
asyncio.run(_ensure_admin_user(app))
|
asyncio.run(_ensure_admin_user(app))
|
||||||
@@ -119,87 +89,140 @@ def test_admin_exists_triggers_migration():
|
|||||||
store.asearch.assert_called_once()
|
store.asearch.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_admin_exists_no_admin_row_skips_migration():
|
def test_first_boot_no_store_skips_migration():
|
||||||
"""Admin count > 0 but DB row missing (edge case) → skip migration gracefully."""
|
"""First boot without store → no crash, migration skipped."""
|
||||||
provider = _make_provider(admin_count=2)
|
provider = _make_provider(user_count=0)
|
||||||
sf = _make_session_factory(admin_row=None)
|
|
||||||
store = AsyncMock()
|
|
||||||
app = _make_app_stub(store=store)
|
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
|
||||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
|
||||||
from app.gateway.app import _ensure_admin_user
|
|
||||||
|
|
||||||
asyncio.run(_ensure_admin_user(app))
|
|
||||||
|
|
||||||
store.asearch.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
def test_admin_exists_no_store_skips_migration():
|
|
||||||
"""Admin exists, row found, but no store → no crash, no migration."""
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
admin_row = MagicMock()
|
|
||||||
admin_row.id = uuid4()
|
|
||||||
|
|
||||||
provider = _make_provider(admin_count=1)
|
|
||||||
sf = _make_session_factory(admin_row=admin_row)
|
|
||||||
app = _make_app_stub(store=None)
|
app = _make_app_stub(store=None)
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||||
from app.gateway.app import _ensure_admin_user
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
asyncio.run(_ensure_admin_user(app))
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
# No assertion needed — just verify no crash
|
provider.create_user.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_admin_exists_session_factory_none_skips_migration():
|
# ── Subsequent boot: needs_setup=True → auto-reset ───────────────────────
|
||||||
"""get_session_factory() returns None → return early, no crash."""
|
|
||||||
provider = _make_provider(admin_count=1)
|
|
||||||
store = AsyncMock()
|
def test_needs_setup_true_resets_password():
|
||||||
app = _make_app_stub(store=store)
|
"""Existing admin with needs_setup=True → password reset + token_version bumped."""
|
||||||
|
admin = User(
|
||||||
|
email="admin@deerflow.dev",
|
||||||
|
password_hash="old-hash",
|
||||||
|
system_role="admin",
|
||||||
|
needs_setup=True,
|
||||||
|
token_version=0,
|
||||||
|
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
||||||
|
)
|
||||||
|
provider = _make_provider(user_count=1, admin_user=admin)
|
||||||
|
app = _make_app_stub()
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=None):
|
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
||||||
from app.gateway.app import _ensure_admin_user
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
asyncio.run(_ensure_admin_user(app))
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
store.asearch.assert_not_called()
|
# Password was reset
|
||||||
|
provider.update_user.assert_called_once()
|
||||||
|
updated = provider.update_user.call_args[0][0]
|
||||||
|
assert updated.password_hash == "new-hash"
|
||||||
|
assert updated.token_version == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_needs_setup_true_consecutive_resets_increment_version():
|
||||||
|
"""Two boots with needs_setup=True → token_version increments each time."""
|
||||||
|
admin = User(
|
||||||
|
email="admin@deerflow.dev",
|
||||||
|
password_hash="hash",
|
||||||
|
system_role="admin",
|
||||||
|
needs_setup=True,
|
||||||
|
token_version=3,
|
||||||
|
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
||||||
|
)
|
||||||
|
provider = _make_provider(user_count=1, admin_user=admin)
|
||||||
|
app = _make_app_stub()
|
||||||
|
|
||||||
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
|
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
||||||
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
|
updated = provider.update_user.call_args[0][0]
|
||||||
|
assert updated.token_version == 4
|
||||||
|
|
||||||
|
|
||||||
|
# ── Subsequent boot: needs_setup=False → no-op ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_needs_setup_false_no_reset():
|
||||||
|
"""Admin with needs_setup=False → no password reset, no update."""
|
||||||
|
admin = User(
|
||||||
|
email="admin@deerflow.dev",
|
||||||
|
password_hash="stable-hash",
|
||||||
|
system_role="admin",
|
||||||
|
needs_setup=False,
|
||||||
|
token_version=2,
|
||||||
|
)
|
||||||
|
provider = _make_provider(user_count=1, admin_user=admin)
|
||||||
|
app = _make_app_stub()
|
||||||
|
|
||||||
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
|
provider.update_user.assert_not_called()
|
||||||
|
assert admin.password_hash == "stable-hash"
|
||||||
|
assert admin.token_version == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ── Edge cases ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_admin_email_found_no_crash():
|
||||||
|
"""Users exist but no admin@deerflow.dev → no crash, no reset."""
|
||||||
|
provider = _make_provider(user_count=3, admin_user=None)
|
||||||
|
app = _make_app_stub()
|
||||||
|
|
||||||
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
|
provider.update_user.assert_not_called()
|
||||||
|
provider.create_user.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def test_migration_failure_is_non_fatal():
|
def test_migration_failure_is_non_fatal():
|
||||||
"""_migrate_orphaned_threads exception is caught and logged."""
|
"""_migrate_orphaned_threads exception is caught and logged."""
|
||||||
from uuid import uuid4
|
provider = _make_provider(user_count=0)
|
||||||
|
|
||||||
admin_row = MagicMock()
|
|
||||||
admin_row.id = uuid4()
|
|
||||||
|
|
||||||
provider = _make_provider(admin_count=1)
|
|
||||||
sf = _make_session_factory(admin_row=admin_row)
|
|
||||||
store = AsyncMock()
|
store = AsyncMock()
|
||||||
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
|
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
|
||||||
app = _make_app_stub(store=store)
|
app = _make_app_stub(store=store)
|
||||||
|
|
||||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
with patch("deerflow.persistence.engine.get_session_factory", return_value=sf):
|
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||||
from app.gateway.app import _ensure_admin_user
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
# Should not raise
|
# Should not raise
|
||||||
asyncio.run(_ensure_admin_user(app))
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
|
provider.create_user.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
# ── Section 5.1-5.6 upgrade path: orphan thread migration ────────────────
|
# ── Section 5.1-5.6 upgrade path: orphan thread migration ────────────────
|
||||||
|
|
||||||
|
|
||||||
def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows():
|
def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
||||||
"""First boot finds Store-only legacy threads → stamps admin's id.
|
"""First boot finds Store-only legacy threads → stamps admin's id.
|
||||||
|
|
||||||
Validates the **TC-UPG-02 upgrade story**: an operator running main
|
Validates the **TC-UPG-02 upgrade story**: an operator running main
|
||||||
(no auth) accumulates threads in the LangGraph Store namespace
|
(no auth) accumulates threads in the LangGraph Store namespace
|
||||||
``("threads",)`` with no ``metadata.user_id``. After upgrading to
|
``("threads",)`` with no ``metadata.owner_id``. After upgrading to
|
||||||
feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should
|
feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should
|
||||||
rewrite each unowned item with the freshly created admin's id.
|
rewrite each unowned item with the freshly created admin's id.
|
||||||
"""
|
"""
|
||||||
@@ -210,7 +233,7 @@ def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows():
|
|||||||
SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}),
|
SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}),
|
||||||
SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}),
|
SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}),
|
||||||
SimpleNamespace(key="t3", value={"metadata": {}}),
|
SimpleNamespace(key="t3", value={"metadata": {}}),
|
||||||
SimpleNamespace(key="t4", value={"metadata": {"user_id": "someone-else", "title": "preserved"}}),
|
SimpleNamespace(key="t4", value={"metadata": {"owner_id": "someone-else", "title": "preserved"}}),
|
||||||
]
|
]
|
||||||
store = AsyncMock()
|
store = AsyncMock()
|
||||||
# asearch returns the entire batch on first call, then an empty page
|
# asearch returns the entire batch on first call, then an empty page
|
||||||
@@ -230,11 +253,11 @@ def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows():
|
|||||||
assert len(aput_calls) == 3
|
assert len(aput_calls) == 3
|
||||||
rewritten_keys = {call[1] for call in aput_calls}
|
rewritten_keys = {call[1] for call in aput_calls}
|
||||||
assert rewritten_keys == {"t1", "t2", "t3"}
|
assert rewritten_keys == {"t1", "t2", "t3"}
|
||||||
# Each rewrite carries the new user_id; titles preserved where present.
|
# Each rewrite carries the new owner_id; titles preserved where present.
|
||||||
by_key = {call[1]: call[2] for call in aput_calls}
|
by_key = {call[1]: call[2] for call in aput_calls}
|
||||||
assert by_key["t1"]["metadata"]["user_id"] == "admin-id-42"
|
assert by_key["t1"]["metadata"]["owner_id"] == "admin-id-42"
|
||||||
assert by_key["t1"]["metadata"]["title"] == "old-thread-1"
|
assert by_key["t1"]["metadata"]["title"] == "old-thread-1"
|
||||||
assert by_key["t3"]["metadata"]["user_id"] == "admin-id-42"
|
assert by_key["t3"]["metadata"]["owner_id"] == "admin-id-42"
|
||||||
# The pre-owned item must NOT have been rewritten.
|
# The pre-owned item must NOT have been rewritten.
|
||||||
assert "t4" not in rewritten_keys
|
assert "t4" not in rewritten_keys
|
||||||
|
|
||||||
|
|||||||
@@ -60,8 +60,8 @@ class TestFeedbackRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_create_with_owner(self, tmp_path):
|
async def test_create_with_owner(self, tmp_path):
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
record = await repo.create(run_id="r1", thread_id="t1", rating=1, owner_id="user-1")
|
||||||
assert record["user_id"] == "user-1"
|
assert record["owner_id"] == "user-1"
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -97,10 +97,10 @@ class TestFeedbackRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_list_by_run(self, tmp_path):
|
async def test_list_by_run(self, tmp_path):
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-2")
|
await repo.create(run_id="r1", thread_id="t1", rating=-1)
|
||||||
await repo.create(run_id="r2", thread_id="t1", rating=1, user_id="user-1")
|
await repo.create(run_id="r2", thread_id="t1", rating=1)
|
||||||
results = await repo.list_by_run("t1", "r1", user_id=None)
|
results = await repo.list_by_run("t1", "r1")
|
||||||
assert len(results) == 2
|
assert len(results) == 2
|
||||||
assert all(r["run_id"] == "r1" for r in results)
|
assert all(r["run_id"] == "r1" for r in results)
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
@@ -135,9 +135,9 @@ class TestFeedbackRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_aggregate_by_run(self, tmp_path):
|
async def test_aggregate_by_run(self, tmp_path):
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-2")
|
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-3")
|
await repo.create(run_id="r1", thread_id="t1", rating=-1)
|
||||||
stats = await repo.aggregate_by_run("t1", "r1")
|
stats = await repo.aggregate_by_run("t1", "r1")
|
||||||
assert stats["total"] == 3
|
assert stats["total"] == 3
|
||||||
assert stats["positive"] == 2
|
assert stats["positive"] == 2
|
||||||
@@ -154,80 +154,6 @@ class TestFeedbackRepository:
|
|||||||
assert stats["negative"] == 0
|
assert stats["negative"] == 0
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_upsert_creates_new(self, tmp_path):
|
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
|
||||||
record = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
|
||||||
assert record["rating"] == 1
|
|
||||||
assert record["feedback_id"]
|
|
||||||
assert record["user_id"] == "u1"
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_upsert_updates_existing(self, tmp_path):
|
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
|
||||||
first = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
|
||||||
second = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u1", comment="changed my mind")
|
|
||||||
assert second["feedback_id"] == first["feedback_id"]
|
|
||||||
assert second["rating"] == -1
|
|
||||||
assert second["comment"] == "changed my mind"
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_upsert_different_users_separate(self, tmp_path):
|
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
|
||||||
r1 = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
|
||||||
r2 = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u2")
|
|
||||||
assert r1["feedback_id"] != r2["feedback_id"]
|
|
||||||
assert r1["rating"] == 1
|
|
||||||
assert r2["rating"] == -1
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_upsert_invalid_rating(self, tmp_path):
|
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1")
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_delete_by_run(self, tmp_path):
|
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
|
||||||
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
|
||||||
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
|
||||||
assert deleted is True
|
|
||||||
results = await repo.list_by_run("t1", "r1", user_id="u1")
|
|
||||||
assert len(results) == 0
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_delete_by_run_nonexistent(self, tmp_path):
|
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
|
||||||
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
|
||||||
assert deleted is False
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_list_by_thread_grouped(self, tmp_path):
|
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
|
||||||
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
|
||||||
await repo.upsert(run_id="r2", thread_id="t1", rating=-1, user_id="u1")
|
|
||||||
await repo.upsert(run_id="r3", thread_id="t2", rating=1, user_id="u1")
|
|
||||||
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
|
||||||
assert "r1" in grouped
|
|
||||||
assert "r2" in grouped
|
|
||||||
assert "r3" not in grouped
|
|
||||||
assert grouped["r1"]["rating"] == 1
|
|
||||||
assert grouped["r2"]["rating"] == -1
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_list_by_thread_grouped_empty(self, tmp_path):
|
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
|
||||||
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
|
||||||
assert grouped == {}
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
|
|
||||||
# -- Follow-up association --
|
# -- Follow-up association --
|
||||||
|
|
||||||
|
|||||||
@@ -1,165 +0,0 @@
|
|||||||
"""Tests for the POST /api/v1/auth/initialize endpoint.
|
|
||||||
|
|
||||||
Covers: first-boot admin creation, rejection when system already
|
|
||||||
initialized, password strength validation,
|
|
||||||
and public accessibility (no auth cookie required).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
|
|
||||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-initialize-admin-min-32")
|
|
||||||
|
|
||||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
|
||||||
|
|
||||||
_TEST_SECRET = "test-secret-key-initialize-admin-min-32"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _setup_auth(tmp_path):
|
|
||||||
"""Fresh SQLite engine + auth config per test."""
|
|
||||||
from app.gateway import deps
|
|
||||||
from deerflow.persistence.engine import close_engine, init_engine
|
|
||||||
|
|
||||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
|
||||||
url = f"sqlite+aiosqlite:///{tmp_path}/init_admin.db"
|
|
||||||
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
|
|
||||||
deps._cached_local_provider = None
|
|
||||||
deps._cached_repo = None
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
deps._cached_local_provider = None
|
|
||||||
deps._cached_repo = None
|
|
||||||
asyncio.run(close_engine())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def client(_setup_auth):
|
|
||||||
from app.gateway.app import create_app
|
|
||||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
|
||||||
|
|
||||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
|
||||||
app = create_app()
|
|
||||||
# Do NOT use TestClient as a context manager — that would trigger the
|
|
||||||
# full lifespan which requires config.yaml. The auth endpoints work
|
|
||||||
# without the lifespan (persistence engine is set up by _setup_auth).
|
|
||||||
yield TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
def _init_payload(**extra):
|
|
||||||
"""Build a valid /initialize payload."""
|
|
||||||
return {
|
|
||||||
"email": "admin@example.com",
|
|
||||||
"password": "Str0ng!Pass99",
|
|
||||||
**extra,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Happy path ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_creates_admin_and_sets_cookie(client):
|
|
||||||
"""POST /initialize when no admin exists → 201, session cookie set."""
|
|
||||||
resp = client.post("/api/v1/auth/initialize", json=_init_payload())
|
|
||||||
assert resp.status_code == 201
|
|
||||||
data = resp.json()
|
|
||||||
assert data["email"] == "admin@example.com"
|
|
||||||
assert data["system_role"] == "admin"
|
|
||||||
assert "access_token" in resp.cookies
|
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_needs_setup_false(client):
|
|
||||||
"""Newly created admin via /initialize has needs_setup=False."""
|
|
||||||
client.post("/api/v1/auth/initialize", json=_init_payload())
|
|
||||||
me = client.get("/api/v1/auth/me")
|
|
||||||
assert me.status_code == 200
|
|
||||||
assert me.json()["needs_setup"] is False
|
|
||||||
|
|
||||||
|
|
||||||
# ── Rejection when already initialized ───────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_rejected_when_admin_exists(client):
|
|
||||||
"""Second call to /initialize after admin exists → 409 system_already_initialized."""
|
|
||||||
client.post("/api/v1/auth/initialize", json=_init_payload())
|
|
||||||
resp2 = client.post(
|
|
||||||
"/api/v1/auth/initialize",
|
|
||||||
json={**_init_payload(), "email": "other@example.com"},
|
|
||||||
)
|
|
||||||
assert resp2.status_code == 409
|
|
||||||
body = resp2.json()
|
|
||||||
assert body["detail"]["code"] == "system_already_initialized"
|
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_register_does_not_block_initialization(client):
|
|
||||||
"""/register creating a user before /initialize doesn't block admin creation."""
|
|
||||||
# Register a regular user first
|
|
||||||
client.post("/api/v1/auth/register", json={"email": "regular@example.com", "password": "Tr0ub4dor3a"})
|
|
||||||
# /initialize should still succeed (checks admin_count, not total user_count)
|
|
||||||
resp = client.post("/api/v1/auth/initialize", json=_init_payload())
|
|
||||||
assert resp.status_code == 201
|
|
||||||
assert resp.json()["system_role"] == "admin"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Endpoint is public (no cookie required) ───────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_accessible_without_cookie(client):
|
|
||||||
"""No access_token cookie needed for /initialize."""
|
|
||||||
resp = client.post(
|
|
||||||
"/api/v1/auth/initialize",
|
|
||||||
json=_init_payload(),
|
|
||||||
cookies={},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 201
|
|
||||||
|
|
||||||
|
|
||||||
# ── Password validation ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_rejects_short_password(client):
|
|
||||||
"""Password shorter than 8 chars → 422."""
|
|
||||||
resp = client.post(
|
|
||||||
"/api/v1/auth/initialize",
|
|
||||||
json={**_init_payload(), "password": "short"},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 422
|
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_rejects_common_password(client):
|
|
||||||
"""Common password → 422."""
|
|
||||||
resp = client.post(
|
|
||||||
"/api/v1/auth/initialize",
|
|
||||||
json={**_init_payload(), "password": "password123"},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 422
|
|
||||||
|
|
||||||
|
|
||||||
# ── setup-status reflects initialization ─────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_setup_status_before_initialization(client):
|
|
||||||
"""setup-status returns needs_setup=True before /initialize is called."""
|
|
||||||
resp = client.get("/api/v1/auth/setup-status")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.json()["needs_setup"] is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_setup_status_after_initialization(client):
|
|
||||||
"""setup-status returns needs_setup=False after /initialize succeeds."""
|
|
||||||
client.post("/api/v1/auth/initialize", json=_init_payload())
|
|
||||||
resp = client.get("/api/v1/auth/setup-status")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.json()["needs_setup"] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_setup_status_false_when_only_regular_user_exists(client):
|
|
||||||
"""setup-status returns needs_setup=True even when regular users exist (no admin)."""
|
|
||||||
client.post("/api/v1/auth/register", json={"email": "regular@example.com", "password": "Tr0ub4dor3a"})
|
|
||||||
resp = client.get("/api/v1/auth/setup-status")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.json()["needs_setup"] is True
|
|
||||||
@@ -152,10 +152,8 @@ def test_get_work_dir_uses_base_dir_when_no_thread_id(monkeypatch, tmp_path):
|
|||||||
def test_get_work_dir_uses_per_thread_path_when_thread_id_given(monkeypatch, tmp_path):
|
def test_get_work_dir_uses_per_thread_path_when_thread_id_given(monkeypatch, tmp_path):
|
||||||
"""P1.1: _get_work_dir(thread_id) uses {base_dir}/threads/{thread_id}/acp-workspace/."""
|
"""P1.1: _get_work_dir(thread_id) uses {base_dir}/threads/{thread_id}/acp-workspace/."""
|
||||||
from deerflow.config import paths as paths_module
|
from deerflow.config import paths as paths_module
|
||||||
from deerflow.runtime import user_context as uc_module
|
|
||||||
|
|
||||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||||
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
|
|
||||||
result = _get_work_dir("thread-abc-123")
|
result = _get_work_dir("thread-abc-123")
|
||||||
expected = tmp_path / "threads" / "thread-abc-123" / "acp-workspace"
|
expected = tmp_path / "threads" / "thread-abc-123" / "acp-workspace"
|
||||||
assert result == str(expected)
|
assert result == str(expected)
|
||||||
@@ -312,10 +310,8 @@ async def test_invoke_acp_agent_uses_fixed_acp_workspace(monkeypatch, tmp_path):
|
|||||||
async def test_invoke_acp_agent_uses_per_thread_workspace_when_thread_id_in_config(monkeypatch, tmp_path):
|
async def test_invoke_acp_agent_uses_per_thread_workspace_when_thread_id_in_config(monkeypatch, tmp_path):
|
||||||
"""P1.1: When thread_id is in the RunnableConfig, ACP agent uses per-thread workspace."""
|
"""P1.1: When thread_id is in the RunnableConfig, ACP agent uses per-thread workspace."""
|
||||||
from deerflow.config import paths as paths_module
|
from deerflow.config import paths as paths_module
|
||||||
from deerflow.runtime import user_context as uc_module
|
|
||||||
|
|
||||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||||
monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None)
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||||
|
|||||||
@@ -175,46 +175,46 @@ def _make_ctx(user_id):
|
|||||||
def test_filter_injects_user_id():
|
def test_filter_injects_user_id():
|
||||||
value = {}
|
value = {}
|
||||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||||
assert value["metadata"]["user_id"] == "user-a"
|
assert value["metadata"]["owner_id"] == "user-a"
|
||||||
|
|
||||||
|
|
||||||
def test_filter_preserves_existing_metadata():
|
def test_filter_preserves_existing_metadata():
|
||||||
value = {"metadata": {"title": "hello"}}
|
value = {"metadata": {"title": "hello"}}
|
||||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||||
assert value["metadata"]["user_id"] == "user-a"
|
assert value["metadata"]["owner_id"] == "user-a"
|
||||||
assert value["metadata"]["title"] == "hello"
|
assert value["metadata"]["title"] == "hello"
|
||||||
|
|
||||||
|
|
||||||
def test_filter_returns_user_id_dict():
|
def test_filter_returns_user_id_dict():
|
||||||
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
||||||
assert result == {"user_id": "user-x"}
|
assert result == {"owner_id": "user-x"}
|
||||||
|
|
||||||
|
|
||||||
def test_filter_read_write_consistency():
|
def test_filter_read_write_consistency():
|
||||||
value = {}
|
value = {}
|
||||||
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
||||||
assert value["metadata"]["user_id"] == filter_dict["user_id"]
|
assert value["metadata"]["owner_id"] == filter_dict["owner_id"]
|
||||||
|
|
||||||
|
|
||||||
def test_different_users_different_filters():
|
def test_different_users_different_filters():
|
||||||
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
||||||
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
||||||
assert f_a["user_id"] != f_b["user_id"]
|
assert f_a["owner_id"] != f_b["owner_id"]
|
||||||
|
|
||||||
|
|
||||||
def test_filter_overrides_conflicting_user_id():
|
def test_filter_overrides_conflicting_user_id():
|
||||||
"""If value already has a different user_id in metadata, it gets overwritten."""
|
"""If value already has a different user_id in metadata, it gets overwritten."""
|
||||||
value = {"metadata": {"user_id": "attacker"}}
|
value = {"metadata": {"owner_id": "attacker"}}
|
||||||
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
||||||
assert value["metadata"]["user_id"] == "real-owner"
|
assert value["metadata"]["owner_id"] == "real-owner"
|
||||||
|
|
||||||
|
|
||||||
def test_filter_with_empty_metadata():
|
def test_filter_with_empty_metadata():
|
||||||
"""Explicit empty metadata dict is fine."""
|
"""Explicit empty metadata dict is fine."""
|
||||||
value = {"metadata": {}}
|
value = {"metadata": {}}
|
||||||
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
||||||
assert value["metadata"]["user_id"] == "user-z"
|
assert value["metadata"]["owner_id"] == "user-z"
|
||||||
assert result == {"user_id": "user-z"}
|
assert result == {"owner_id": "user-z"}
|
||||||
|
|
||||||
|
|
||||||
# ── Gateway parity ───────────────────────────────────────────────────────
|
# ── Gateway parity ───────────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -48,7 +48,6 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
|||||||
agent_name="lead_agent",
|
agent_name="lead_agent",
|
||||||
correction_detected=True,
|
correction_detected=True,
|
||||||
reinforcement_detected=False,
|
reinforcement_detected=False,
|
||||||
user_id=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -89,5 +88,4 @@ def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
|||||||
agent_name="lead_agent",
|
agent_name="lead_agent",
|
||||||
correction_detected=False,
|
correction_detected=False,
|
||||||
reinforcement_detected=True,
|
reinforcement_detected=True,
|
||||||
user_id=None,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,38 +0,0 @@
|
|||||||
"""Tests for user_id propagation through memory queue."""
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
|
||||||
|
|
||||||
|
|
||||||
def test_conversation_context_has_user_id():
|
|
||||||
ctx = ConversationContext(thread_id="t1", messages=[], user_id="alice")
|
|
||||||
assert ctx.user_id == "alice"
|
|
||||||
|
|
||||||
|
|
||||||
def test_conversation_context_user_id_default_none():
|
|
||||||
ctx = ConversationContext(thread_id="t1", messages=[])
|
|
||||||
assert ctx.user_id is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_queue_add_stores_user_id():
|
|
||||||
q = MemoryUpdateQueue()
|
|
||||||
with patch.object(q, "_reset_timer"):
|
|
||||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
|
||||||
assert len(q._queue) == 1
|
|
||||||
assert q._queue[0].user_id == "alice"
|
|
||||||
q.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def test_queue_process_passes_user_id_to_updater():
|
|
||||||
q = MemoryUpdateQueue()
|
|
||||||
with patch.object(q, "_reset_timer"):
|
|
||||||
q.add(thread_id="t1", messages=["msg"], user_id="alice")
|
|
||||||
|
|
||||||
mock_updater = MagicMock()
|
|
||||||
mock_updater.update_memory.return_value = True
|
|
||||||
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
|
|
||||||
q._process_queue()
|
|
||||||
|
|
||||||
mock_updater.update_memory.assert_called_once()
|
|
||||||
call_kwargs = mock_updater.update_memory.call_args.kwargs
|
|
||||||
assert call_kwargs["user_id"] == "alice"
|
|
||||||
@@ -258,13 +258,12 @@ def test_update_memory_fact_route_preserves_omitted_fields() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert update_fact.call_count == 1
|
update_fact.assert_called_once_with(
|
||||||
call_kwargs = update_fact.call_args.kwargs
|
fact_id="fact_edit",
|
||||||
assert call_kwargs.get("fact_id") == "fact_edit"
|
content="User prefers spaces",
|
||||||
assert call_kwargs.get("content") == "User prefers spaces"
|
category=None,
|
||||||
assert call_kwargs.get("category") is None
|
confidence=None,
|
||||||
assert call_kwargs.get("confidence") is None
|
)
|
||||||
assert "user_id" in call_kwargs
|
|
||||||
assert response.json()["facts"] == updated_memory["facts"]
|
assert response.json()["facts"] == updated_memory["facts"]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,150 +0,0 @@
|
|||||||
"""Tests for per-user memory storage isolation."""
|
|
||||||
import pytest
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from deerflow.agents.memory.storage import FileMemoryStorage, create_empty_memory
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def base_dir(tmp_path: Path) -> Path:
|
|
||||||
return tmp_path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def storage() -> FileMemoryStorage:
|
|
||||||
return FileMemoryStorage()
|
|
||||||
|
|
||||||
|
|
||||||
class TestUserIsolatedStorage:
|
|
||||||
def test_save_and_load_per_user(self, storage: FileMemoryStorage, base_dir: Path):
|
|
||||||
from deerflow.config.paths import Paths
|
|
||||||
|
|
||||||
paths = Paths(base_dir)
|
|
||||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
|
||||||
memory_a = create_empty_memory()
|
|
||||||
memory_a["user"]["workContext"]["summary"] = "User A context"
|
|
||||||
storage.save(memory_a, user_id="alice")
|
|
||||||
|
|
||||||
memory_b = create_empty_memory()
|
|
||||||
memory_b["user"]["workContext"]["summary"] = "User B context"
|
|
||||||
storage.save(memory_b, user_id="bob")
|
|
||||||
|
|
||||||
loaded_a = storage.load(user_id="alice")
|
|
||||||
loaded_b = storage.load(user_id="bob")
|
|
||||||
|
|
||||||
assert loaded_a["user"]["workContext"]["summary"] == "User A context"
|
|
||||||
assert loaded_b["user"]["workContext"]["summary"] == "User B context"
|
|
||||||
|
|
||||||
def test_user_memory_file_location(self, base_dir: Path):
|
|
||||||
from deerflow.config.paths import Paths
|
|
||||||
|
|
||||||
paths = Paths(base_dir)
|
|
||||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
|
||||||
s = FileMemoryStorage()
|
|
||||||
memory = create_empty_memory()
|
|
||||||
s.save(memory, user_id="alice")
|
|
||||||
expected_path = base_dir / "users" / "alice" / "memory.json"
|
|
||||||
assert expected_path.exists()
|
|
||||||
|
|
||||||
def test_cache_isolated_per_user(self, base_dir: Path):
|
|
||||||
from deerflow.config.paths import Paths
|
|
||||||
|
|
||||||
paths = Paths(base_dir)
|
|
||||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
|
||||||
s = FileMemoryStorage()
|
|
||||||
memory_a = create_empty_memory()
|
|
||||||
memory_a["user"]["workContext"]["summary"] = "A"
|
|
||||||
s.save(memory_a, user_id="alice")
|
|
||||||
|
|
||||||
memory_b = create_empty_memory()
|
|
||||||
memory_b["user"]["workContext"]["summary"] = "B"
|
|
||||||
s.save(memory_b, user_id="bob")
|
|
||||||
|
|
||||||
loaded_a = s.load(user_id="alice")
|
|
||||||
assert loaded_a["user"]["workContext"]["summary"] == "A"
|
|
||||||
|
|
||||||
def test_no_user_id_uses_legacy_path(self, base_dir: Path):
|
|
||||||
from deerflow.config.paths import Paths
|
|
||||||
from deerflow.config.memory_config import MemoryConfig
|
|
||||||
|
|
||||||
paths = Paths(base_dir)
|
|
||||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
|
||||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
|
||||||
s = FileMemoryStorage()
|
|
||||||
memory = create_empty_memory()
|
|
||||||
s.save(memory, user_id=None)
|
|
||||||
expected_path = base_dir / "memory.json"
|
|
||||||
assert expected_path.exists()
|
|
||||||
|
|
||||||
def test_user_and_legacy_do_not_interfere(self, base_dir: Path):
|
|
||||||
"""user_id=None (legacy) and user_id='alice' must use different files and caches."""
|
|
||||||
from deerflow.config.paths import Paths
|
|
||||||
from deerflow.config.memory_config import MemoryConfig
|
|
||||||
|
|
||||||
paths = Paths(base_dir)
|
|
||||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
|
||||||
with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")):
|
|
||||||
s = FileMemoryStorage()
|
|
||||||
|
|
||||||
legacy_mem = create_empty_memory()
|
|
||||||
legacy_mem["user"]["workContext"]["summary"] = "legacy"
|
|
||||||
s.save(legacy_mem, user_id=None)
|
|
||||||
|
|
||||||
user_mem = create_empty_memory()
|
|
||||||
user_mem["user"]["workContext"]["summary"] = "alice"
|
|
||||||
s.save(user_mem, user_id="alice")
|
|
||||||
|
|
||||||
assert s.load(user_id=None)["user"]["workContext"]["summary"] == "legacy"
|
|
||||||
assert s.load(user_id="alice")["user"]["workContext"]["summary"] == "alice"
|
|
||||||
|
|
||||||
def test_user_agent_memory_file_location(self, base_dir: Path):
|
|
||||||
"""Per-user per-agent memory uses the user_agent_memory_file path."""
|
|
||||||
from deerflow.config.paths import Paths
|
|
||||||
|
|
||||||
paths = Paths(base_dir)
|
|
||||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
|
||||||
s = FileMemoryStorage()
|
|
||||||
memory = create_empty_memory()
|
|
||||||
memory["user"]["workContext"]["summary"] = "agent scoped"
|
|
||||||
s.save(memory, "test-agent", user_id="alice")
|
|
||||||
expected_path = base_dir / "users" / "alice" / "agents" / "test-agent" / "memory.json"
|
|
||||||
assert expected_path.exists()
|
|
||||||
|
|
||||||
def test_cache_key_is_user_agent_tuple(self, base_dir: Path):
|
|
||||||
"""Cache keys must be (user_id, agent_name) tuples, not bare agent names."""
|
|
||||||
from deerflow.config.paths import Paths
|
|
||||||
|
|
||||||
paths = Paths(base_dir)
|
|
||||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
|
||||||
s = FileMemoryStorage()
|
|
||||||
memory = create_empty_memory()
|
|
||||||
s.save(memory, user_id="alice")
|
|
||||||
# After save, cache should have tuple key
|
|
||||||
assert ("alice", None) in s._memory_cache
|
|
||||||
|
|
||||||
def test_reload_with_user_id(self, base_dir: Path):
|
|
||||||
"""reload() with user_id should force re-read from the user-scoped file."""
|
|
||||||
from deerflow.config.paths import Paths
|
|
||||||
|
|
||||||
paths = Paths(base_dir)
|
|
||||||
with patch("deerflow.agents.memory.storage.get_paths", return_value=paths):
|
|
||||||
s = FileMemoryStorage()
|
|
||||||
memory = create_empty_memory()
|
|
||||||
memory["user"]["workContext"]["summary"] = "initial"
|
|
||||||
s.save(memory, user_id="alice")
|
|
||||||
|
|
||||||
# Load once to prime cache
|
|
||||||
s.load(user_id="alice")
|
|
||||||
|
|
||||||
# Write updated content directly to file
|
|
||||||
user_file = base_dir / "users" / "alice" / "memory.json"
|
|
||||||
import json
|
|
||||||
|
|
||||||
updated = create_empty_memory()
|
|
||||||
updated["user"]["workContext"]["summary"] = "updated"
|
|
||||||
user_file.write_text(json.dumps(updated))
|
|
||||||
|
|
||||||
# reload should pick up the new content
|
|
||||||
reloaded = s.reload(user_id="alice")
|
|
||||||
assert reloaded["user"]["workContext"]["summary"] == "updated"
|
|
||||||
@@ -1,156 +0,0 @@
|
|||||||
"""Owner isolation tests for MemoryThreadMetaStore.
|
|
||||||
|
|
||||||
Mirrors the SQL-backed tests in test_owner_isolation.py but exercises
|
|
||||||
the in-memory LangGraph Store backend used when database.backend=memory.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from langgraph.store.memory import InMemoryStore
|
|
||||||
|
|
||||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
|
||||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
|
||||||
|
|
||||||
USER_A = SimpleNamespace(id="user-a", email="a@test.local")
|
|
||||||
USER_B = SimpleNamespace(id="user-b", email="b@test.local")
|
|
||||||
|
|
||||||
|
|
||||||
def _as_user(user):
|
|
||||||
class _Ctx:
|
|
||||||
def __enter__(self):
|
|
||||||
self._token = set_current_user(user)
|
|
||||||
return user
|
|
||||||
|
|
||||||
def __exit__(self, *exc):
|
|
||||||
reset_current_user(self._token)
|
|
||||||
|
|
||||||
return _Ctx()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def store():
|
|
||||||
return MemoryThreadMetaStore(InMemoryStore())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
@pytest.mark.no_auto_user
|
|
||||||
async def test_search_isolation(store):
|
|
||||||
"""search() returns only threads owned by the current user."""
|
|
||||||
with _as_user(USER_A):
|
|
||||||
await store.create("t-alpha", display_name="A's thread")
|
|
||||||
with _as_user(USER_B):
|
|
||||||
await store.create("t-beta", display_name="B's thread")
|
|
||||||
|
|
||||||
with _as_user(USER_A):
|
|
||||||
results = await store.search()
|
|
||||||
assert [r["thread_id"] for r in results] == ["t-alpha"]
|
|
||||||
|
|
||||||
with _as_user(USER_B):
|
|
||||||
results = await store.search()
|
|
||||||
assert [r["thread_id"] for r in results] == ["t-beta"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
@pytest.mark.no_auto_user
|
|
||||||
async def test_get_isolation(store):
|
|
||||||
"""get() returns None for threads owned by another user."""
|
|
||||||
with _as_user(USER_A):
|
|
||||||
await store.create("t-alpha", display_name="A's thread")
|
|
||||||
|
|
||||||
with _as_user(USER_B):
|
|
||||||
assert await store.get("t-alpha") is None
|
|
||||||
|
|
||||||
with _as_user(USER_A):
|
|
||||||
result = await store.get("t-alpha")
|
|
||||||
assert result is not None
|
|
||||||
assert result["display_name"] == "A's thread"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
@pytest.mark.no_auto_user
|
|
||||||
async def test_update_display_name_denied(store):
|
|
||||||
"""User B cannot rename User A's thread."""
|
|
||||||
with _as_user(USER_A):
|
|
||||||
await store.create("t-alpha", display_name="original")
|
|
||||||
|
|
||||||
with _as_user(USER_B):
|
|
||||||
await store.update_display_name("t-alpha", "hacked")
|
|
||||||
|
|
||||||
with _as_user(USER_A):
|
|
||||||
row = await store.get("t-alpha")
|
|
||||||
assert row is not None
|
|
||||||
assert row["display_name"] == "original"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
@pytest.mark.no_auto_user
|
|
||||||
async def test_update_status_denied(store):
|
|
||||||
"""User B cannot change status of User A's thread."""
|
|
||||||
with _as_user(USER_A):
|
|
||||||
await store.create("t-alpha")
|
|
||||||
|
|
||||||
with _as_user(USER_B):
|
|
||||||
await store.update_status("t-alpha", "error")
|
|
||||||
|
|
||||||
with _as_user(USER_A):
|
|
||||||
row = await store.get("t-alpha")
|
|
||||||
assert row is not None
|
|
||||||
assert row["status"] == "idle"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
@pytest.mark.no_auto_user
|
|
||||||
async def test_update_metadata_denied(store):
|
|
||||||
"""User B cannot modify metadata of User A's thread."""
|
|
||||||
with _as_user(USER_A):
|
|
||||||
await store.create("t-alpha", metadata={"key": "original"})
|
|
||||||
|
|
||||||
with _as_user(USER_B):
|
|
||||||
await store.update_metadata("t-alpha", {"key": "hacked"})
|
|
||||||
|
|
||||||
with _as_user(USER_A):
|
|
||||||
row = await store.get("t-alpha")
|
|
||||||
assert row is not None
|
|
||||||
assert row["metadata"]["key"] == "original"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
@pytest.mark.no_auto_user
|
|
||||||
async def test_delete_denied(store):
|
|
||||||
"""User B cannot delete User A's thread."""
|
|
||||||
with _as_user(USER_A):
|
|
||||||
await store.create("t-alpha")
|
|
||||||
|
|
||||||
with _as_user(USER_B):
|
|
||||||
await store.delete("t-alpha")
|
|
||||||
|
|
||||||
with _as_user(USER_A):
|
|
||||||
row = await store.get("t-alpha")
|
|
||||||
assert row is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
@pytest.mark.no_auto_user
|
|
||||||
async def test_no_context_raises(store):
|
|
||||||
"""Calling methods without user context raises RuntimeError."""
|
|
||||||
with pytest.raises(RuntimeError, match="no user context is set"):
|
|
||||||
await store.search()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
@pytest.mark.no_auto_user
|
|
||||||
async def test_explicit_none_bypasses_filter(store):
|
|
||||||
"""user_id=None bypasses isolation (migration/CLI escape hatch)."""
|
|
||||||
with _as_user(USER_A):
|
|
||||||
await store.create("t-alpha")
|
|
||||||
with _as_user(USER_B):
|
|
||||||
await store.create("t-beta")
|
|
||||||
|
|
||||||
all_rows = await store.search(user_id=None)
|
|
||||||
assert {r["thread_id"] for r in all_rows} == {"t-alpha", "t-beta"}
|
|
||||||
|
|
||||||
row = await store.get("t-alpha", user_id=None)
|
|
||||||
assert row is not None
|
|
||||||
@@ -301,8 +301,8 @@ def test_import_memory_data_saves_and_returns_imported_memory() -> None:
|
|||||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
||||||
result = import_memory_data(imported_memory)
|
result = import_memory_data(imported_memory)
|
||||||
|
|
||||||
mock_storage.save.assert_called_once_with(imported_memory, None, user_id=None)
|
mock_storage.save.assert_called_once_with(imported_memory, None)
|
||||||
mock_storage.load.assert_called_once_with(None, user_id=None)
|
mock_storage.load.assert_called_once_with(None)
|
||||||
assert result == imported_memory
|
assert result == imported_memory
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,29 +0,0 @@
|
|||||||
"""Tests for user_id propagation in memory updater."""
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
from deerflow.agents.memory.updater import get_memory_data, clear_memory_data, _save_memory_to_file
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_memory_data_passes_user_id():
|
|
||||||
mock_storage = MagicMock()
|
|
||||||
mock_storage.load.return_value = {"version": "1.0"}
|
|
||||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
|
||||||
get_memory_data(user_id="alice")
|
|
||||||
mock_storage.load.assert_called_once_with(None, user_id="alice")
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_memory_passes_user_id():
|
|
||||||
mock_storage = MagicMock()
|
|
||||||
mock_storage.save.return_value = True
|
|
||||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
|
||||||
_save_memory_to_file({"version": "1.0"}, user_id="bob")
|
|
||||||
mock_storage.save.assert_called_once_with({"version": "1.0"}, None, user_id="bob")
|
|
||||||
|
|
||||||
|
|
||||||
def test_clear_memory_data_passes_user_id():
|
|
||||||
mock_storage = MagicMock()
|
|
||||||
mock_storage.save.return_value = True
|
|
||||||
with patch("deerflow.agents.memory.updater.get_memory_storage", return_value=mock_storage):
|
|
||||||
clear_memory_data(user_id="charlie")
|
|
||||||
# Verify save was called with user_id
|
|
||||||
assert mock_storage.save.call_args.kwargs["user_id"] == "charlie"
|
|
||||||
@@ -1,116 +0,0 @@
|
|||||||
"""Tests for per-user data migration."""
|
|
||||||
import json
|
|
||||||
import pytest
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from deerflow.config.paths import Paths
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def base_dir(tmp_path: Path) -> Path:
|
|
||||||
return tmp_path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def paths(base_dir: Path) -> Paths:
|
|
||||||
return Paths(base_dir)
|
|
||||||
|
|
||||||
|
|
||||||
class TestMigrateThreadDirs:
|
|
||||||
def test_moves_thread_to_user_dir(self, base_dir: Path, paths: Paths):
|
|
||||||
legacy = base_dir / "threads" / "t1" / "user-data" / "workspace"
|
|
||||||
legacy.mkdir(parents=True)
|
|
||||||
(legacy / "file.txt").write_text("hello")
|
|
||||||
|
|
||||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
|
||||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
|
||||||
|
|
||||||
expected = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace" / "file.txt"
|
|
||||||
assert expected.exists()
|
|
||||||
assert expected.read_text() == "hello"
|
|
||||||
assert not (base_dir / "threads" / "t1").exists()
|
|
||||||
|
|
||||||
def test_unowned_thread_goes_to_default(self, base_dir: Path, paths: Paths):
|
|
||||||
legacy = base_dir / "threads" / "t2" / "user-data" / "workspace"
|
|
||||||
legacy.mkdir(parents=True)
|
|
||||||
|
|
||||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
|
||||||
migrate_thread_dirs(paths, thread_owner_map={})
|
|
||||||
|
|
||||||
expected = base_dir / "users" / "default" / "threads" / "t2"
|
|
||||||
assert expected.exists()
|
|
||||||
|
|
||||||
def test_idempotent_skip_already_migrated(self, base_dir: Path, paths: Paths):
|
|
||||||
new_dir = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace"
|
|
||||||
new_dir.mkdir(parents=True)
|
|
||||||
|
|
||||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
|
||||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
|
||||||
assert new_dir.exists()
|
|
||||||
|
|
||||||
def test_conflict_preserved(self, base_dir: Path, paths: Paths):
|
|
||||||
legacy = base_dir / "threads" / "t1" / "user-data" / "workspace"
|
|
||||||
legacy.mkdir(parents=True)
|
|
||||||
(legacy / "old.txt").write_text("old")
|
|
||||||
|
|
||||||
dest = base_dir / "users" / "alice" / "threads" / "t1" / "user-data" / "workspace"
|
|
||||||
dest.mkdir(parents=True)
|
|
||||||
(dest / "new.txt").write_text("new")
|
|
||||||
|
|
||||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
|
||||||
migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"})
|
|
||||||
|
|
||||||
assert (dest / "new.txt").read_text() == "new"
|
|
||||||
conflicts = base_dir / "migration-conflicts" / "t1"
|
|
||||||
assert conflicts.exists()
|
|
||||||
|
|
||||||
def test_cleans_up_empty_legacy_dir(self, base_dir: Path, paths: Paths):
|
|
||||||
legacy = base_dir / "threads" / "t1" / "user-data"
|
|
||||||
legacy.mkdir(parents=True)
|
|
||||||
|
|
||||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
|
||||||
migrate_thread_dirs(paths, thread_owner_map={})
|
|
||||||
|
|
||||||
assert not (base_dir / "threads").exists()
|
|
||||||
|
|
||||||
def test_dry_run_does_not_move(self, base_dir: Path, paths: Paths):
|
|
||||||
legacy = base_dir / "threads" / "t1" / "user-data"
|
|
||||||
legacy.mkdir(parents=True)
|
|
||||||
|
|
||||||
from scripts.migrate_user_isolation import migrate_thread_dirs
|
|
||||||
report = migrate_thread_dirs(paths, thread_owner_map={"t1": "alice"}, dry_run=True)
|
|
||||||
|
|
||||||
assert len(report) == 1
|
|
||||||
assert (base_dir / "threads" / "t1").exists() # not moved
|
|
||||||
assert not (base_dir / "users" / "alice" / "threads" / "t1").exists()
|
|
||||||
|
|
||||||
|
|
||||||
class TestMigrateMemory:
|
|
||||||
def test_moves_global_memory(self, base_dir: Path, paths: Paths):
|
|
||||||
legacy_mem = base_dir / "memory.json"
|
|
||||||
legacy_mem.write_text(json.dumps({"version": "1.0", "facts": []}))
|
|
||||||
|
|
||||||
from scripts.migrate_user_isolation import migrate_memory
|
|
||||||
migrate_memory(paths, user_id="default")
|
|
||||||
|
|
||||||
expected = base_dir / "users" / "default" / "memory.json"
|
|
||||||
assert expected.exists()
|
|
||||||
assert not legacy_mem.exists()
|
|
||||||
|
|
||||||
def test_skips_if_destination_exists(self, base_dir: Path, paths: Paths):
|
|
||||||
legacy_mem = base_dir / "memory.json"
|
|
||||||
legacy_mem.write_text(json.dumps({"version": "old"}))
|
|
||||||
|
|
||||||
dest = base_dir / "users" / "default" / "memory.json"
|
|
||||||
dest.parent.mkdir(parents=True)
|
|
||||||
dest.write_text(json.dumps({"version": "new"}))
|
|
||||||
|
|
||||||
from scripts.migrate_user_isolation import migrate_memory
|
|
||||||
migrate_memory(paths, user_id="default")
|
|
||||||
|
|
||||||
assert json.loads(dest.read_text())["version"] == "new"
|
|
||||||
assert (base_dir / "memory.legacy.json").exists()
|
|
||||||
|
|
||||||
def test_no_legacy_memory_is_noop(self, base_dir: Path, paths: Paths):
|
|
||||||
from scripts.migrate_user_isolation import migrate_memory
|
|
||||||
migrate_memory(paths, user_id="default") # should not raise
|
|
||||||
@@ -9,8 +9,8 @@ These tests bypass the HTTP layer and exercise the storage-layer
|
|||||||
owner filter directly by switching the ``user_context`` contextvar
|
owner filter directly by switching the ``user_context`` contextvar
|
||||||
between two users. The safety property under test is:
|
between two users. The safety property under test is:
|
||||||
|
|
||||||
After a repository write with user_id=A, a subsequent read with
|
After a repository write with owner_id=A, a subsequent read with
|
||||||
user_id=B must not return the row, and vice versa.
|
owner_id=B must not return the row, and vice versa.
|
||||||
|
|
||||||
The HTTP layer is covered by test_auth_middleware.py, which proves
|
The HTTP layer is covered by test_auth_middleware.py, which proves
|
||||||
that a request cookie reaches the ``set_current_user`` call. Together
|
that a request cookie reaches the ``set_current_user`` call. Together
|
||||||
@@ -431,13 +431,13 @@ async def test_repository_without_context_raises(tmp_path):
|
|||||||
await cleanup()
|
await cleanup()
|
||||||
|
|
||||||
|
|
||||||
# ── Escape hatch: explicit user_id=None bypasses filter (for migration) ──
|
# ── Escape hatch: explicit owner_id=None bypasses filter (for migration) ──
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.no_auto_user
|
@pytest.mark.no_auto_user
|
||||||
async def test_explicit_none_bypasses_filter(tmp_path):
|
async def test_explicit_none_bypasses_filter(tmp_path):
|
||||||
"""Migration scripts pass user_id=None to see all rows regardless of owner."""
|
"""Migration scripts pass owner_id=None to see all rows regardless of owner."""
|
||||||
from deerflow.persistence.engine import get_session_factory
|
from deerflow.persistence.engine import get_session_factory
|
||||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||||
|
|
||||||
@@ -452,14 +452,14 @@ async def test_explicit_none_bypasses_filter(tmp_path):
|
|||||||
await repo.create("t-beta")
|
await repo.create("t-beta")
|
||||||
|
|
||||||
# Migration-style read: no contextvar, explicit None bypass.
|
# Migration-style read: no contextvar, explicit None bypass.
|
||||||
all_rows = await repo.search(user_id=None)
|
all_rows = await repo.search(owner_id=None)
|
||||||
thread_ids = {r["thread_id"] for r in all_rows}
|
thread_ids = {r["thread_id"] for r in all_rows}
|
||||||
assert thread_ids == {"t-alpha", "t-beta"}
|
assert thread_ids == {"t-alpha", "t-beta"}
|
||||||
|
|
||||||
# Explicit get with None does not apply the filter either.
|
# Explicit get with None does not apply the filter either.
|
||||||
row_a = await repo.get("t-alpha", user_id=None)
|
row_a = await repo.get("t-alpha", owner_id=None)
|
||||||
assert row_a is not None
|
assert row_a is not None
|
||||||
row_b = await repo.get("t-beta", user_id=None)
|
row_b = await repo.get("t-beta", owner_id=None)
|
||||||
assert row_b is not None
|
assert row_b is not None
|
||||||
finally:
|
finally:
|
||||||
await cleanup()
|
await cleanup()
|
||||||
|
|||||||
@@ -1,167 +0,0 @@
|
|||||||
"""Tests for user-scoped path resolution in Paths."""
|
|
||||||
import pytest
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from deerflow.config.paths import Paths
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def paths(tmp_path: Path) -> Paths:
|
|
||||||
return Paths(tmp_path)
|
|
||||||
|
|
||||||
|
|
||||||
class TestValidateUserId:
|
|
||||||
def test_valid_user_id(self, paths: Paths):
|
|
||||||
d = paths.user_dir("u-abc-123")
|
|
||||||
assert d == paths.base_dir / "users" / "u-abc-123"
|
|
||||||
|
|
||||||
def test_rejects_path_traversal(self, paths: Paths):
|
|
||||||
with pytest.raises(ValueError, match="Invalid user_id"):
|
|
||||||
paths.user_dir("../escape")
|
|
||||||
|
|
||||||
def test_rejects_slash(self, paths: Paths):
|
|
||||||
with pytest.raises(ValueError, match="Invalid user_id"):
|
|
||||||
paths.user_dir("foo/bar")
|
|
||||||
|
|
||||||
def test_rejects_empty(self, paths: Paths):
|
|
||||||
with pytest.raises(ValueError, match="Invalid user_id"):
|
|
||||||
paths.user_dir("")
|
|
||||||
|
|
||||||
|
|
||||||
class TestUserDir:
|
|
||||||
def test_user_dir(self, paths: Paths):
|
|
||||||
assert paths.user_dir("alice") == paths.base_dir / "users" / "alice"
|
|
||||||
|
|
||||||
|
|
||||||
class TestUserMemoryFile:
|
|
||||||
def test_user_memory_file(self, paths: Paths):
|
|
||||||
assert paths.user_memory_file("bob") == paths.base_dir / "users" / "bob" / "memory.json"
|
|
||||||
|
|
||||||
|
|
||||||
class TestUserAgentMemoryFile:
|
|
||||||
def test_user_agent_memory_file(self, paths: Paths):
|
|
||||||
expected = paths.base_dir / "users" / "bob" / "agents" / "myagent" / "memory.json"
|
|
||||||
assert paths.user_agent_memory_file("bob", "myagent") == expected
|
|
||||||
|
|
||||||
def test_user_agent_memory_file_lowercases_name(self, paths: Paths):
|
|
||||||
expected = paths.base_dir / "users" / "bob" / "agents" / "myagent" / "memory.json"
|
|
||||||
assert paths.user_agent_memory_file("bob", "MyAgent") == expected
|
|
||||||
|
|
||||||
|
|
||||||
class TestUserThreadDir:
|
|
||||||
def test_user_thread_dir(self, paths: Paths):
|
|
||||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1"
|
|
||||||
assert paths.thread_dir("t1", user_id="u1") == expected
|
|
||||||
|
|
||||||
def test_thread_dir_no_user_id_falls_back_to_legacy(self, paths: Paths):
|
|
||||||
expected = paths.base_dir / "threads" / "t1"
|
|
||||||
assert paths.thread_dir("t1") == expected
|
|
||||||
|
|
||||||
|
|
||||||
class TestUserSandboxDirs:
|
|
||||||
def test_sandbox_work_dir(self, paths: Paths):
|
|
||||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "workspace"
|
|
||||||
assert paths.sandbox_work_dir("t1", user_id="u1") == expected
|
|
||||||
|
|
||||||
def test_sandbox_uploads_dir(self, paths: Paths):
|
|
||||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "uploads"
|
|
||||||
assert paths.sandbox_uploads_dir("t1", user_id="u1") == expected
|
|
||||||
|
|
||||||
def test_sandbox_outputs_dir(self, paths: Paths):
|
|
||||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data" / "outputs"
|
|
||||||
assert paths.sandbox_outputs_dir("t1", user_id="u1") == expected
|
|
||||||
|
|
||||||
def test_sandbox_user_data_dir(self, paths: Paths):
|
|
||||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "user-data"
|
|
||||||
assert paths.sandbox_user_data_dir("t1", user_id="u1") == expected
|
|
||||||
|
|
||||||
def test_acp_workspace_dir(self, paths: Paths):
|
|
||||||
expected = paths.base_dir / "users" / "u1" / "threads" / "t1" / "acp-workspace"
|
|
||||||
assert paths.acp_workspace_dir("t1", user_id="u1") == expected
|
|
||||||
|
|
||||||
def test_legacy_sandbox_work_dir(self, paths: Paths):
|
|
||||||
expected = paths.base_dir / "threads" / "t1" / "user-data" / "workspace"
|
|
||||||
assert paths.sandbox_work_dir("t1") == expected
|
|
||||||
|
|
||||||
|
|
||||||
class TestHostPathsWithUserId:
|
|
||||||
def test_host_thread_dir_with_user_id(self, paths: Paths):
|
|
||||||
result = paths.host_thread_dir("t1", user_id="u1")
|
|
||||||
assert "users" in result
|
|
||||||
assert "u1" in result
|
|
||||||
assert "threads" in result
|
|
||||||
assert "t1" in result
|
|
||||||
|
|
||||||
def test_host_thread_dir_legacy(self, paths: Paths):
|
|
||||||
result = paths.host_thread_dir("t1")
|
|
||||||
assert "threads" in result
|
|
||||||
assert "t1" in result
|
|
||||||
assert "users" not in result
|
|
||||||
|
|
||||||
def test_host_sandbox_user_data_dir_with_user_id(self, paths: Paths):
|
|
||||||
result = paths.host_sandbox_user_data_dir("t1", user_id="u1")
|
|
||||||
assert "users" in result
|
|
||||||
assert "user-data" in result
|
|
||||||
|
|
||||||
def test_host_sandbox_work_dir_with_user_id(self, paths: Paths):
|
|
||||||
result = paths.host_sandbox_work_dir("t1", user_id="u1")
|
|
||||||
assert "workspace" in result
|
|
||||||
|
|
||||||
def test_host_sandbox_uploads_dir_with_user_id(self, paths: Paths):
|
|
||||||
result = paths.host_sandbox_uploads_dir("t1", user_id="u1")
|
|
||||||
assert "uploads" in result
|
|
||||||
|
|
||||||
def test_host_sandbox_outputs_dir_with_user_id(self, paths: Paths):
|
|
||||||
result = paths.host_sandbox_outputs_dir("t1", user_id="u1")
|
|
||||||
assert "outputs" in result
|
|
||||||
|
|
||||||
def test_host_acp_workspace_dir_with_user_id(self, paths: Paths):
|
|
||||||
result = paths.host_acp_workspace_dir("t1", user_id="u1")
|
|
||||||
assert "acp-workspace" in result
|
|
||||||
|
|
||||||
|
|
||||||
class TestEnsureAndDeleteWithUserId:
|
|
||||||
def test_ensure_thread_dirs_creates_user_scoped(self, paths: Paths):
|
|
||||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
|
||||||
assert paths.sandbox_work_dir("t1", user_id="u1").is_dir()
|
|
||||||
assert paths.sandbox_uploads_dir("t1", user_id="u1").is_dir()
|
|
||||||
assert paths.sandbox_outputs_dir("t1", user_id="u1").is_dir()
|
|
||||||
assert paths.acp_workspace_dir("t1", user_id="u1").is_dir()
|
|
||||||
|
|
||||||
def test_delete_thread_dir_removes_user_scoped(self, paths: Paths):
|
|
||||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
|
||||||
assert paths.thread_dir("t1", user_id="u1").exists()
|
|
||||||
paths.delete_thread_dir("t1", user_id="u1")
|
|
||||||
assert not paths.thread_dir("t1", user_id="u1").exists()
|
|
||||||
|
|
||||||
def test_delete_thread_dir_idempotent(self, paths: Paths):
|
|
||||||
paths.delete_thread_dir("nonexistent", user_id="u1") # should not raise
|
|
||||||
|
|
||||||
def test_ensure_thread_dirs_legacy_still_works(self, paths: Paths):
|
|
||||||
paths.ensure_thread_dirs("t1")
|
|
||||||
assert paths.sandbox_work_dir("t1").is_dir()
|
|
||||||
|
|
||||||
def test_user_scoped_and_legacy_are_independent(self, paths: Paths):
|
|
||||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
|
||||||
paths.ensure_thread_dirs("t1")
|
|
||||||
# Both exist independently
|
|
||||||
assert paths.thread_dir("t1", user_id="u1").exists()
|
|
||||||
assert paths.thread_dir("t1").exists()
|
|
||||||
# Delete one doesn't affect the other
|
|
||||||
paths.delete_thread_dir("t1", user_id="u1")
|
|
||||||
assert not paths.thread_dir("t1", user_id="u1").exists()
|
|
||||||
assert paths.thread_dir("t1").exists()
|
|
||||||
|
|
||||||
|
|
||||||
class TestResolveVirtualPathWithUserId:
|
|
||||||
def test_resolve_virtual_path_with_user_id(self, paths: Paths):
|
|
||||||
paths.ensure_thread_dirs("t1", user_id="u1")
|
|
||||||
result = paths.resolve_virtual_path("t1", "/mnt/user-data/workspace/file.txt", user_id="u1")
|
|
||||||
expected_base = paths.sandbox_user_data_dir("t1", user_id="u1").resolve()
|
|
||||||
assert str(result).startswith(str(expected_base))
|
|
||||||
|
|
||||||
def test_resolve_virtual_path_legacy(self, paths: Paths):
|
|
||||||
paths.ensure_thread_dirs("t1")
|
|
||||||
result = paths.resolve_virtual_path("t1", "/mnt/user-data/workspace/file.txt")
|
|
||||||
expected_base = paths.sandbox_user_data_dir("t1").resolve()
|
|
||||||
assert str(result).startswith(str(expected_base))
|
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
Tests:
|
Tests:
|
||||||
1. DatabaseConfig property derivation (paths, URLs)
|
1. DatabaseConfig property derivation (paths, URLs)
|
||||||
2. MemoryRunStore CRUD + user_id filtering
|
2. MemoryRunStore CRUD + owner_id filtering
|
||||||
3. Base.to_dict() via inspect mixin
|
3. Base.to_dict() via inspect mixin
|
||||||
4. Engine init/close lifecycle (memory + SQLite)
|
4. Engine init/close lifecycle (memory + SQLite)
|
||||||
5. Postgres missing-dep error message
|
5. Postgres missing-dep error message
|
||||||
@@ -24,19 +24,18 @@ class TestDatabaseConfig:
|
|||||||
assert c.backend == "memory"
|
assert c.backend == "memory"
|
||||||
assert c.pool_size == 5
|
assert c.pool_size == 5
|
||||||
|
|
||||||
def test_sqlite_paths_unified(self):
|
def test_sqlite_paths_are_different(self):
|
||||||
c = DatabaseConfig(backend="sqlite", sqlite_dir="./mydata")
|
c = DatabaseConfig(backend="sqlite", sqlite_dir="./mydata")
|
||||||
assert c.sqlite_path.endswith("deerflow.db")
|
assert c.checkpointer_sqlite_path.endswith("checkpoints.db")
|
||||||
assert "mydata" in c.sqlite_path
|
assert c.app_sqlite_path.endswith("app.db")
|
||||||
# Backward-compatible aliases point to the same file
|
assert "mydata" in c.checkpointer_sqlite_path
|
||||||
assert c.checkpointer_sqlite_path == c.sqlite_path
|
assert c.checkpointer_sqlite_path != c.app_sqlite_path
|
||||||
assert c.app_sqlite_path == c.sqlite_path
|
|
||||||
|
|
||||||
def test_app_sqlalchemy_url_sqlite(self):
|
def test_app_sqlalchemy_url_sqlite(self):
|
||||||
c = DatabaseConfig(backend="sqlite", sqlite_dir="./data")
|
c = DatabaseConfig(backend="sqlite", sqlite_dir="./data")
|
||||||
url = c.app_sqlalchemy_url
|
url = c.app_sqlalchemy_url
|
||||||
assert url.startswith("sqlite+aiosqlite:///")
|
assert url.startswith("sqlite+aiosqlite:///")
|
||||||
assert "deerflow.db" in url
|
assert "app.db" in url
|
||||||
|
|
||||||
def test_app_sqlalchemy_url_postgres(self):
|
def test_app_sqlalchemy_url_postgres(self):
|
||||||
c = DatabaseConfig(
|
c = DatabaseConfig(
|
||||||
@@ -106,17 +105,17 @@ class TestMemoryRunStore:
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_list_by_thread_owner_filter(self, store):
|
async def test_list_by_thread_owner_filter(self, store):
|
||||||
await store.put("r1", thread_id="t1", user_id="alice")
|
await store.put("r1", thread_id="t1", owner_id="alice")
|
||||||
await store.put("r2", thread_id="t1", user_id="bob")
|
await store.put("r2", thread_id="t1", owner_id="bob")
|
||||||
rows = await store.list_by_thread("t1", user_id="alice")
|
rows = await store.list_by_thread("t1", owner_id="alice")
|
||||||
assert len(rows) == 1
|
assert len(rows) == 1
|
||||||
assert rows[0]["user_id"] == "alice"
|
assert rows[0]["owner_id"] == "alice"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_owner_none_returns_all(self, store):
|
async def test_owner_none_returns_all(self, store):
|
||||||
await store.put("r1", thread_id="t1", user_id="alice")
|
await store.put("r1", thread_id="t1", owner_id="alice")
|
||||||
await store.put("r2", thread_id="t1", user_id="bob")
|
await store.put("r2", thread_id="t1", owner_id="bob")
|
||||||
rows = await store.list_by_thread("t1", user_id=None)
|
rows = await store.list_by_thread("t1", owner_id=None)
|
||||||
assert len(rows) == 2
|
assert len(rows) == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ def test_present_files_keeps_virtual_outputs_path(tmp_path, monkeypatch):
|
|||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
present_file_tool_module,
|
present_file_tool_module,
|
||||||
"get_paths",
|
"get_paths",
|
||||||
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path, *, user_id=None: artifact_path),
|
lambda: SimpleNamespace(resolve_virtual_path=lambda thread_id, path: artifact_path),
|
||||||
)
|
)
|
||||||
|
|
||||||
result = present_file_tool_module.present_file_tool.func(
|
result = present_file_tool_module.present_file_tool.func(
|
||||||
|
|||||||
@@ -1,107 +0,0 @@
|
|||||||
"""Tests for paginated list_messages_by_run across all RunEventStore backends."""
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def base_store():
|
|
||||||
return MemoryRunEventStore()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_list_messages_by_run_default_returns_all(base_store):
|
|
||||||
store = base_store
|
|
||||||
for i in range(7):
|
|
||||||
await store.put(
|
|
||||||
thread_id="t1", run_id="run-a",
|
|
||||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
|
||||||
category="message", content=f"msg-a-{i}",
|
|
||||||
)
|
|
||||||
for i in range(3):
|
|
||||||
await store.put(
|
|
||||||
thread_id="t1", run_id="run-b",
|
|
||||||
event_type="human_message", category="message", content=f"msg-b-{i}",
|
|
||||||
)
|
|
||||||
await store.put(thread_id="t1", run_id="run-a", event_type="tool_call", category="trace", content="trace")
|
|
||||||
|
|
||||||
msgs = await store.list_messages_by_run("t1", "run-a")
|
|
||||||
assert len(msgs) == 7
|
|
||||||
assert all(m["category"] == "message" for m in msgs)
|
|
||||||
assert all(m["run_id"] == "run-a" for m in msgs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_list_messages_by_run_with_limit(base_store):
|
|
||||||
store = base_store
|
|
||||||
for i in range(7):
|
|
||||||
await store.put(
|
|
||||||
thread_id="t1", run_id="run-a",
|
|
||||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
|
||||||
category="message", content=f"msg-a-{i}",
|
|
||||||
)
|
|
||||||
|
|
||||||
msgs = await store.list_messages_by_run("t1", "run-a", limit=3)
|
|
||||||
assert len(msgs) == 3
|
|
||||||
seqs = [m["seq"] for m in msgs]
|
|
||||||
assert seqs == sorted(seqs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_list_messages_by_run_after_seq(base_store):
|
|
||||||
store = base_store
|
|
||||||
for i in range(7):
|
|
||||||
await store.put(
|
|
||||||
thread_id="t1", run_id="run-a",
|
|
||||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
|
||||||
category="message", content=f"msg-a-{i}",
|
|
||||||
)
|
|
||||||
|
|
||||||
all_msgs = await store.list_messages_by_run("t1", "run-a")
|
|
||||||
cursor_seq = all_msgs[2]["seq"]
|
|
||||||
msgs = await store.list_messages_by_run("t1", "run-a", after_seq=cursor_seq, limit=50)
|
|
||||||
assert all(m["seq"] > cursor_seq for m in msgs)
|
|
||||||
assert len(msgs) == 4
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_list_messages_by_run_before_seq(base_store):
|
|
||||||
store = base_store
|
|
||||||
for i in range(7):
|
|
||||||
await store.put(
|
|
||||||
thread_id="t1", run_id="run-a",
|
|
||||||
event_type="human_message" if i % 2 == 0 else "ai_message",
|
|
||||||
category="message", content=f"msg-a-{i}",
|
|
||||||
)
|
|
||||||
|
|
||||||
all_msgs = await store.list_messages_by_run("t1", "run-a")
|
|
||||||
cursor_seq = all_msgs[4]["seq"]
|
|
||||||
msgs = await store.list_messages_by_run("t1", "run-a", before_seq=cursor_seq, limit=50)
|
|
||||||
assert all(m["seq"] < cursor_seq for m in msgs)
|
|
||||||
assert len(msgs) == 4
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_list_messages_by_run_does_not_include_other_run(base_store):
|
|
||||||
store = base_store
|
|
||||||
for i in range(7):
|
|
||||||
await store.put(
|
|
||||||
thread_id="t1", run_id="run-a",
|
|
||||||
event_type="human_message", category="message", content=f"msg-a-{i}",
|
|
||||||
)
|
|
||||||
for i in range(3):
|
|
||||||
await store.put(
|
|
||||||
thread_id="t1", run_id="run-b",
|
|
||||||
event_type="human_message", category="message", content=f"msg-b-{i}",
|
|
||||||
)
|
|
||||||
|
|
||||||
msgs = await store.list_messages_by_run("t1", "run-b")
|
|
||||||
assert len(msgs) == 3
|
|
||||||
assert all(m["run_id"] == "run-b" for m in msgs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_list_messages_by_run_empty_run(base_store):
|
|
||||||
store = base_store
|
|
||||||
msgs = await store.list_messages_by_run("t1", "nonexistent")
|
|
||||||
assert msgs == []
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user