Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0f82f8a3a2 | |||
| a0ab3a3dd4 | |||
| 4fa2c15613 | |||
| 892a06fe98 | |||
| b5e18f5b47 | |||
| e3e00af51d | |||
| 5f2f1941e9 | |||
| 9d0a42c1fb | |||
| 39a575617b | |||
| 274255b1a5 | |||
| 14892e1463 | |||
| 37fd8b0d7a | |||
| 2fe0856e33 | |||
| 38a6ec496f | |||
| 3a99c4e81c | |||
| 7b9d224b3a | |||
| 0572ef44b9 | |||
| 839563f308 | |||
| 62bdfe3abc | |||
| b61ce3527b | |||
| 2d5f6f1b3d | |||
| 69bf3dafd8 | |||
| 6cbec13495 | |||
| 31e5b586a1 | |||
| e75a2ff29a | |||
| 185f5649dd |
@@ -33,5 +33,9 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
|||||||
|
|
||||||
# GitHub API Token
|
# GitHub API Token
|
||||||
# GITHUB_TOKEN=your-github-token
|
# GITHUB_TOKEN=your-github-token
|
||||||
|
|
||||||
|
# Database (only needed when config.yaml has database.backend: postgres)
|
||||||
|
# DATABASE_URL=postgresql://deerflow:password@localhost:5432/deerflow
|
||||||
|
#
|
||||||
# WECOM_BOT_ID=your-wecom-bot-id
|
# WECOM_BOT_ID=your-wecom-bot-id
|
||||||
# WECOM_BOT_SECRET=your-wecom-bot-secret
|
# WECOM_BOT_SECRET=your-wecom-bot-secret
|
||||||
|
|||||||
+20
-8
@@ -158,7 +158,7 @@ from deerflow.config import get_app_config
|
|||||||
|
|
||||||
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
|
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
|
||||||
|
|
||||||
1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory
|
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
|
||||||
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
||||||
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
||||||
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
|
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
|
||||||
@@ -216,6 +216,9 @@ FastAPI application on port 8001 with health check at `GET /health`.
|
|||||||
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
||||||
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
|
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
|
||||||
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing |
|
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing |
|
||||||
|
| **Thread Runs** (`/api/threads/{id}/runs`) | `POST /` - create background run; `POST /stream` - create + SSE stream; `POST /wait` - create + block; `GET /` - list runs; `GET /{rid}` - run details; `POST /{rid}/cancel` - cancel; `GET /{rid}/join` - join SSE; `GET /{rid}/messages` - paginated messages `{data, has_more}`; `GET /{rid}/events` - full event stream; `GET /../messages` - thread messages with feedback; `GET /../token-usage` - aggregate tokens |
|
||||||
|
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
||||||
|
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
||||||
|
|
||||||
Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway.
|
Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway.
|
||||||
|
|
||||||
@@ -229,7 +232,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
|||||||
|
|
||||||
**Virtual Path System**:
|
**Virtual Path System**:
|
||||||
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
||||||
- Physical: `backend/.deer-flow/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
- Physical: `backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
|
||||||
- Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()`
|
- Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()`
|
||||||
- Detection: `is_local_sandbox()` checks `sandbox_id == "local"`
|
- Detection: `is_local_sandbox()` checks `sandbox_id == "local"`
|
||||||
|
|
||||||
@@ -269,7 +272,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
|||||||
- `invoke_acp_agent` - Invokes external ACP-compatible agents from `config.yaml`
|
- `invoke_acp_agent` - Invokes external ACP-compatible agents from `config.yaml`
|
||||||
- ACP launchers must be real ACP adapters. The standard `codex` CLI is not ACP-compatible by itself; configure a wrapper such as `npx -y @zed-industries/codex-acp` or an installed `codex-acp` binary
|
- ACP launchers must be real ACP adapters. The standard `codex` CLI is not ACP-compatible by itself; configure a wrapper such as `npx -y @zed-industries/codex-acp` or an installed `codex-acp` binary
|
||||||
- Missing ACP executables now return an actionable error message instead of a raw `[Errno 2]`
|
- Missing ACP executables now return an actionable error message instead of a raw `[Errno 2]`
|
||||||
- Each ACP agent uses a per-thread workspace at `{base_dir}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
|
- Each ACP agent uses a per-thread workspace at `{base_dir}/users/{user_id}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
|
||||||
- `image_search/` - Image search via DuckDuckGo
|
- `image_search/` - Image search via DuckDuckGo
|
||||||
|
|
||||||
### MCP System (`packages/harness/deerflow/mcp/`)
|
### MCP System (`packages/harness/deerflow/mcp/`)
|
||||||
@@ -338,18 +341,27 @@ Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow a
|
|||||||
|
|
||||||
**Components**:
|
**Components**:
|
||||||
- `updater.py` - LLM-based memory updates with fact extraction, whitespace-normalized fact deduplication (trims leading/trailing whitespace before comparing), and atomic file I/O
|
- `updater.py` - LLM-based memory updates with fact extraction, whitespace-normalized fact deduplication (trims leading/trailing whitespace before comparing), and atomic file I/O
|
||||||
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time)
|
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time); captures `user_id` at enqueue time so it survives the `threading.Timer` boundary
|
||||||
- `prompt.py` - Prompt templates for memory updates
|
- `prompt.py` - Prompt templates for memory updates
|
||||||
|
- `storage.py` - File-based storage with per-user isolation; cache keyed by `(user_id, agent_name)` tuple
|
||||||
|
|
||||||
**Data Structure** (stored in `backend/.deer-flow/memory.json`):
|
**Per-User Isolation**:
|
||||||
|
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
|
||||||
|
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
|
||||||
|
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
|
||||||
|
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
|
||||||
|
- Absolute `storage_path` in config opts out of per-user isolation
|
||||||
|
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
|
||||||
|
|
||||||
|
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
|
||||||
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
|
||||||
- **History**: `recentMonths`, `earlierContext`, `longTermBackground`
|
- **History**: `recentMonths`, `earlierContext`, `longTermBackground`
|
||||||
- **Facts**: Discrete facts with `id`, `content`, `category` (preference/knowledge/context/behavior/goal), `confidence` (0-1), `createdAt`, `source`
|
- **Facts**: Discrete facts with `id`, `content`, `category` (preference/knowledge/context/behavior/goal), `confidence` (0-1), `createdAt`, `source`
|
||||||
|
|
||||||
**Workflow**:
|
**Workflow**:
|
||||||
1. `MemoryMiddleware` filters messages (user inputs + final AI responses) and queues conversation
|
1. `MemoryMiddleware` filters messages (user inputs + final AI responses), captures `user_id` via `get_effective_user_id()`, and queues conversation with the captured `user_id`
|
||||||
2. Queue debounces (30s default), batches updates, deduplicates per-thread
|
2. Queue debounces (30s default), batches updates, deduplicates per-thread
|
||||||
3. Background thread invokes LLM to extract context updates and facts
|
3. Background thread invokes LLM to extract context updates and facts, using the stored `user_id` (not the contextvar, which is unavailable on timer threads)
|
||||||
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
||||||
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
||||||
|
|
||||||
@@ -357,7 +369,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
|
|||||||
|
|
||||||
**Configuration** (`config.yaml` → `memory`):
|
**Configuration** (`config.yaml` → `memory`):
|
||||||
- `enabled` / `injection_enabled` - Master switches
|
- `enabled` / `injection_enabled` - Master switches
|
||||||
- `storage_path` - Path to memory.json
|
- `storage_path` - Path to memory.json (absolute path opts out of per-user isolation)
|
||||||
- `debounce_seconds` - Wait time before processing (default: 30)
|
- `debounce_seconds` - Wait time before processing (default: 30)
|
||||||
- `model_name` - LLM for updates (null = default model)
|
- `model_name` - LLM for updates (null = default model)
|
||||||
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
||||||
|
|||||||
+5
-1
@@ -13,6 +13,9 @@ FROM python:3.12-slim-bookworm AS builder
|
|||||||
ARG NODE_MAJOR=22
|
ARG NODE_MAJOR=22
|
||||||
ARG APT_MIRROR
|
ARG APT_MIRROR
|
||||||
ARG UV_INDEX_URL
|
ARG UV_INDEX_URL
|
||||||
|
# Optional extras to install (e.g. "postgres" for PostgreSQL support)
|
||||||
|
# Usage: docker build --build-arg UV_EXTRAS=postgres ...
|
||||||
|
ARG UV_EXTRAS
|
||||||
|
|
||||||
# Optionally override apt mirror for restricted networks (e.g. APT_MIRROR=mirrors.aliyun.com)
|
# Optionally override apt mirror for restricted networks (e.g. APT_MIRROR=mirrors.aliyun.com)
|
||||||
RUN if [ -n "${APT_MIRROR}" ]; then \
|
RUN if [ -n "${APT_MIRROR}" ]; then \
|
||||||
@@ -43,8 +46,9 @@ WORKDIR /app
|
|||||||
COPY backend ./backend
|
COPY backend ./backend
|
||||||
|
|
||||||
# Install dependencies with cache mount
|
# Install dependencies with cache mount
|
||||||
|
# When UV_EXTRAS is set (e.g. "postgres"), installs optional dependencies.
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync"
|
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync ${UV_EXTRAS:+--extra $UV_EXTRAS}"
|
||||||
|
|
||||||
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
|
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
|
||||||
# Retains compiler toolchain from builder so startup-time `uv sync` can build
|
# Retains compiler toolchain from builder so startup-time `uv sync` can build
|
||||||
|
|||||||
+1
-1
@@ -8,7 +8,7 @@ gateway:
|
|||||||
PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
|
PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
|
||||||
|
|
||||||
test:
|
test:
|
||||||
PYTHONPATH=. uv run pytest tests/ -v
|
PYTHONPATH=. uv run pytest tests/unittest -v
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
uvx ruff check .
|
uvx ruff check .
|
||||||
|
|||||||
@@ -9,10 +9,12 @@ import re
|
|||||||
import threading
|
import threading
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from app.plugins.auth.security.actor_context import bind_user_actor_context
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||||
|
from deerflow.runtime.actor_context import get_effective_user_id
|
||||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -297,15 +299,35 @@ class FeishuChannel(Channel):
|
|||||||
text = msg.text
|
text = msg.text
|
||||||
for file in files:
|
for file in files:
|
||||||
if file.get("image_key"):
|
if file.get("image_key"):
|
||||||
virtual_path = await self._receive_single_file(msg.thread_ts, file["image_key"], "image", thread_id)
|
virtual_path = await self._receive_single_file(
|
||||||
|
msg.thread_ts,
|
||||||
|
file["image_key"],
|
||||||
|
"image",
|
||||||
|
thread_id,
|
||||||
|
user_id=msg.user_id,
|
||||||
|
)
|
||||||
text = text.replace("[image]", virtual_path, 1)
|
text = text.replace("[image]", virtual_path, 1)
|
||||||
elif file.get("file_key"):
|
elif file.get("file_key"):
|
||||||
virtual_path = await self._receive_single_file(msg.thread_ts, file["file_key"], "file", thread_id)
|
virtual_path = await self._receive_single_file(
|
||||||
|
msg.thread_ts,
|
||||||
|
file["file_key"],
|
||||||
|
"file",
|
||||||
|
thread_id,
|
||||||
|
user_id=msg.user_id,
|
||||||
|
)
|
||||||
text = text.replace("[file]", virtual_path, 1)
|
text = text.replace("[file]", virtual_path, 1)
|
||||||
msg.text = text
|
msg.text = text
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
async def _receive_single_file(self, message_id: str, file_key: str, type: Literal["image", "file"], thread_id: str) -> str:
|
async def _receive_single_file(
|
||||||
|
self,
|
||||||
|
message_id: str,
|
||||||
|
file_key: str,
|
||||||
|
type: Literal["image", "file"],
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> str:
|
||||||
request = self._GetMessageResourceRequest.builder().message_id(message_id).file_key(file_key).type(type).build()
|
request = self._GetMessageResourceRequest.builder().message_id(message_id).file_key(file_key).type(type).build()
|
||||||
|
|
||||||
def inner():
|
def inner():
|
||||||
@@ -344,49 +366,51 @@ class FeishuChannel(Channel):
|
|||||||
return f"Failed to obtain the [{type}]"
|
return f"Failed to obtain the [{type}]"
|
||||||
|
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
paths.ensure_thread_dirs(thread_id)
|
with bind_user_actor_context(user_id):
|
||||||
uploads_dir = paths.sandbox_uploads_dir(thread_id).resolve()
|
effective_user_id = get_effective_user_id()
|
||||||
|
paths.ensure_thread_dirs(thread_id, user_id=effective_user_id)
|
||||||
|
uploads_dir = paths.sandbox_uploads_dir(thread_id, user_id=effective_user_id).resolve()
|
||||||
|
|
||||||
ext = "png" if type == "image" else "bin"
|
ext = "png" if type == "image" else "bin"
|
||||||
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
|
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
|
||||||
|
|
||||||
# Sanitize filename: preserve extension, replace path chars in name part
|
# Sanitize filename: preserve extension, replace path chars in name part
|
||||||
if "." in raw_filename:
|
if "." in raw_filename:
|
||||||
name_part, ext = raw_filename.rsplit(".", 1)
|
name_part, ext = raw_filename.rsplit(".", 1)
|
||||||
name_part = re.sub(r"[./\\]", "_", name_part)
|
name_part = re.sub(r"[./\\]", "_", name_part)
|
||||||
filename = f"{name_part}.{ext}"
|
filename = f"{name_part}.{ext}"
|
||||||
else:
|
else:
|
||||||
filename = re.sub(r"[./\\]", "_", raw_filename)
|
filename = re.sub(r"[./\\]", "_", raw_filename)
|
||||||
resolved_target = uploads_dir / filename
|
resolved_target = uploads_dir / filename
|
||||||
|
|
||||||
def down_load():
|
def down_load():
|
||||||
# use thread_lock to avoid filename conflicts when writing
|
# use thread_lock to avoid filename conflicts when writing
|
||||||
with self._thread_lock:
|
with self._thread_lock:
|
||||||
resolved_target.write_bytes(content)
|
resolved_target.write_bytes(content)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(down_load)
|
await asyncio.to_thread(down_load)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[Feishu] failed to persist downloaded resource: %s, type=%s", resolved_target, type)
|
logger.exception("[Feishu] failed to persist downloaded resource: %s, type=%s", resolved_target, type)
|
||||||
return f"Failed to obtain the [{type}]"
|
return f"Failed to obtain the [{type}]"
|
||||||
|
|
||||||
virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"
|
virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sandbox_provider = get_sandbox_provider()
|
sandbox_provider = get_sandbox_provider()
|
||||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||||
if sandbox_id != "local":
|
if sandbox_id != "local":
|
||||||
sandbox = sandbox_provider.get(sandbox_id)
|
sandbox = sandbox_provider.get(sandbox_id)
|
||||||
if sandbox is None:
|
if sandbox is None:
|
||||||
logger.warning("[Feishu] sandbox not found for thread_id=%s", thread_id)
|
logger.warning("[Feishu] sandbox not found for thread_id=%s", thread_id)
|
||||||
return f"Failed to obtain the [{type}]"
|
return f"Failed to obtain the [{type}]"
|
||||||
sandbox.update_file(virtual_path, content)
|
sandbox.update_file(virtual_path, content)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[Feishu] failed to sync resource into non-local sandbox: %s", virtual_path)
|
logger.exception("[Feishu] failed to sync resource into non-local sandbox: %s", virtual_path)
|
||||||
return f"Failed to obtain the [{type}]"
|
return f"Failed to obtain the [{type}]"
|
||||||
|
|
||||||
logger.info("[Feishu] downloaded resource mapped: file_key=%s -> %s", file_key, virtual_path)
|
logger.info("[Feishu] downloaded resource mapped: file_key=%s -> %s", file_key, virtual_path)
|
||||||
return virtual_path
|
return virtual_path
|
||||||
|
|
||||||
# -- message formatting ------------------------------------------------
|
# -- message formatting ------------------------------------------------
|
||||||
|
|
||||||
|
|||||||
@@ -14,9 +14,11 @@ from typing import Any
|
|||||||
import httpx
|
import httpx
|
||||||
from langgraph_sdk.errors import ConflictError
|
from langgraph_sdk.errors import ConflictError
|
||||||
|
|
||||||
|
from app.plugins.auth.security.actor_context import bind_user_actor_context
|
||||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
from app.channels.store import ChannelStore
|
from app.channels.store import ChannelStore
|
||||||
|
from deerflow.runtime.actor_context import get_effective_user_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -327,7 +329,7 @@ def _format_artifact_text(artifacts: list[str]) -> str:
|
|||||||
_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/"
|
_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/"
|
||||||
|
|
||||||
|
|
||||||
def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedAttachment]:
|
def _resolve_attachments(thread_id: str, artifacts: list[str], *, user_id: str | None = None) -> list[ResolvedAttachment]:
|
||||||
"""Resolve virtual artifact paths to host filesystem paths with metadata.
|
"""Resolve virtual artifact paths to host filesystem paths with metadata.
|
||||||
|
|
||||||
Only paths under ``/mnt/user-data/outputs/`` are accepted; any other
|
Only paths under ``/mnt/user-data/outputs/`` are accepted; any other
|
||||||
@@ -341,38 +343,40 @@ def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedA
|
|||||||
|
|
||||||
attachments: list[ResolvedAttachment] = []
|
attachments: list[ResolvedAttachment] = []
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
outputs_dir = paths.sandbox_outputs_dir(thread_id).resolve()
|
with bind_user_actor_context(user_id):
|
||||||
for virtual_path in artifacts:
|
effective_user_id = get_effective_user_id()
|
||||||
# Security: only allow files from the agent outputs directory
|
outputs_dir = paths.sandbox_outputs_dir(thread_id, user_id=effective_user_id).resolve()
|
||||||
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
|
for virtual_path in artifacts:
|
||||||
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
|
# Security: only allow files from the agent outputs directory
|
||||||
continue
|
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
|
||||||
try:
|
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
|
||||||
actual = paths.resolve_virtual_path(thread_id, virtual_path)
|
continue
|
||||||
# Verify the resolved path is actually under the outputs directory
|
|
||||||
# (guards against path-traversal even after prefix check)
|
|
||||||
try:
|
try:
|
||||||
actual.resolve().relative_to(outputs_dir)
|
actual = paths.resolve_virtual_path(thread_id, virtual_path, user_id=effective_user_id)
|
||||||
except ValueError:
|
# Verify the resolved path is actually under the outputs directory
|
||||||
logger.warning("[Manager] artifact path escapes outputs dir: %s -> %s", virtual_path, actual)
|
# (guards against path-traversal even after prefix check)
|
||||||
continue
|
try:
|
||||||
if not actual.is_file():
|
actual.resolve().relative_to(outputs_dir)
|
||||||
logger.warning("[Manager] artifact not found on disk: %s -> %s", virtual_path, actual)
|
except ValueError:
|
||||||
continue
|
logger.warning("[Manager] artifact path escapes outputs dir: %s -> %s", virtual_path, actual)
|
||||||
mime, _ = mimetypes.guess_type(str(actual))
|
continue
|
||||||
mime = mime or "application/octet-stream"
|
if not actual.is_file():
|
||||||
attachments.append(
|
logger.warning("[Manager] artifact not found on disk: %s -> %s", virtual_path, actual)
|
||||||
ResolvedAttachment(
|
continue
|
||||||
virtual_path=virtual_path,
|
mime, _ = mimetypes.guess_type(str(actual))
|
||||||
actual_path=actual,
|
mime = mime or "application/octet-stream"
|
||||||
filename=actual.name,
|
attachments.append(
|
||||||
mime_type=mime,
|
ResolvedAttachment(
|
||||||
size=actual.stat().st_size,
|
virtual_path=virtual_path,
|
||||||
is_image=mime.startswith("image/"),
|
actual_path=actual,
|
||||||
|
filename=actual.name,
|
||||||
|
mime_type=mime,
|
||||||
|
size=actual.stat().st_size,
|
||||||
|
is_image=mime.startswith("image/"),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
except (ValueError, OSError) as exc:
|
||||||
except (ValueError, OSError) as exc:
|
logger.warning("[Manager] failed to resolve artifact %s: %s", virtual_path, exc)
|
||||||
logger.warning("[Manager] failed to resolve artifact %s: %s", virtual_path, exc)
|
|
||||||
return attachments
|
return attachments
|
||||||
|
|
||||||
|
|
||||||
@@ -380,13 +384,15 @@ def _prepare_artifact_delivery(
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
response_text: str,
|
response_text: str,
|
||||||
artifacts: list[str],
|
artifacts: list[str],
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> tuple[str, list[ResolvedAttachment]]:
|
) -> tuple[str, list[ResolvedAttachment]]:
|
||||||
"""Resolve attachments and append filename fallbacks to the text response."""
|
"""Resolve attachments and append filename fallbacks to the text response."""
|
||||||
attachments: list[ResolvedAttachment] = []
|
attachments: list[ResolvedAttachment] = []
|
||||||
if not artifacts:
|
if not artifacts:
|
||||||
return response_text, attachments
|
return response_text, attachments
|
||||||
|
|
||||||
attachments = _resolve_attachments(thread_id, artifacts)
|
attachments = _resolve_attachments(thread_id, artifacts, user_id=user_id)
|
||||||
resolved_virtuals = {attachment.virtual_path for attachment in attachments}
|
resolved_virtuals = {attachment.virtual_path for attachment in attachments}
|
||||||
unresolved = [path for path in artifacts if path not in resolved_virtuals]
|
unresolved = [path for path in artifacts if path not in resolved_virtuals]
|
||||||
|
|
||||||
@@ -409,7 +415,8 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
|||||||
|
|
||||||
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
|
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
|
||||||
|
|
||||||
uploads_dir = ensure_uploads_dir(thread_id)
|
with bind_user_actor_context(msg.user_id):
|
||||||
|
uploads_dir = ensure_uploads_dir(thread_id)
|
||||||
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
||||||
|
|
||||||
created: list[dict[str, Any]] = []
|
created: list[dict[str, Any]] = []
|
||||||
@@ -743,7 +750,12 @@ class ChannelManager:
|
|||||||
len(artifacts),
|
len(artifacts),
|
||||||
)
|
)
|
||||||
|
|
||||||
response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts)
|
response_text, attachments = _prepare_artifact_delivery(
|
||||||
|
thread_id,
|
||||||
|
response_text,
|
||||||
|
artifacts,
|
||||||
|
user_id=msg.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not response_text:
|
if not response_text:
|
||||||
if attachments:
|
if attachments:
|
||||||
@@ -834,7 +846,12 @@ class ChannelManager:
|
|||||||
result = last_values if last_values is not None else {"messages": [{"type": "ai", "content": latest_text}]}
|
result = last_values if last_values is not None else {"messages": [{"type": "ai", "content": latest_text}]}
|
||||||
response_text = _extract_response_text(result)
|
response_text = _extract_response_text(result)
|
||||||
artifacts = _extract_artifacts(result)
|
artifacts = _extract_artifacts(result)
|
||||||
response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts)
|
response_text, attachments = _prepare_artifact_delivery(
|
||||||
|
thread_id,
|
||||||
|
response_text,
|
||||||
|
artifacts,
|
||||||
|
user_id=msg.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not response_text:
|
if not response_text:
|
||||||
if attachments:
|
if attachments:
|
||||||
|
|||||||
@@ -1,4 +1,23 @@
|
|||||||
from .app import app, create_app
|
from __future__ import annotations
|
||||||
from .config import GatewayConfig, get_gateway_config
|
|
||||||
|
|
||||||
__all__ = ["app", "create_app", "GatewayConfig", "get_gateway_config"]
|
__all__ = ["GatewayConfig", "app", "get_gateway_config", "register_app"]
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str):
|
||||||
|
if name == "app":
|
||||||
|
from .app import app
|
||||||
|
|
||||||
|
return app
|
||||||
|
if name == "GatewayConfig":
|
||||||
|
from .config import GatewayConfig
|
||||||
|
|
||||||
|
return GatewayConfig
|
||||||
|
if name == "get_gateway_config":
|
||||||
|
from .config import get_gateway_config
|
||||||
|
|
||||||
|
return get_gateway_config
|
||||||
|
if name == "register_app":
|
||||||
|
from .registrar import register_app
|
||||||
|
|
||||||
|
return register_app
|
||||||
|
raise AttributeError(name)
|
||||||
|
|||||||
+4
-217
@@ -1,221 +1,8 @@
|
|||||||
import logging
|
from app.gateway.registrar import register_app
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
|
|
||||||
from app.gateway.config import get_gateway_config
|
|
||||||
from app.gateway.deps import langgraph_runtime
|
|
||||||
from app.gateway.routers import (
|
|
||||||
agents,
|
|
||||||
artifacts,
|
|
||||||
assistants_compat,
|
|
||||||
channels,
|
|
||||||
mcp,
|
|
||||||
memory,
|
|
||||||
models,
|
|
||||||
runs,
|
|
||||||
skills,
|
|
||||||
suggestions,
|
|
||||||
thread_runs,
|
|
||||||
threads,
|
|
||||||
uploads,
|
|
||||||
)
|
|
||||||
from deerflow.config.app_config import get_app_config
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
def create_app():
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
return register_app()
|
||||||
"""Application lifespan handler."""
|
|
||||||
|
|
||||||
# Load config and check necessary environment variables at startup
|
|
||||||
try:
|
|
||||||
get_app_config()
|
|
||||||
logger.info("Configuration loaded successfully")
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
|
||||||
logger.exception(error_msg)
|
|
||||||
raise RuntimeError(error_msg) from e
|
|
||||||
config = get_gateway_config()
|
|
||||||
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
|
|
||||||
|
|
||||||
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
|
||||||
async with langgraph_runtime(app):
|
|
||||||
logger.info("LangGraph runtime initialised")
|
|
||||||
|
|
||||||
# Start IM channel service if any channels are configured
|
|
||||||
try:
|
|
||||||
from app.channels.service import start_channel_service
|
|
||||||
|
|
||||||
channel_service = await start_channel_service()
|
|
||||||
logger.info("Channel service started: %s", channel_service.get_status())
|
|
||||||
except Exception:
|
|
||||||
logger.exception("No IM channels configured or channel service failed to start")
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
# Stop channel service on shutdown
|
|
||||||
try:
|
|
||||||
from app.channels.service import stop_channel_service
|
|
||||||
|
|
||||||
await stop_channel_service()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to stop channel service")
|
|
||||||
|
|
||||||
logger.info("Shutting down API Gateway")
|
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
app = register_app()
|
||||||
"""Create and configure the FastAPI application.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured FastAPI application instance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
app = FastAPI(
|
|
||||||
title="DeerFlow API Gateway",
|
|
||||||
description="""
|
|
||||||
## DeerFlow API Gateway
|
|
||||||
|
|
||||||
API Gateway for DeerFlow - A LangGraph-based AI agent backend with sandbox execution capabilities.
|
|
||||||
|
|
||||||
### Features
|
|
||||||
|
|
||||||
- **Models Management**: Query and retrieve available AI models
|
|
||||||
- **MCP Configuration**: Manage Model Context Protocol (MCP) server configurations
|
|
||||||
- **Memory Management**: Access and manage global memory data for personalized conversations
|
|
||||||
- **Skills Management**: Query and manage skills and their enabled status
|
|
||||||
- **Artifacts**: Access thread artifacts and generated files
|
|
||||||
- **Health Monitoring**: System health check endpoints
|
|
||||||
|
|
||||||
### Architecture
|
|
||||||
|
|
||||||
LangGraph requests are handled by nginx reverse proxy.
|
|
||||||
This gateway provides custom endpoints for models, MCP configuration, skills, and artifacts.
|
|
||||||
""",
|
|
||||||
version="0.1.0",
|
|
||||||
lifespan=lifespan,
|
|
||||||
docs_url="/docs",
|
|
||||||
redoc_url="/redoc",
|
|
||||||
openapi_url="/openapi.json",
|
|
||||||
openapi_tags=[
|
|
||||||
{
|
|
||||||
"name": "models",
|
|
||||||
"description": "Operations for querying available AI models and their configurations",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "mcp",
|
|
||||||
"description": "Manage Model Context Protocol (MCP) server configurations",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "memory",
|
|
||||||
"description": "Access and manage global memory data for personalized conversations",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "skills",
|
|
||||||
"description": "Manage skills and their configurations",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "artifacts",
|
|
||||||
"description": "Access and download thread artifacts and generated files",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "uploads",
|
|
||||||
"description": "Upload and manage user files for threads",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "threads",
|
|
||||||
"description": "Manage DeerFlow thread-local filesystem data",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "agents",
|
|
||||||
"description": "Create and manage custom agents with per-agent config and prompts",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "suggestions",
|
|
||||||
"description": "Generate follow-up question suggestions for conversations",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "channels",
|
|
||||||
"description": "Manage IM channel integrations (Feishu, Slack, Telegram)",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "assistants-compat",
|
|
||||||
"description": "LangGraph Platform-compatible assistants API (stub)",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "runs",
|
|
||||||
"description": "LangGraph Platform-compatible runs lifecycle (create, stream, cancel)",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "health",
|
|
||||||
"description": "Health check and system status endpoints",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# CORS is handled by nginx - no need for FastAPI middleware
|
|
||||||
|
|
||||||
# Include routers
|
|
||||||
# Models API is mounted at /api/models
|
|
||||||
app.include_router(models.router)
|
|
||||||
|
|
||||||
# MCP API is mounted at /api/mcp
|
|
||||||
app.include_router(mcp.router)
|
|
||||||
|
|
||||||
# Memory API is mounted at /api/memory
|
|
||||||
app.include_router(memory.router)
|
|
||||||
|
|
||||||
# Skills API is mounted at /api/skills
|
|
||||||
app.include_router(skills.router)
|
|
||||||
|
|
||||||
# Artifacts API is mounted at /api/threads/{thread_id}/artifacts
|
|
||||||
app.include_router(artifacts.router)
|
|
||||||
|
|
||||||
# Uploads API is mounted at /api/threads/{thread_id}/uploads
|
|
||||||
app.include_router(uploads.router)
|
|
||||||
|
|
||||||
# Thread cleanup API is mounted at /api/threads/{thread_id}
|
|
||||||
app.include_router(threads.router)
|
|
||||||
|
|
||||||
# Agents API is mounted at /api/agents
|
|
||||||
app.include_router(agents.router)
|
|
||||||
|
|
||||||
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
|
||||||
app.include_router(suggestions.router)
|
|
||||||
|
|
||||||
# Channels API is mounted at /api/channels
|
|
||||||
app.include_router(channels.router)
|
|
||||||
|
|
||||||
# Assistants compatibility API (LangGraph Platform stub)
|
|
||||||
app.include_router(assistants_compat.router)
|
|
||||||
|
|
||||||
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
|
|
||||||
app.include_router(thread_runs.router)
|
|
||||||
|
|
||||||
# Stateless Runs API (stream/wait without a pre-existing thread)
|
|
||||||
app.include_router(runs.router)
|
|
||||||
|
|
||||||
@app.get("/health", tags=["health"])
|
|
||||||
async def health_check() -> dict:
|
|
||||||
"""Health check endpoint.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Service health status information.
|
|
||||||
"""
|
|
||||||
return {"status": "healthy", "service": "deer-flow-gateway"}
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
# Create app instance for uvicorn
|
|
||||||
app = create_app()
|
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .lifespan import lifespan_manager
|
||||||
|
|
||||||
|
__all__ = ["lifespan_manager"]
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
LifespanFunc = Callable[[FastAPI], AbstractAsyncContextManager[dict[str, Any] | None]]
|
||||||
|
|
||||||
|
|
||||||
|
class LifespanManager:
|
||||||
|
"""FastAPI lifespan manager"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._lifespans: list[LifespanFunc] = []
|
||||||
|
|
||||||
|
def register(self, func: LifespanFunc) -> LifespanFunc:
|
||||||
|
"""
|
||||||
|
Register a lifespan hook.
|
||||||
|
|
||||||
|
:param func: lifespan hook
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if func not in self._lifespans:
|
||||||
|
self._lifespans.append(func)
|
||||||
|
return func
|
||||||
|
|
||||||
|
def build(self) -> LifespanFunc:
|
||||||
|
"""
|
||||||
|
Build the combined lifespan hook.
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def combined_lifespan(app: FastAPI): # noqa: ANN202
|
||||||
|
state: dict[str, Any] = {}
|
||||||
|
async with AsyncExitStack() as exit_stack:
|
||||||
|
for lifespan_fn in self._lifespans:
|
||||||
|
result = await exit_stack.enter_async_context(lifespan_fn(app))
|
||||||
|
if isinstance(result, dict):
|
||||||
|
state.update(result)
|
||||||
|
|
||||||
|
for key, value in state.items():
|
||||||
|
setattr(app.state, key, value)
|
||||||
|
|
||||||
|
yield state or None
|
||||||
|
|
||||||
|
return combined_lifespan
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton lifespan_manager instance
|
||||||
|
lifespan_manager = LifespanManager()
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
from app.gateway.dependencies.checkpointer import (
|
||||||
|
CurrentCheckpointer,
|
||||||
|
get_checkpointer,
|
||||||
|
)
|
||||||
|
from app.plugins.auth.security.dependencies import (
|
||||||
|
CurrentAuthService,
|
||||||
|
CurrentUserRepository,
|
||||||
|
get_auth_service,
|
||||||
|
get_current_user_from_request,
|
||||||
|
get_current_user_id,
|
||||||
|
get_optional_user_from_request,
|
||||||
|
get_user_repository,
|
||||||
|
)
|
||||||
|
from app.gateway.dependencies.db import (
|
||||||
|
CurrentSession,
|
||||||
|
CurrentSessionTransaction,
|
||||||
|
get_db_session,
|
||||||
|
get_db_session_transaction,
|
||||||
|
)
|
||||||
|
from app.gateway.dependencies.repositories import (
|
||||||
|
CurrentFeedbackRepository,
|
||||||
|
CurrentRunRepository,
|
||||||
|
CurrentThreadMetaRepository,
|
||||||
|
CurrentThreadMetaStorage,
|
||||||
|
get_feedback_repository,
|
||||||
|
get_run_repository,
|
||||||
|
get_thread_meta_repository,
|
||||||
|
get_thread_meta_storage,
|
||||||
|
)
|
||||||
|
from app.gateway.dependencies.stream_bridge import (
|
||||||
|
CurrentStreamBridge,
|
||||||
|
get_stream_bridge,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CurrentCheckpointer",
|
||||||
|
"CurrentAuthService",
|
||||||
|
"CurrentFeedbackRepository",
|
||||||
|
"CurrentRunRepository",
|
||||||
|
"CurrentSession",
|
||||||
|
"CurrentSessionTransaction",
|
||||||
|
"CurrentStreamBridge",
|
||||||
|
"CurrentThreadMetaRepository",
|
||||||
|
"CurrentThreadMetaStorage",
|
||||||
|
"CurrentUserRepository",
|
||||||
|
"get_auth_service",
|
||||||
|
"get_checkpointer",
|
||||||
|
"get_current_user_from_request",
|
||||||
|
"get_current_user_id",
|
||||||
|
"get_db_session",
|
||||||
|
"get_db_session_transaction",
|
||||||
|
"get_feedback_repository",
|
||||||
|
"get_optional_user_from_request",
|
||||||
|
"get_run_repository",
|
||||||
|
"get_stream_bridge",
|
||||||
|
"get_thread_meta_repository",
|
||||||
|
"get_thread_meta_storage",
|
||||||
|
"get_user_repository",
|
||||||
|
]
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
|
|
||||||
|
def get_checkpointer(request: Request) -> Checkpointer:
|
||||||
|
"""Get checkpointer from app.state.persistence."""
|
||||||
|
persistence = getattr(request.app.state, "persistence", None)
|
||||||
|
if persistence is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Persistence not available")
|
||||||
|
checkpointer = getattr(persistence, "checkpointer", None)
|
||||||
|
if checkpointer is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Checkpointer not available")
|
||||||
|
return checkpointer
|
||||||
|
|
||||||
|
|
||||||
|
CurrentCheckpointer = Annotated[Checkpointer, Depends(get_checkpointer)]
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session_factory(request: Request) -> async_sessionmaker[AsyncSession]:
|
||||||
|
factory = getattr(request.app.state.persistence, "session_factory", None)
|
||||||
|
if factory is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Database session factory not available")
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db_session(request: Request) -> AsyncIterator[AsyncSession]:
|
||||||
|
"""Open a session without auto-commit. Use for read-only endpoints."""
|
||||||
|
session_factory = _get_session_factory(request)
|
||||||
|
async with session_factory() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db_session_transaction(request: Request) -> AsyncIterator[AsyncSession]:
|
||||||
|
"""Open a session and commit on success, rollback on error."""
|
||||||
|
session_factory = _get_session_factory(request)
|
||||||
|
async with session_factory() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
CurrentSession = Annotated[AsyncSession, Depends(get_db_session)]
|
||||||
|
CurrentSessionTransaction = Annotated[AsyncSession, Depends(get_db_session_transaction)]
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
|
||||||
|
from app.infra.storage import ThreadMetaStorage
|
||||||
|
from store.repositories.contracts import (
|
||||||
|
FeedbackRepositoryProtocol,
|
||||||
|
RunRepositoryProtocol,
|
||||||
|
ThreadMetaRepositoryProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _require_state(request: Request, attr: str, label: str):
|
||||||
|
value = getattr(request.app.state, attr, None)
|
||||||
|
if value is None:
|
||||||
|
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def get_run_repository(request: Request) -> RunRepositoryProtocol:
|
||||||
|
return _require_state(request, "run_store", "Run store")
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread_meta_repository(request: Request) -> ThreadMetaRepositoryProtocol:
|
||||||
|
return _require_state(request, "thread_meta_repo", "Thread metadata store")
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread_meta_storage(request: Request) -> ThreadMetaStorage:
|
||||||
|
return _require_state(request, "thread_meta_storage", "Thread metadata storage")
|
||||||
|
|
||||||
|
|
||||||
|
def get_feedback_repository(request: Request) -> FeedbackRepositoryProtocol:
|
||||||
|
return _require_state(request, "feedback_repo", "Feedback")
|
||||||
|
|
||||||
|
|
||||||
|
CurrentRunRepository = Annotated[RunRepositoryProtocol, Depends(get_run_repository)]
|
||||||
|
CurrentThreadMetaRepository = Annotated[ThreadMetaRepositoryProtocol, Depends(get_thread_meta_repository)]
|
||||||
|
CurrentThreadMetaStorage = Annotated[ThreadMetaStorage, Depends(get_thread_meta_storage)]
|
||||||
|
CurrentFeedbackRepository = Annotated[FeedbackRepositoryProtocol, Depends(get_feedback_repository)]
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
|
||||||
|
from deerflow.runtime import StreamBridge
|
||||||
|
|
||||||
|
|
||||||
|
def get_stream_bridge(request: Request) -> StreamBridge:
|
||||||
|
"""Get stream bridge from app.state."""
|
||||||
|
bridge = getattr(request.app.state, "stream_bridge", None)
|
||||||
|
if bridge is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Stream bridge not available")
|
||||||
|
return bridge
|
||||||
|
|
||||||
|
|
||||||
|
CurrentStreamBridge = Annotated[StreamBridge, Depends(get_stream_bridge)]
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
"""Centralized accessors for singleton objects stored on ``app.state``.
|
|
||||||
|
|
||||||
**Getters** (used by routers): raise 503 when a required dependency is
|
|
||||||
missing, except ``get_store`` which returns ``None``.
|
|
||||||
|
|
||||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager
|
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
|
||||||
|
|
||||||
from deerflow.runtime import RunManager, StreamBridge
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
||||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
|
||||||
|
|
||||||
Usage in ``app.py``::
|
|
||||||
|
|
||||||
async with langgraph_runtime(app):
|
|
||||||
yield
|
|
||||||
"""
|
|
||||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
|
||||||
from deerflow.runtime import make_store, make_stream_bridge
|
|
||||||
|
|
||||||
async with AsyncExitStack() as stack:
|
|
||||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
|
|
||||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
|
||||||
app.state.store = await stack.enter_async_context(make_store())
|
|
||||||
app.state.run_manager = RunManager()
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Getters – called by routers per-request
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def get_stream_bridge(request: Request) -> StreamBridge:
|
|
||||||
"""Return the global :class:`StreamBridge`, or 503."""
|
|
||||||
bridge = getattr(request.app.state, "stream_bridge", None)
|
|
||||||
if bridge is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Stream bridge not available")
|
|
||||||
return bridge
|
|
||||||
|
|
||||||
|
|
||||||
def get_run_manager(request: Request) -> RunManager:
|
|
||||||
"""Return the global :class:`RunManager`, or 503."""
|
|
||||||
mgr = getattr(request.app.state, "run_manager", None)
|
|
||||||
if mgr is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Run manager not available")
|
|
||||||
return mgr
|
|
||||||
|
|
||||||
|
|
||||||
def get_checkpointer(request: Request):
|
|
||||||
"""Return the global checkpointer, or 503."""
|
|
||||||
cp = getattr(request.app.state, "checkpointer", None)
|
|
||||||
if cp is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Checkpointer not available")
|
|
||||||
return cp
|
|
||||||
|
|
||||||
|
|
||||||
def get_store(request: Request):
|
|
||||||
"""Return the global store (may be ``None`` if not configured)."""
|
|
||||||
return getattr(request.app.state, "store", None)
|
|
||||||
@@ -5,15 +5,17 @@ from pathlib import Path
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
|
from deerflow.runtime.actor_context import get_effective_user_id
|
||||||
|
|
||||||
|
|
||||||
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
def resolve_thread_virtual_path(thread_id: str, virtual_path: str, *, user_id: str | None = None) -> Path:
|
||||||
"""Resolve a virtual path to the actual filesystem path under thread user-data.
|
"""Resolve a virtual path to the actual filesystem path under thread user-data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thread_id: The thread ID.
|
thread_id: The thread ID.
|
||||||
virtual_path: The virtual path as seen inside the sandbox
|
virtual_path: The virtual path as seen inside the sandbox
|
||||||
(e.g., /mnt/user-data/outputs/file.txt).
|
(e.g., /mnt/user-data/outputs/file.txt).
|
||||||
|
user_id: Explicit user id override. Falls back to the current actor context.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The resolved filesystem path.
|
The resolved filesystem path.
|
||||||
@@ -22,7 +24,8 @@ def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
|||||||
HTTPException: If the path is invalid or outside allowed directories.
|
HTTPException: If the path is invalid or outside allowed directories.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return get_paths().resolve_virtual_path(thread_id, virtual_path)
|
resolved_user_id = get_effective_user_id() if user_id is None else user_id
|
||||||
|
return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=resolved_user_id)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
status = 403 if "traversal" in str(e) else 400
|
status = 403 if "traversal" in str(e) else 400
|
||||||
raise HTTPException(status_code=status, detail=str(e))
|
raise HTTPException(status_code=status, detail=str(e))
|
||||||
|
|||||||
@@ -0,0 +1,132 @@
|
|||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from scalar_fastapi import AgentScalarConfig, get_scalar_api_reference
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from store.persistence import create_persistence
|
||||||
|
|
||||||
|
from app.gateway.common import lifespan_manager
|
||||||
|
from app.gateway.router import router as gateway_router
|
||||||
|
from app.infra.run_events import build_run_event_store
|
||||||
|
from app.infra.storage import FeedbackStoreAdapter, RunStoreAdapter, ThreadMetaStorage, ThreadMetaStoreAdapter
|
||||||
|
from app.plugins.auth.authorization.hooks import build_authz_hooks
|
||||||
|
from app.plugins.auth.injection import install_route_guards, load_route_policy_registry, validate_route_policy_registry
|
||||||
|
from app.plugins.auth.security import AuthMiddleware, CSRFMiddleware
|
||||||
|
|
||||||
|
STATIC_DIR = Path(__file__).resolve().parents[1] / "static"
|
||||||
|
STATIC_MOUNT = "/api/static"
|
||||||
|
SCALAR_JS_URL = f"{STATIC_MOUNT}/scalar.js"
|
||||||
|
|
||||||
|
|
||||||
|
@lifespan_manager.register
|
||||||
|
@asynccontextmanager
|
||||||
|
async def init_persistence(app: FastAPI) -> AsyncGenerator[dict[str, Any], None]:
|
||||||
|
"""Initialize persistence layer (DB, checkpointer, store)."""
|
||||||
|
app_persistence = await create_persistence()
|
||||||
|
|
||||||
|
await app_persistence.setup()
|
||||||
|
run_store = RunStoreAdapter(app_persistence.session_factory)
|
||||||
|
thread_meta_store = ThreadMetaStoreAdapter(app_persistence.session_factory)
|
||||||
|
feedback_store = FeedbackStoreAdapter(app_persistence.session_factory)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield {
|
||||||
|
"persistence": app_persistence,
|
||||||
|
"checkpointer": app_persistence.checkpointer,
|
||||||
|
"store": None,
|
||||||
|
"session_factory": app_persistence.session_factory,
|
||||||
|
"run_store": run_store,
|
||||||
|
"run_read_repo": run_store,
|
||||||
|
"run_write_repo": run_store,
|
||||||
|
"run_delete_repo": run_store,
|
||||||
|
"feedback_repo": feedback_store,
|
||||||
|
"thread_meta_repo": thread_meta_store,
|
||||||
|
"thread_meta_storage": ThreadMetaStorage(thread_meta_store),
|
||||||
|
"run_event_store": build_run_event_store(app_persistence.session_factory),
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
await app_persistence.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@lifespan_manager.register
|
||||||
|
@asynccontextmanager
|
||||||
|
async def init_runtime(app: FastAPI) -> AsyncGenerator[dict[str, Any], None]:
|
||||||
|
"""Initialize StreamBridge for LangGraph-compatible runtime endpoints."""
|
||||||
|
from app.infra.stream_bridge import build_stream_bridge
|
||||||
|
|
||||||
|
async with build_stream_bridge() as stream_bridge:
|
||||||
|
yield {
|
||||||
|
"stream_bridge": stream_bridge,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register_app() -> FastAPI:
|
||||||
|
app = FastAPI(
|
||||||
|
title="DeerFlow API Gateway",
|
||||||
|
version="0.1.0",
|
||||||
|
docs_url=None,
|
||||||
|
redoc_url=None,
|
||||||
|
lifespan=lifespan_manager.build(),
|
||||||
|
openapi_tags=[
|
||||||
|
{
|
||||||
|
"name": "threads",
|
||||||
|
"description": "Endpoints for managing threads, which are conversations between a human and an assistant. A thread can have multiple runs as the conversation progresses."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
app.state.authz_hooks = build_authz_hooks()
|
||||||
|
|
||||||
|
_register_static(app)
|
||||||
|
_register_routes(app)
|
||||||
|
_register_scalar(app)
|
||||||
|
_register_auth_route_policies(app)
|
||||||
|
_register_middlewares(app)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def _register_static(app: FastAPI) -> None:
|
||||||
|
app.mount(STATIC_MOUNT, StaticFiles(directory=STATIC_DIR), name="static")
|
||||||
|
|
||||||
|
|
||||||
|
def _register_routes(app: FastAPI) -> None:
|
||||||
|
app.include_router(gateway_router)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_auth_route_policies(app: FastAPI) -> None:
|
||||||
|
registry = load_route_policy_registry()
|
||||||
|
validate_route_policy_registry(app, registry)
|
||||||
|
app.state.auth_route_policy_registry = registry
|
||||||
|
install_route_guards(app)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_middlewares(app: FastAPI) -> None:
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
expose_headers=["*"],
|
||||||
|
)
|
||||||
|
app.add_middleware(CSRFMiddleware)
|
||||||
|
app.add_middleware(AuthMiddleware)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_scalar(app: FastAPI) -> None:
|
||||||
|
@app.get("/docs", include_in_schema=False)
|
||||||
|
def scalar_docs() -> HTMLResponse:
|
||||||
|
return get_scalar_api_reference(
|
||||||
|
openapi_url=app.openapi_url,
|
||||||
|
title=app.title,
|
||||||
|
scalar_js_url=SCALAR_JS_URL,
|
||||||
|
agent=AgentScalarConfig(disabled=True),
|
||||||
|
hide_client_button=True,
|
||||||
|
overrides={"mcp": {"disabled": True}},
|
||||||
|
)
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from app.plugins.auth.api.router import router as auth_router
|
||||||
|
|
||||||
|
from .routers import artifacts, channels, mcp, models, skills, uploads
|
||||||
|
from .routers.agents import router as agents_router
|
||||||
|
from .routers.langgraph import feedback_router, runs_router, suggestion_router, threads_router
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
router.include_router(auth_router)
|
||||||
|
router.include_router(threads_router, prefix="/api/threads")
|
||||||
|
router.include_router(runs_router, prefix="/api/threads")
|
||||||
|
router.include_router(feedback_router, prefix="/api/threads")
|
||||||
|
router.include_router(suggestion_router)
|
||||||
|
router.include_router(agents_router)
|
||||||
|
router.include_router(channels.router)
|
||||||
|
router.include_router(artifacts.router)
|
||||||
|
router.include_router(mcp.router)
|
||||||
|
router.include_router(models.router)
|
||||||
|
router.include_router(skills.router)
|
||||||
|
router.include_router(uploads.router)
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads
|
from . import artifacts, mcp, models, skills, suggestions, uploads
|
||||||
|
|
||||||
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
|
__all__ = ["artifacts", "mcp", "models", "skills", "suggestions", "uploads"]
|
||||||
|
|||||||
@@ -1,149 +0,0 @@
|
|||||||
"""Assistants compatibility endpoints.
|
|
||||||
|
|
||||||
Provides LangGraph Platform-compatible assistants API backed by the
|
|
||||||
``langgraph.json`` graph registry and ``config.yaml`` agent definitions.
|
|
||||||
|
|
||||||
This is a minimal stub that satisfies the ``useStream`` React hook's
|
|
||||||
initialization requirements (``assistants.search()`` and ``assistants.get()``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
router = APIRouter(prefix="/api/assistants", tags=["assistants-compat"])
|
|
||||||
|
|
||||||
|
|
||||||
class AssistantResponse(BaseModel):
|
|
||||||
assistant_id: str
|
|
||||||
graph_id: str
|
|
||||||
name: str
|
|
||||||
config: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
description: str | None = None
|
|
||||||
created_at: str = ""
|
|
||||||
updated_at: str = ""
|
|
||||||
version: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
class AssistantSearchRequest(BaseModel):
|
|
||||||
graph_id: str | None = None
|
|
||||||
name: str | None = None
|
|
||||||
metadata: dict[str, Any] | None = None
|
|
||||||
limit: int = 10
|
|
||||||
offset: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
def _get_default_assistant() -> AssistantResponse:
|
|
||||||
"""Return the default lead_agent assistant."""
|
|
||||||
now = datetime.now(UTC).isoformat()
|
|
||||||
return AssistantResponse(
|
|
||||||
assistant_id="lead_agent",
|
|
||||||
graph_id="lead_agent",
|
|
||||||
name="lead_agent",
|
|
||||||
config={},
|
|
||||||
metadata={"created_by": "system"},
|
|
||||||
description="DeerFlow lead agent",
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
version=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _list_assistants() -> list[AssistantResponse]:
|
|
||||||
"""List all available assistants from config."""
|
|
||||||
assistants = [_get_default_assistant()]
|
|
||||||
|
|
||||||
# Also include custom agents from config.yaml agents directory
|
|
||||||
try:
|
|
||||||
from deerflow.config.agents_config import list_custom_agents
|
|
||||||
|
|
||||||
for agent_cfg in list_custom_agents():
|
|
||||||
now = datetime.now(UTC).isoformat()
|
|
||||||
assistants.append(
|
|
||||||
AssistantResponse(
|
|
||||||
assistant_id=agent_cfg.name,
|
|
||||||
graph_id="lead_agent", # All agents use the same graph
|
|
||||||
name=agent_cfg.name,
|
|
||||||
config={},
|
|
||||||
metadata={"created_by": "user"},
|
|
||||||
description=agent_cfg.description or "",
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
version=1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not load custom agents for assistants list")
|
|
||||||
|
|
||||||
return assistants
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/search", response_model=list[AssistantResponse])
|
|
||||||
async def search_assistants(body: AssistantSearchRequest | None = None) -> list[AssistantResponse]:
|
|
||||||
"""Search assistants.
|
|
||||||
|
|
||||||
Returns all registered assistants (lead_agent + custom agents from config).
|
|
||||||
"""
|
|
||||||
assistants = _list_assistants()
|
|
||||||
|
|
||||||
if body and body.graph_id:
|
|
||||||
assistants = [a for a in assistants if a.graph_id == body.graph_id]
|
|
||||||
if body and body.name:
|
|
||||||
assistants = [a for a in assistants if body.name.lower() in a.name.lower()]
|
|
||||||
|
|
||||||
offset = body.offset if body else 0
|
|
||||||
limit = body.limit if body else 10
|
|
||||||
return assistants[offset : offset + limit]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{assistant_id}", response_model=AssistantResponse)
|
|
||||||
async def get_assistant_compat(assistant_id: str) -> AssistantResponse:
|
|
||||||
"""Get an assistant by ID."""
|
|
||||||
for a in _list_assistants():
|
|
||||||
if a.assistant_id == assistant_id:
|
|
||||||
return a
|
|
||||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{assistant_id}/graph")
|
|
||||||
async def get_assistant_graph(assistant_id: str) -> dict:
|
|
||||||
"""Get the graph structure for an assistant.
|
|
||||||
|
|
||||||
Returns a minimal graph description. Full graph introspection is
|
|
||||||
not supported in the Gateway — this stub satisfies SDK validation.
|
|
||||||
"""
|
|
||||||
found = any(a.assistant_id == assistant_id for a in _list_assistants())
|
|
||||||
if not found:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"graph_id": "lead_agent",
|
|
||||||
"nodes": [],
|
|
||||||
"edges": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{assistant_id}/schemas")
|
|
||||||
async def get_assistant_schemas(assistant_id: str) -> dict:
|
|
||||||
"""Get JSON schemas for an assistant's input/output/state.
|
|
||||||
|
|
||||||
Returns empty schemas — full introspection not supported in Gateway.
|
|
||||||
"""
|
|
||||||
found = any(a.assistant_id == assistant_id for a in _list_assistants())
|
|
||||||
if not found:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"graph_id": "lead_agent",
|
|
||||||
"input_schema": {},
|
|
||||||
"output_schema": {},
|
|
||||||
"state_schema": {},
|
|
||||||
"config_schema": {},
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
from .feedback import router as feedback_router
|
||||||
|
from .runs import router as runs_router
|
||||||
|
from .suggestions import router as suggestion_router
|
||||||
|
from .threads import router as threads_router
|
||||||
|
|
||||||
|
__all__ = ["feedback_router", "runs_router", "threads_router", "suggestion_router"]
|
||||||
@@ -0,0 +1,179 @@
|
|||||||
|
"""LangGraph-compatible run feedback endpoints."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.gateway.dependencies import get_feedback_repository, get_run_repository
|
||||||
|
from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id
|
||||||
|
from app.plugins.auth.security.dependencies import get_current_user_id
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter(tags=["feedback"])
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackCreateRequest(BaseModel):
|
||||||
|
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
|
||||||
|
comment: str | None = Field(default=None, description="Optional text feedback")
|
||||||
|
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackResponse(BaseModel):
|
||||||
|
feedback_id: str
|
||||||
|
run_id: str
|
||||||
|
thread_id: str
|
||||||
|
owner_id: str | None = None
|
||||||
|
message_id: str | None = None
|
||||||
|
rating: int
|
||||||
|
comment: str | None = None
|
||||||
|
created_at: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackStatsResponse(BaseModel):
|
||||||
|
run_id: str
|
||||||
|
total: int = 0
|
||||||
|
positive: int = 0
|
||||||
|
negative: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
async def _validate_run_scope(thread_id: str, run_id: str, request: Request) -> None:
|
||||||
|
run_store = get_run_repository(request)
|
||||||
|
if resolve_request_user_id(request) is None:
|
||||||
|
run = await run_store.get(run_id, user_id=None)
|
||||||
|
else:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
run = await run_store.get(run_id)
|
||||||
|
if run is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
if run.get("thread_id") != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_current_user(request: Request) -> str | None:
|
||||||
|
"""Extract current user id from auth dependencies when available."""
|
||||||
|
return await get_current_user_id(request)
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_feedback(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
body: FeedbackCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if body.rating not in (1, -1):
|
||||||
|
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
|
||||||
|
|
||||||
|
await _validate_run_scope(thread_id, run_id, request)
|
||||||
|
user_id = await _get_current_user(request)
|
||||||
|
feedback_repo = get_feedback_repository(request)
|
||||||
|
return await feedback_repo.create(
|
||||||
|
run_id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
rating=body.rating,
|
||||||
|
user_id=user_id,
|
||||||
|
message_id=body.message_id,
|
||||||
|
comment=body.comment,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||||
|
async def upsert_feedback(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
body: FeedbackCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create or replace the run-level feedback record."""
|
||||||
|
feedback_repo = get_feedback_repository(request)
|
||||||
|
user_id = await _get_current_user(request)
|
||||||
|
if user_id is not None:
|
||||||
|
return await feedback_repo.upsert(
|
||||||
|
run_id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
rating=body.rating,
|
||||||
|
user_id=user_id,
|
||||||
|
comment=body.comment,
|
||||||
|
)
|
||||||
|
existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None)
|
||||||
|
for item in existing:
|
||||||
|
feedback_id = item.get("feedback_id")
|
||||||
|
if isinstance(feedback_id, str):
|
||||||
|
await feedback_repo.delete(feedback_id)
|
||||||
|
return await _create_feedback(thread_id, run_id, body, request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
|
||||||
|
async def create_feedback(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
body: FeedbackCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Submit feedback for a run."""
|
||||||
|
return await _create_feedback(thread_id, run_id, body, request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
|
||||||
|
async def list_feedback(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""List all feedback for a run."""
|
||||||
|
feedback_repo = get_feedback_repository(request)
|
||||||
|
user_id = await _get_current_user(request)
|
||||||
|
return await feedback_repo.list_by_run(thread_id, run_id, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
|
||||||
|
async def feedback_stats(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Get aggregated feedback stats for a run."""
|
||||||
|
feedback_repo = get_feedback_repository(request)
|
||||||
|
return await feedback_repo.aggregate_by_run(thread_id, run_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{thread_id}/runs/{run_id}/feedback")
|
||||||
|
async def delete_run_feedback(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete all feedback records for a run."""
|
||||||
|
feedback_repo = get_feedback_repository(request)
|
||||||
|
user_id = await _get_current_user(request)
|
||||||
|
if user_id is not None:
|
||||||
|
return {"success": await feedback_repo.delete_by_run(thread_id=thread_id, run_id=run_id, user_id=user_id)}
|
||||||
|
existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None)
|
||||||
|
for item in existing:
|
||||||
|
feedback_id = item.get("feedback_id")
|
||||||
|
if isinstance(feedback_id, str):
|
||||||
|
await feedback_repo.delete(feedback_id)
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
|
||||||
|
async def delete_feedback(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
feedback_id: str,
|
||||||
|
request: Request,
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a single feedback record."""
|
||||||
|
feedback_repo = get_feedback_repository(request)
|
||||||
|
existing = await feedback_repo.get(feedback_id)
|
||||||
|
if existing is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
||||||
|
if existing.get("thread_id") != thread_id or existing.get("run_id") != run_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found in run {run_id}")
|
||||||
|
deleted = await feedback_repo.delete(feedback_id)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
||||||
|
return {"success": True}
|
||||||
@@ -0,0 +1,501 @@
|
|||||||
|
"""LangGraph-compatible runs endpoints backed by RunsFacade."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import Response, StreamingResponse
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.plugins.auth.security.actor_context import bind_request_actor_context
|
||||||
|
from app.gateway.services.runs.facade_factory import build_runs_facade_from_request
|
||||||
|
from app.gateway.services.runs.input import (
|
||||||
|
AdaptedRunRequest,
|
||||||
|
RunSpecBuilder,
|
||||||
|
UnsupportedRunFeatureError,
|
||||||
|
adapt_create_run_request,
|
||||||
|
adapt_create_stream_request,
|
||||||
|
adapt_create_wait_request,
|
||||||
|
adapt_join_stream_request,
|
||||||
|
adapt_join_wait_request,
|
||||||
|
)
|
||||||
|
from deerflow.runtime.runs.types import RunRecord, RunSpec
|
||||||
|
from deerflow.runtime.stream_bridge import JSONValue, StreamEvent
|
||||||
|
|
||||||
|
router = APIRouter(tags=["runs"])
|
||||||
|
|
||||||
|
|
||||||
|
class RunCreateRequest(BaseModel):
|
||||||
|
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
|
||||||
|
follow_up_to_run_id: str | None = Field(default=None, description="Lineage link to the prior run")
|
||||||
|
input: dict[str, JSONValue] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
|
||||||
|
command: dict[str, JSONValue] | None = Field(default=None, description="LangGraph Command")
|
||||||
|
metadata: dict[str, JSONValue] | None = Field(default=None, description="Run metadata")
|
||||||
|
config: dict[str, JSONValue] | None = Field(default=None, description="RunnableConfig overrides")
|
||||||
|
context: dict[str, JSONValue] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
||||||
|
webhook: str | None = Field(default=None, description="Completion callback URL")
|
||||||
|
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
||||||
|
checkpoint: dict[str, JSONValue] | None = Field(default=None, description="Full checkpoint object")
|
||||||
|
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
|
||||||
|
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
|
||||||
|
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
|
||||||
|
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
|
||||||
|
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
|
||||||
|
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
|
||||||
|
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
|
||||||
|
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
|
||||||
|
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
||||||
|
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
||||||
|
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
||||||
|
|
||||||
|
|
||||||
|
class RunResponse(BaseModel):
|
||||||
|
run_id: str
|
||||||
|
thread_id: str
|
||||||
|
assistant_id: str | None = None
|
||||||
|
status: str
|
||||||
|
metadata: dict[str, JSONValue] = Field(default_factory=dict)
|
||||||
|
multitask_strategy: str = "reject"
|
||||||
|
created_at: str = ""
|
||||||
|
updated_at: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class RunDeleteResponse(BaseModel):
|
||||||
|
deleted: bool
|
||||||
|
|
||||||
|
|
||||||
|
class RunMessageResponse(BaseModel):
|
||||||
|
run_id: str
|
||||||
|
content: JSONValue
|
||||||
|
metadata: dict[str, JSONValue] = Field(default_factory=dict)
|
||||||
|
created_at: str
|
||||||
|
seq: int
|
||||||
|
|
||||||
|
|
||||||
|
class RunMessagesResponse(BaseModel):
|
||||||
|
data: list[RunMessageResponse]
|
||||||
|
hasMore: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def format_sse(event: str, data: JSONValue, *, event_id: str | None = None) -> str:
|
||||||
|
"""Format a single SSE frame."""
|
||||||
|
payload = json.dumps(data, default=str, ensure_ascii=False)
|
||||||
|
parts = [f"event: {event}", f"data: {payload}"]
|
||||||
|
if event_id:
|
||||||
|
parts.append(f"id: {event_id}")
|
||||||
|
parts.append("")
|
||||||
|
parts.append("")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _record_to_response(record: RunRecord) -> RunResponse:
|
||||||
|
return RunResponse(
|
||||||
|
run_id=record.run_id,
|
||||||
|
thread_id=record.thread_id,
|
||||||
|
assistant_id=record.assistant_id,
|
||||||
|
status=record.status,
|
||||||
|
metadata=record.metadata,
|
||||||
|
multitask_strategy=record.multitask_strategy,
|
||||||
|
created_at=record.created_at,
|
||||||
|
updated_at=record.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _trim_paginated_rows(
|
||||||
|
rows: list[dict],
|
||||||
|
*,
|
||||||
|
limit: int,
|
||||||
|
after_seq: int | None,
|
||||||
|
) -> tuple[list[dict], bool]:
|
||||||
|
has_more = len(rows) > limit
|
||||||
|
if not has_more:
|
||||||
|
return rows, False
|
||||||
|
if after_seq is not None:
|
||||||
|
return rows[:limit], True
|
||||||
|
return rows[-limit:], True
|
||||||
|
|
||||||
|
|
||||||
|
def _event_to_run_message(event: dict) -> RunMessageResponse:
|
||||||
|
return RunMessageResponse(
|
||||||
|
run_id=str(event["run_id"]),
|
||||||
|
content=event.get("content"),
|
||||||
|
metadata=dict(event.get("metadata") or {}),
|
||||||
|
created_at=str(event.get("created_at") or ""),
|
||||||
|
seq=int(event["seq"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _sse_consumer(
|
||||||
|
stream: AsyncIterator[StreamEvent],
|
||||||
|
request: Request,
|
||||||
|
*,
|
||||||
|
cancel_on_disconnect: bool,
|
||||||
|
cancel_run,
|
||||||
|
run_id: str,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
try:
|
||||||
|
async for event in stream:
|
||||||
|
if await request.is_disconnected():
|
||||||
|
break
|
||||||
|
|
||||||
|
if event.event == "__heartbeat__":
|
||||||
|
yield ": heartbeat\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event.event == "__end__":
|
||||||
|
yield format_sse("end", None, event_id=event.id or None)
|
||||||
|
return
|
||||||
|
|
||||||
|
if event.event == "__cancelled__":
|
||||||
|
yield format_sse("cancel", None, event_id=event.id or None)
|
||||||
|
return
|
||||||
|
|
||||||
|
yield format_sse(event.event, event.data, event_id=event.id or None)
|
||||||
|
finally:
|
||||||
|
if cancel_on_disconnect:
|
||||||
|
await cancel_run(run_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_run_event_store(request: Request):
|
||||||
|
event_store = getattr(request.app.state, "run_event_store", None)
|
||||||
|
if event_store is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Run event store not available")
|
||||||
|
return event_store
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
||||||
|
async def list_runs(
|
||||||
|
thread_id: str,
|
||||||
|
request: Request,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
status: str | None = None,
|
||||||
|
) -> list[RunResponse]:
|
||||||
|
# Accepted for API compatibility; field projection is not implemented yet.
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
records = await facade.list_runs(thread_id)
|
||||||
|
if status is not None:
|
||||||
|
records = [record for record in records if record.status == status]
|
||||||
|
records = records[offset : offset + limit]
|
||||||
|
return [_record_to_response(record) for record in records]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
||||||
|
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
return _record_to_response(record)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/runs/{run_id}/messages", response_model=RunMessagesResponse)
|
||||||
|
async def run_messages(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
) -> RunMessagesResponse:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
|
||||||
|
event_store = _get_run_event_store(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
rows = await event_store.list_messages_by_run(
|
||||||
|
thread_id,
|
||||||
|
run_id,
|
||||||
|
limit=limit + 1,
|
||||||
|
before_seq=before_seq,
|
||||||
|
after_seq=after_seq,
|
||||||
|
)
|
||||||
|
page, has_more = _trim_paginated_rows(rows, limit=limit, after_seq=after_seq)
|
||||||
|
return RunMessagesResponse(data=[_event_to_run_message(row) for row in page], hasMore=has_more)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_spec(
|
||||||
|
*,
|
||||||
|
adapted: AdaptedRunRequest,
|
||||||
|
) -> RunSpec:
|
||||||
|
try:
|
||||||
|
return RunSpecBuilder().build(adapted)
|
||||||
|
except UnsupportedRunFeatureError as exc:
|
||||||
|
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
||||||
|
async def create_run(
|
||||||
|
thread_id: str,
|
||||||
|
body: RunCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> Response:
|
||||||
|
adapted = adapt_create_run_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.create_background(spec)
|
||||||
|
return Response(
|
||||||
|
content=_record_to_response(record).model_dump_json(),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs/stream")
|
||||||
|
async def stream_run(
|
||||||
|
thread_id: str,
|
||||||
|
body: RunCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> StreamingResponse:
|
||||||
|
adapted = adapt_create_stream_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record, stream = await facade.create_and_stream(spec)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
_sse_consumer(
|
||||||
|
stream,
|
||||||
|
request,
|
||||||
|
cancel_on_disconnect=spec.on_disconnect == "cancel",
|
||||||
|
cancel_run=facade.cancel,
|
||||||
|
run_id=record.run_id,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs/wait")
|
||||||
|
async def wait_run(
|
||||||
|
thread_id: str,
|
||||||
|
body: RunCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> Response:
|
||||||
|
adapted = adapt_create_wait_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record, result = await facade.create_and_wait(spec)
|
||||||
|
return Response(
|
||||||
|
content=json.dumps(result, default=str, ensure_ascii=False),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/runs", response_model=RunResponse)
|
||||||
|
async def create_stateless_run(body: RunCreateRequest, request: Request) -> Response:
|
||||||
|
adapted = adapt_create_run_request(
|
||||||
|
thread_id=None,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.create_background(spec)
|
||||||
|
return Response(
|
||||||
|
content=_record_to_response(record).model_dump_json(),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/runs/stream")
|
||||||
|
async def create_stateless_stream_run(body: RunCreateRequest, request: Request) -> StreamingResponse:
|
||||||
|
adapted = adapt_create_stream_request(
|
||||||
|
thread_id=None,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record, stream = await facade.create_and_stream(spec)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
_sse_consumer(
|
||||||
|
stream,
|
||||||
|
request,
|
||||||
|
cancel_on_disconnect=spec.on_disconnect == "cancel",
|
||||||
|
cancel_run=facade.cancel,
|
||||||
|
run_id=record.run_id,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/runs/wait")
|
||||||
|
async def wait_stateless_run(body: RunCreateRequest, request: Request) -> Response:
|
||||||
|
adapted = adapt_create_wait_request(
|
||||||
|
thread_id=None,
|
||||||
|
body=body.model_dump(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
spec = _build_spec(adapted=adapted)
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record, result = await facade.create_and_wait(spec)
|
||||||
|
return Response(
|
||||||
|
content=json.dumps(result, default=str, ensure_ascii=False),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
|
||||||
|
async def stream_existing_run(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
action: Literal["interrupt", "rollback"] | None = None,
|
||||||
|
wait: bool = False,
|
||||||
|
cancel_on_disconnect: bool = False,
|
||||||
|
stream_mode: str | None = None,
|
||||||
|
) -> StreamingResponse | Response:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
|
||||||
|
if action is not None:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
cancelled = await facade.cancel(run_id, action=action)
|
||||||
|
if not cancelled:
|
||||||
|
raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable")
|
||||||
|
if wait:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
await facade.join_wait(run_id)
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
adapted = adapt_join_stream_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
stream = await facade.join_stream(run_id, last_event_id=adapted.last_event_id)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
_sse_consumer(
|
||||||
|
stream,
|
||||||
|
request,
|
||||||
|
cancel_on_disconnect=cancel_on_disconnect,
|
||||||
|
cancel_run=facade.cancel,
|
||||||
|
run_id=run_id,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/runs/{run_id}/join")
|
||||||
|
async def join_existing_run(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
cancel_on_disconnect: bool = False,
|
||||||
|
) -> JSONValue:
|
||||||
|
# Accepted for API compatibility; current join_wait path does not change
|
||||||
|
# behavior based on client disconnect.
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
|
||||||
|
adapted = adapt_join_wait_request(
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
headers=dict(request.headers),
|
||||||
|
query=dict(request.query_params),
|
||||||
|
)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
return await facade.join_wait(run_id, last_event_id=adapted.last_event_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
||||||
|
async def cancel_existing_run(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
wait: bool = False,
|
||||||
|
action: Literal["interrupt", "rollback"] = "interrupt",
|
||||||
|
) -> JSONValue:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
cancelled = await facade.cancel(run_id, action=action)
|
||||||
|
if not cancelled:
|
||||||
|
raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable")
|
||||||
|
if wait:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
return await facade.join_wait(run_id)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{thread_id}/runs/{run_id}", response_model=RunDeleteResponse)
|
||||||
|
async def delete_run(
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
request: Request,
|
||||||
|
) -> RunDeleteResponse:
|
||||||
|
facade = build_runs_facade_from_request(request)
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
record = await facade.get_run(run_id)
|
||||||
|
if record is None or record.thread_id != thread_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
deleted = await facade.delete_run(run_id)
|
||||||
|
return RunDeleteResponse(deleted=deleted)
|
||||||
@@ -0,0 +1,132 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from deerflow.models import create_chat_model
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api", tags=["suggestions"])
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestionMessage(BaseModel):
|
||||||
|
role: str = Field(..., description="Message role: user|assistant")
|
||||||
|
content: str = Field(..., description="Message content as plain text")
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestionsRequest(BaseModel):
|
||||||
|
messages: list[SuggestionMessage] = Field(..., description="Recent conversation messages")
|
||||||
|
n: int = Field(default=3, ge=1, le=5, description="Number of suggestions to generate")
|
||||||
|
model_name: str | None = Field(default=None, description="Optional model override")
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestionsResponse(BaseModel):
|
||||||
|
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_markdown_code_fence(text: str) -> str:
|
||||||
|
stripped = text.strip()
|
||||||
|
if not stripped.startswith("```"):
|
||||||
|
return stripped
|
||||||
|
lines = stripped.splitlines()
|
||||||
|
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
|
||||||
|
return "\n".join(lines[1:-1]).strip()
|
||||||
|
return stripped
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_string_list(text: str) -> list[str] | None:
|
||||||
|
candidate = _strip_markdown_code_fence(text)
|
||||||
|
start = candidate.find("[")
|
||||||
|
end = candidate.rfind("]")
|
||||||
|
if start == -1 or end == -1 or end <= start:
|
||||||
|
return None
|
||||||
|
candidate = candidate[start : end + 1]
|
||||||
|
try:
|
||||||
|
data = json.loads(candidate)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if not isinstance(data, list):
|
||||||
|
return None
|
||||||
|
out: list[str] = []
|
||||||
|
for item in data:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
continue
|
||||||
|
s = item.strip()
|
||||||
|
if not s:
|
||||||
|
continue
|
||||||
|
out.append(s)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_response_text(content: object) -> str:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
parts.append(block)
|
||||||
|
elif isinstance(block, dict) and block.get("type") in {"text", "output_text"}:
|
||||||
|
text = block.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "\n".join(parts) if parts else ""
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_conversation(messages: list[SuggestionMessage]) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
for m in messages:
|
||||||
|
role = m.role.strip().lower()
|
||||||
|
if role in ("user", "human"):
|
||||||
|
parts.append(f"User: {m.content.strip()}")
|
||||||
|
elif role in ("assistant", "ai"):
|
||||||
|
parts.append(f"Assistant: {m.content.strip()}")
|
||||||
|
else:
|
||||||
|
parts.append(f"{m.role}: {m.content.strip()}")
|
||||||
|
return "\n".join(parts).strip()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/threads/{thread_id}/suggestions",
|
||||||
|
response_model=SuggestionsResponse,
|
||||||
|
summary="Generate Follow-up Questions",
|
||||||
|
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
||||||
|
)
|
||||||
|
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
|
||||||
|
if not request.messages:
|
||||||
|
return SuggestionsResponse(suggestions=[])
|
||||||
|
|
||||||
|
n = request.n
|
||||||
|
conversation = _format_conversation(request.messages)
|
||||||
|
if not conversation:
|
||||||
|
return SuggestionsResponse(suggestions=[])
|
||||||
|
|
||||||
|
system_instruction = (
|
||||||
|
"You are generating follow-up questions to help the user continue the conversation.\n"
|
||||||
|
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
|
||||||
|
"Requirements:\n"
|
||||||
|
"- Questions must be relevant to the preceding conversation.\n"
|
||||||
|
"- Questions must be written in the same language as the user.\n"
|
||||||
|
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
|
||||||
|
"- Do NOT include numbering, markdown, or any extra text.\n"
|
||||||
|
"- Output MUST be a JSON array of strings only.\n"
|
||||||
|
)
|
||||||
|
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = create_chat_model(name=request.model_name, thinking_enabled=False)
|
||||||
|
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
|
||||||
|
raw = _extract_response_text(response.content)
|
||||||
|
suggestions = _parse_json_string_list(raw) or []
|
||||||
|
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
|
||||||
|
cleaned = cleaned[:n]
|
||||||
|
return SuggestionsResponse(suggestions=cleaned)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to generate suggestions: thread_id=%s err=%s", thread_id, exc)
|
||||||
|
return SuggestionsResponse(suggestions=[])
|
||||||
@@ -0,0 +1,455 @@
|
|||||||
|
"""Thread management endpoints.
|
||||||
|
|
||||||
|
Provides CRUD operations for threads and checkpoint state management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.gateway.dependencies import CurrentCheckpointer, CurrentRunRepository, CurrentThreadMetaStorage
|
||||||
|
from app.infra.storage import ThreadMetaStorage
|
||||||
|
from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id
|
||||||
|
from deerflow.config.paths import Paths, get_paths
|
||||||
|
from deerflow.runtime import serialize_channel_values
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter(tags=["threads"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Request / Response Models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadCreateRequest(BaseModel):
|
||||||
|
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
||||||
|
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSearchRequest(BaseModel):
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
|
||||||
|
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
|
||||||
|
offset: int = Field(default=0, ge=0, description="Pagination offset")
|
||||||
|
status: str | None = Field(default=None, description="Filter by thread status")
|
||||||
|
user_id: str | None = Field(default=None, description="Filter by user ID")
|
||||||
|
assistant_id: str | None = Field(default=None, description="Filter by assistant ID")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadResponse(BaseModel):
|
||||||
|
thread_id: str = Field(description="Unique thread identifier")
|
||||||
|
status: str = Field(default="idle", description="Thread status")
|
||||||
|
created_at: str = Field(default="", description="ISO timestamp")
|
||||||
|
updated_at: str = Field(default="", description="ISO timestamp")
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
|
||||||
|
values: dict[str, Any] = Field(default_factory=dict, description="Current state values")
|
||||||
|
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadDeleteResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadStateUpdateRequest(BaseModel):
|
||||||
|
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
|
||||||
|
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
|
||||||
|
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
||||||
|
as_node: str | None = Field(default=None, description="Node identity for the update")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadStateResponse(BaseModel):
|
||||||
|
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
|
||||||
|
next: list[str] = Field(default_factory=list, description="Next nodes to execute")
|
||||||
|
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
|
||||||
|
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
|
||||||
|
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
|
||||||
|
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
|
||||||
|
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadHistoryRequest(BaseModel):
|
||||||
|
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
|
||||||
|
before: str | None = Field(default=None, description="Cursor for pagination (checkpoint_id)")
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryEntry(BaseModel):
|
||||||
|
checkpoint_id: str
|
||||||
|
parent_checkpoint_id: str | None = None
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
values: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
created_at: str | None = None
|
||||||
|
next: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_log_param(value: str) -> str:
|
||||||
|
"""Strip control characters to prevent log injection."""
|
||||||
|
|
||||||
|
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
|
||||||
|
|
||||||
|
|
||||||
|
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
|
||||||
|
"""Delete local filesystem data for a thread."""
|
||||||
|
path_manager = paths or get_paths()
|
||||||
|
try:
|
||||||
|
path_manager.delete_thread_dir(thread_id)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
|
||||||
|
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
|
||||||
|
|
||||||
|
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
|
||||||
|
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _thread_or_run_exists(
|
||||||
|
*,
|
||||||
|
request: Request,
|
||||||
|
thread_id: str,
|
||||||
|
thread_meta_storage: ThreadMetaStorage,
|
||||||
|
run_repo,
|
||||||
|
) -> bool:
|
||||||
|
request_user_id = resolve_request_user_id(request)
|
||||||
|
|
||||||
|
if request_user_id is None:
|
||||||
|
thread = await thread_meta_storage.get_thread(thread_id, user_id=None)
|
||||||
|
if thread is not None:
|
||||||
|
return True
|
||||||
|
runs = await run_repo.list_by_thread(thread_id, limit=1, user_id=None)
|
||||||
|
return bool(runs)
|
||||||
|
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
thread = await thread_meta_storage.get_thread(thread_id)
|
||||||
|
if thread is not None:
|
||||||
|
return True
|
||||||
|
runs = await run_repo.list_by_thread(thread_id, limit=1)
|
||||||
|
return bool(runs)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=ThreadResponse)
|
||||||
|
async def create_thread(
|
||||||
|
body: ThreadCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
) -> ThreadResponse:
|
||||||
|
"""Create a new thread."""
|
||||||
|
thread_id = body.thread_id or str(uuid.uuid4())
|
||||||
|
|
||||||
|
request_user_id = resolve_request_user_id(request)
|
||||||
|
if request_user_id is None:
|
||||||
|
existing = await thread_meta_storage.get_thread(thread_id, user_id=None)
|
||||||
|
else:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
existing = await thread_meta_storage.get_thread(thread_id)
|
||||||
|
if existing is not None:
|
||||||
|
return ThreadResponse(
|
||||||
|
thread_id=thread_id,
|
||||||
|
status=existing.status,
|
||||||
|
created_at=existing.created_time.isoformat() if existing.created_time else "",
|
||||||
|
updated_at=existing.updated_time.isoformat() if existing.updated_time else "",
|
||||||
|
metadata=existing.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if request_user_id is None:
|
||||||
|
created = await thread_meta_storage.ensure_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
metadata=body.metadata,
|
||||||
|
user_id=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
created = await thread_meta_storage.ensure_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
metadata=body.metadata,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to create thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||||
|
|
||||||
|
logger.info("Thread created: %s", sanitize_log_param(thread_id))
|
||||||
|
return ThreadResponse(
|
||||||
|
thread_id=thread_id,
|
||||||
|
status=created.status,
|
||||||
|
created_at=created.created_time.isoformat() if created.created_time else "",
|
||||||
|
updated_at=created.updated_time.isoformat() if created.updated_time else "",
|
||||||
|
metadata=created.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/search", response_model=list[ThreadResponse])
|
||||||
|
async def search_threads(
|
||||||
|
body: ThreadSearchRequest,
|
||||||
|
request: Request,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
) -> list[ThreadResponse]:
|
||||||
|
"""Search threads with filters."""
|
||||||
|
try:
|
||||||
|
request_user_id = resolve_request_user_id(request)
|
||||||
|
if request_user_id is None:
|
||||||
|
threads = await thread_meta_storage.search_threads(
|
||||||
|
metadata=body.metadata or None,
|
||||||
|
status=body.status,
|
||||||
|
user_id=body.user_id,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
limit=body.limit,
|
||||||
|
offset=body.offset,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with bind_request_actor_context(request):
|
||||||
|
threads = await thread_meta_storage.search_threads(
|
||||||
|
metadata=body.metadata or None,
|
||||||
|
status=body.status,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
limit=body.limit,
|
||||||
|
offset=body.offset,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to search threads")
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to search threads")
|
||||||
|
|
||||||
|
return [
|
||||||
|
ThreadResponse(
|
||||||
|
thread_id=t.thread_id,
|
||||||
|
status=t.status,
|
||||||
|
created_at=t.created_time.isoformat() if t.created_time else "",
|
||||||
|
updated_at=t.updated_time.isoformat() if t.updated_time else "",
|
||||||
|
metadata=t.metadata,
|
||||||
|
values={"title": t.display_name} if t.display_name else {},
|
||||||
|
interrupts={},
|
||||||
|
)
|
||||||
|
for t in threads
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
|
||||||
|
async def delete_thread(
|
||||||
|
thread_id: str,
|
||||||
|
checkpointer: CurrentCheckpointer,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
) -> ThreadDeleteResponse:
|
||||||
|
"""Delete a thread and all associated data."""
|
||||||
|
response = _delete_thread_data(thread_id)
|
||||||
|
|
||||||
|
# Remove checkpoints (best-effort)
|
||||||
|
try:
|
||||||
|
if hasattr(checkpointer, "adelete_thread"):
|
||||||
|
await checkpointer.adelete_thread(thread_id)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not delete checkpoints for thread %s", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
|
# Remove thread_meta (best-effort)
|
||||||
|
try:
|
||||||
|
await thread_meta_storage.delete_thread(thread_id)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not delete thread_meta for %s", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||||
|
async def get_thread_state(
|
||||||
|
thread_id: str,
|
||||||
|
request: Request,
|
||||||
|
checkpointer: CurrentCheckpointer,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
run_repo: CurrentRunRepository,
|
||||||
|
) -> ThreadStateResponse:
|
||||||
|
"""Get the latest state snapshot for a thread."""
|
||||||
|
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
|
|
||||||
|
try:
|
||||||
|
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||||
|
|
||||||
|
if checkpoint_tuple is None:
|
||||||
|
if await _thread_or_run_exists(
|
||||||
|
request=request,
|
||||||
|
thread_id=thread_id,
|
||||||
|
thread_meta_storage=thread_meta_storage,
|
||||||
|
run_repo=run_repo,
|
||||||
|
):
|
||||||
|
return ThreadStateResponse()
|
||||||
|
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||||
|
|
||||||
|
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||||
|
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||||
|
channel_values = checkpoint.get("channel_values", {})
|
||||||
|
|
||||||
|
ckpt_config = getattr(checkpoint_tuple, "config", {}) or {}
|
||||||
|
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
|
||||||
|
|
||||||
|
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
||||||
|
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None
|
||||||
|
|
||||||
|
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||||
|
next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||||
|
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
||||||
|
|
||||||
|
return ThreadStateResponse(
|
||||||
|
values=serialize_channel_values(channel_values),
|
||||||
|
next=next_nodes,
|
||||||
|
tasks=tasks,
|
||||||
|
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||||
|
checkpoint_id=checkpoint_id,
|
||||||
|
parent_checkpoint_id=parent_checkpoint_id,
|
||||||
|
metadata=metadata,
|
||||||
|
created_at=str(metadata.get("created_at", "")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||||
|
async def update_thread_state(
|
||||||
|
thread_id: str,
|
||||||
|
body: ThreadStateUpdateRequest,
|
||||||
|
checkpointer: CurrentCheckpointer,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
) -> ThreadStateResponse:
|
||||||
|
"""Update thread state (human-in-the-loop or title rename)."""
|
||||||
|
read_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
|
if body.checkpoint_id:
|
||||||
|
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||||
|
|
||||||
|
if checkpoint_tuple is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||||
|
|
||||||
|
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
|
||||||
|
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
|
||||||
|
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
|
||||||
|
|
||||||
|
if body.values:
|
||||||
|
channel_values.update(body.values)
|
||||||
|
|
||||||
|
checkpoint["channel_values"] = channel_values
|
||||||
|
metadata["updated_at"] = time.time()
|
||||||
|
|
||||||
|
if body.as_node:
|
||||||
|
metadata["source"] = "update"
|
||||||
|
metadata["step"] = metadata.get("step", 0) + 1
|
||||||
|
metadata["writes"] = {body.as_node: body.values}
|
||||||
|
|
||||||
|
write_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
|
try:
|
||||||
|
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update thread state")
|
||||||
|
|
||||||
|
new_checkpoint_id: str | None = None
|
||||||
|
if isinstance(new_config, dict):
|
||||||
|
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
|
||||||
|
|
||||||
|
# Sync title to thread_meta
|
||||||
|
if body.values and "title" in body.values:
|
||||||
|
new_title = body.values["title"]
|
||||||
|
if new_title:
|
||||||
|
try:
|
||||||
|
await thread_meta_storage.sync_thread_title(
|
||||||
|
thread_id=thread_id,
|
||||||
|
title=new_title,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to sync title for %s", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
|
return ThreadStateResponse(
|
||||||
|
values=serialize_channel_values(channel_values),
|
||||||
|
next=[],
|
||||||
|
metadata=metadata,
|
||||||
|
checkpoint_id=new_checkpoint_id,
|
||||||
|
created_at=str(metadata.get("created_at", "")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
|
||||||
|
async def get_thread_history(
|
||||||
|
thread_id: str,
|
||||||
|
body: ThreadHistoryRequest,
|
||||||
|
request: Request,
|
||||||
|
checkpointer: CurrentCheckpointer,
|
||||||
|
thread_meta_storage: CurrentThreadMetaStorage,
|
||||||
|
run_repo: CurrentRunRepository,
|
||||||
|
) -> list[HistoryEntry]:
|
||||||
|
"""Get checkpoint history for a thread."""
|
||||||
|
config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
|
if body.before:
|
||||||
|
config["configurable"]["checkpoint_id"] = body.before
|
||||||
|
|
||||||
|
entries: list[HistoryEntry] = []
|
||||||
|
is_first = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
|
||||||
|
ckpt_config = getattr(checkpoint_tuple, "config", {}) or {}
|
||||||
|
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
||||||
|
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||||
|
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||||
|
|
||||||
|
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
|
||||||
|
parent_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None
|
||||||
|
channel_values = checkpoint.get("channel_values", {})
|
||||||
|
|
||||||
|
values: dict[str, Any] = {}
|
||||||
|
if title := channel_values.get("title"):
|
||||||
|
values["title"] = title
|
||||||
|
if is_first and (messages := channel_values.get("messages")):
|
||||||
|
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||||
|
is_first = False
|
||||||
|
|
||||||
|
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||||
|
next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||||
|
|
||||||
|
entries.append(
|
||||||
|
HistoryEntry(
|
||||||
|
checkpoint_id=checkpoint_id,
|
||||||
|
parent_checkpoint_id=parent_id,
|
||||||
|
metadata=metadata,
|
||||||
|
values=values,
|
||||||
|
created_at=str(metadata.get("created_at", "")),
|
||||||
|
next=next_nodes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to get thread history")
|
||||||
|
|
||||||
|
if not entries and await _thread_or_run_exists(
|
||||||
|
request=request,
|
||||||
|
thread_id=thread_id,
|
||||||
|
thread_meta_storage=thread_meta_storage,
|
||||||
|
run_repo=run_repo,
|
||||||
|
):
|
||||||
|
return []
|
||||||
|
|
||||||
|
return entries
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
"""Memory API router for retrieving and managing global memory data."""
|
"""Memory API router for retrieving and managing global memory data."""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.plugins.auth.security.actor_context import bind_request_actor_context
|
||||||
from deerflow.agents.memory.updater import (
|
from deerflow.agents.memory.updater import (
|
||||||
clear_memory_data,
|
clear_memory_data,
|
||||||
create_memory_fact,
|
create_memory_fact,
|
||||||
@@ -13,6 +14,7 @@ from deerflow.agents.memory.updater import (
|
|||||||
update_memory_fact,
|
update_memory_fact,
|
||||||
)
|
)
|
||||||
from deerflow.config.memory_config import get_memory_config
|
from deerflow.config.memory_config import get_memory_config
|
||||||
|
from deerflow.runtime.actor_context import get_effective_user_id
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["memory"])
|
router = APIRouter(prefix="/api", tags=["memory"])
|
||||||
|
|
||||||
@@ -113,7 +115,7 @@ class MemoryStatusResponse(BaseModel):
|
|||||||
summary="Get Memory Data",
|
summary="Get Memory Data",
|
||||||
description="Retrieve the current global memory data including user context, history, and facts.",
|
description="Retrieve the current global memory data including user context, history, and facts.",
|
||||||
)
|
)
|
||||||
async def get_memory() -> MemoryResponse:
|
async def get_memory(request: Request) -> MemoryResponse:
|
||||||
"""Get the current global memory data.
|
"""Get the current global memory data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -147,8 +149,9 @@ async def get_memory() -> MemoryResponse:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
memory_data = get_memory_data()
|
with bind_request_actor_context(request):
|
||||||
return MemoryResponse(**memory_data)
|
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||||
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -158,7 +161,7 @@ async def get_memory() -> MemoryResponse:
|
|||||||
summary="Reload Memory Data",
|
summary="Reload Memory Data",
|
||||||
description="Reload memory data from the storage file, refreshing the in-memory cache.",
|
description="Reload memory data from the storage file, refreshing the in-memory cache.",
|
||||||
)
|
)
|
||||||
async def reload_memory() -> MemoryResponse:
|
async def reload_memory(request: Request) -> MemoryResponse:
|
||||||
"""Reload memory data from file.
|
"""Reload memory data from file.
|
||||||
|
|
||||||
This forces a reload of the memory data from the storage file,
|
This forces a reload of the memory data from the storage file,
|
||||||
@@ -167,8 +170,9 @@ async def reload_memory() -> MemoryResponse:
|
|||||||
Returns:
|
Returns:
|
||||||
The reloaded memory data.
|
The reloaded memory data.
|
||||||
"""
|
"""
|
||||||
memory_data = reload_memory_data()
|
with bind_request_actor_context(request):
|
||||||
return MemoryResponse(**memory_data)
|
memory_data = reload_memory_data(user_id=get_effective_user_id())
|
||||||
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
@@ -178,14 +182,15 @@ async def reload_memory() -> MemoryResponse:
|
|||||||
summary="Clear All Memory Data",
|
summary="Clear All Memory Data",
|
||||||
description="Delete all saved memory data and reset the memory structure to an empty state.",
|
description="Delete all saved memory data and reset the memory structure to an empty state.",
|
||||||
)
|
)
|
||||||
async def clear_memory() -> MemoryResponse:
|
async def clear_memory(request: Request) -> MemoryResponse:
|
||||||
"""Clear all persisted memory data."""
|
"""Clear all persisted memory data."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = clear_memory_data()
|
try:
|
||||||
except OSError as exc:
|
memory_data = clear_memory_data(user_id=get_effective_user_id())
|
||||||
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -195,20 +200,22 @@ async def clear_memory() -> MemoryResponse:
|
|||||||
summary="Create Memory Fact",
|
summary="Create Memory Fact",
|
||||||
description="Create a single saved memory fact manually.",
|
description="Create a single saved memory fact manually.",
|
||||||
)
|
)
|
||||||
async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryResponse:
|
async def create_memory_fact_endpoint(request: Request, payload: FactCreateRequest) -> MemoryResponse:
|
||||||
"""Create a single fact manually."""
|
"""Create a single fact manually."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = create_memory_fact(
|
try:
|
||||||
content=request.content,
|
memory_data = create_memory_fact(
|
||||||
category=request.category,
|
content=payload.content,
|
||||||
confidence=request.confidence,
|
category=payload.category,
|
||||||
)
|
confidence=payload.confidence,
|
||||||
except ValueError as exc:
|
user_id=get_effective_user_id(),
|
||||||
raise _map_memory_fact_value_error(exc) from exc
|
)
|
||||||
except OSError as exc:
|
except ValueError as exc:
|
||||||
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
|
raise _map_memory_fact_value_error(exc) from exc
|
||||||
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
@@ -218,16 +225,17 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
|
|||||||
summary="Delete Memory Fact",
|
summary="Delete Memory Fact",
|
||||||
description="Delete a single saved memory fact by its fact id.",
|
description="Delete a single saved memory fact by its fact id.",
|
||||||
)
|
)
|
||||||
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
async def delete_memory_fact_endpoint(fact_id: str, request: Request) -> MemoryResponse:
|
||||||
"""Delete a single fact from memory by fact id."""
|
"""Delete a single fact from memory by fact id."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = delete_memory_fact(fact_id)
|
try:
|
||||||
except KeyError as exc:
|
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
|
||||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
except KeyError as exc:
|
||||||
except OSError as exc:
|
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||||
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
@router.patch(
|
||||||
@@ -237,23 +245,25 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
|||||||
summary="Patch Memory Fact",
|
summary="Patch Memory Fact",
|
||||||
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
|
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
|
||||||
)
|
)
|
||||||
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -> MemoryResponse:
|
async def update_memory_fact_endpoint(fact_id: str, request: Request, payload: FactPatchRequest) -> MemoryResponse:
|
||||||
"""Partially update a single fact manually."""
|
"""Partially update a single fact manually."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = update_memory_fact(
|
try:
|
||||||
fact_id=fact_id,
|
memory_data = update_memory_fact(
|
||||||
content=request.content,
|
fact_id=fact_id,
|
||||||
category=request.category,
|
content=payload.content,
|
||||||
confidence=request.confidence,
|
category=payload.category,
|
||||||
)
|
confidence=payload.confidence,
|
||||||
except ValueError as exc:
|
user_id=get_effective_user_id(),
|
||||||
raise _map_memory_fact_value_error(exc) from exc
|
)
|
||||||
except KeyError as exc:
|
except ValueError as exc:
|
||||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
raise _map_memory_fact_value_error(exc) from exc
|
||||||
except OSError as exc:
|
except KeyError as exc:
|
||||||
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
|
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||||
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -263,10 +273,11 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
|
|||||||
summary="Export Memory Data",
|
summary="Export Memory Data",
|
||||||
description="Export the current global memory data as JSON for backup or transfer.",
|
description="Export the current global memory data as JSON for backup or transfer.",
|
||||||
)
|
)
|
||||||
async def export_memory() -> MemoryResponse:
|
async def export_memory(request: Request) -> MemoryResponse:
|
||||||
"""Export the current memory data."""
|
"""Export the current memory data."""
|
||||||
memory_data = get_memory_data()
|
with bind_request_actor_context(request):
|
||||||
return MemoryResponse(**memory_data)
|
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||||
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -276,14 +287,15 @@ async def export_memory() -> MemoryResponse:
|
|||||||
summary="Import Memory Data",
|
summary="Import Memory Data",
|
||||||
description="Import and overwrite the current global memory data from a JSON payload.",
|
description="Import and overwrite the current global memory data from a JSON payload.",
|
||||||
)
|
)
|
||||||
async def import_memory(request: MemoryResponse) -> MemoryResponse:
|
async def import_memory(request: Request, payload: MemoryResponse) -> MemoryResponse:
|
||||||
"""Import and persist memory data."""
|
"""Import and persist memory data."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
memory_data = import_memory_data(request.model_dump())
|
try:
|
||||||
except OSError as exc:
|
memory_data = import_memory_data(payload.model_dump(), user_id=get_effective_user_id())
|
||||||
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
except OSError as exc:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
||||||
|
|
||||||
return MemoryResponse(**memory_data)
|
return MemoryResponse(**memory_data)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -330,24 +342,25 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
|||||||
summary="Get Memory Status",
|
summary="Get Memory Status",
|
||||||
description="Retrieve both memory configuration and current data in a single request.",
|
description="Retrieve both memory configuration and current data in a single request.",
|
||||||
)
|
)
|
||||||
async def get_memory_status() -> MemoryStatusResponse:
|
async def get_memory_status(request: Request) -> MemoryStatusResponse:
|
||||||
"""Get the memory system status including configuration and data.
|
"""Get the memory system status including configuration and data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Combined memory configuration and current data.
|
Combined memory configuration and current data.
|
||||||
"""
|
"""
|
||||||
config = get_memory_config()
|
with bind_request_actor_context(request):
|
||||||
memory_data = get_memory_data()
|
config = get_memory_config()
|
||||||
|
memory_data = get_memory_data(user_id=get_effective_user_id())
|
||||||
|
|
||||||
return MemoryStatusResponse(
|
return MemoryStatusResponse(
|
||||||
config=MemoryConfigResponse(
|
config=MemoryConfigResponse(
|
||||||
enabled=config.enabled,
|
enabled=config.enabled,
|
||||||
storage_path=config.storage_path,
|
storage_path=config.storage_path,
|
||||||
debounce_seconds=config.debounce_seconds,
|
debounce_seconds=config.debounce_seconds,
|
||||||
max_facts=config.max_facts,
|
max_facts=config.max_facts,
|
||||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||||
injection_enabled=config.injection_enabled,
|
injection_enabled=config.injection_enabled,
|
||||||
max_injection_tokens=config.max_injection_tokens,
|
max_injection_tokens=config.max_injection_tokens,
|
||||||
),
|
),
|
||||||
data=MemoryResponse(**memory_data),
|
data=MemoryResponse(**memory_data),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,87 +0,0 @@
|
|||||||
"""Stateless runs endpoints -- stream and wait without a pre-existing thread.
|
|
||||||
|
|
||||||
These endpoints auto-create a temporary thread when no ``thread_id`` is
|
|
||||||
supplied in the request body. When a ``thread_id`` **is** provided, it
|
|
||||||
is reused so that conversation history is preserved across calls.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
|
|
||||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
|
|
||||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
|
||||||
from app.gateway.services import sse_consumer, start_run
|
|
||||||
from deerflow.runtime import serialize_channel_values
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
router = APIRouter(prefix="/api/runs", tags=["runs"])
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_thread_id(body: RunCreateRequest) -> str:
|
|
||||||
"""Return the thread_id from the request body, or generate a new one."""
|
|
||||||
thread_id = (body.config or {}).get("configurable", {}).get("thread_id")
|
|
||||||
if thread_id:
|
|
||||||
return str(thread_id)
|
|
||||||
return str(uuid.uuid4())
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/stream")
|
|
||||||
async def stateless_stream(body: RunCreateRequest, request: Request) -> StreamingResponse:
|
|
||||||
"""Create a run and stream events via SSE.
|
|
||||||
|
|
||||||
If ``config.configurable.thread_id`` is provided, the run is created
|
|
||||||
on the given thread so that conversation history is preserved.
|
|
||||||
Otherwise a new temporary thread is created.
|
|
||||||
"""
|
|
||||||
thread_id = _resolve_thread_id(body)
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_consumer(bridge, record, request, run_mgr),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/wait", response_model=dict)
|
|
||||||
async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
|
||||||
"""Create a run and block until completion.
|
|
||||||
|
|
||||||
If ``config.configurable.thread_id`` is provided, the run is created
|
|
||||||
on the given thread so that conversation history is preserved.
|
|
||||||
Otherwise a new temporary thread is created.
|
|
||||||
"""
|
|
||||||
thread_id = _resolve_thread_id(body)
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
|
|
||||||
if record.task is not None:
|
|
||||||
try:
|
|
||||||
await record.task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
config = {"configurable": {"thread_id": thread_id}}
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
|
||||||
if checkpoint_tuple is not None:
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
return serialize_channel_values(channel_values)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
|
||||||
|
|
||||||
return {"status": record.status.value, "error": record.error}
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter, Request
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -98,12 +98,12 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
|
|||||||
summary="Generate Follow-up Questions",
|
summary="Generate Follow-up Questions",
|
||||||
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
||||||
)
|
)
|
||||||
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
|
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
|
||||||
if not request.messages:
|
if not body.messages:
|
||||||
return SuggestionsResponse(suggestions=[])
|
return SuggestionsResponse(suggestions=[])
|
||||||
|
|
||||||
n = request.n
|
n = body.n
|
||||||
conversation = _format_conversation(request.messages)
|
conversation = _format_conversation(body.messages)
|
||||||
if not conversation:
|
if not conversation:
|
||||||
return SuggestionsResponse(suggestions=[])
|
return SuggestionsResponse(suggestions=[])
|
||||||
|
|
||||||
@@ -120,7 +120,7 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S
|
|||||||
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = create_chat_model(name=request.model_name, thinking_enabled=False)
|
model = create_chat_model(name=body.model_name, thinking_enabled=False)
|
||||||
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
|
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
|
||||||
raw = _extract_response_text(response.content)
|
raw = _extract_response_text(response.content)
|
||||||
suggestions = _parse_json_string_list(raw) or []
|
suggestions = _parse_json_string_list(raw) or []
|
||||||
|
|||||||
@@ -1,267 +0,0 @@
|
|||||||
"""Runs endpoints — create, stream, wait, cancel.
|
|
||||||
|
|
||||||
Implements the LangGraph Platform runs API on top of
|
|
||||||
:class:`deerflow.agents.runs.RunManager` and
|
|
||||||
:class:`deerflow.agents.stream_bridge.StreamBridge`.
|
|
||||||
|
|
||||||
SSE format is aligned with the LangGraph Platform protocol so that
|
|
||||||
the ``useStream`` React hook from ``@langchain/langgraph-sdk/react``
|
|
||||||
works without modification.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request
|
|
||||||
from fastapi.responses import Response, StreamingResponse
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
|
|
||||||
from app.gateway.services import sse_consumer, start_run
|
|
||||||
from deerflow.runtime import RunRecord, serialize_channel_values
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Request / response models
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class RunCreateRequest(BaseModel):
|
|
||||||
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
|
|
||||||
input: dict[str, Any] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
|
|
||||||
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
|
|
||||||
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
|
|
||||||
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
|
|
||||||
context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
|
||||||
webhook: str | None = Field(default=None, description="Completion callback URL")
|
|
||||||
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
|
||||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
|
||||||
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
|
|
||||||
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
|
|
||||||
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
|
|
||||||
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
|
|
||||||
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
|
|
||||||
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
|
|
||||||
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
|
|
||||||
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
|
|
||||||
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
|
||||||
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
|
||||||
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
|
||||||
|
|
||||||
|
|
||||||
class RunResponse(BaseModel):
|
|
||||||
run_id: str
|
|
||||||
thread_id: str
|
|
||||||
assistant_id: str | None = None
|
|
||||||
status: str
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
kwargs: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
multitask_strategy: str = "reject"
|
|
||||||
created_at: str = ""
|
|
||||||
updated_at: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _record_to_response(record: RunRecord) -> RunResponse:
|
|
||||||
return RunResponse(
|
|
||||||
run_id=record.run_id,
|
|
||||||
thread_id=record.thread_id,
|
|
||||||
assistant_id=record.assistant_id,
|
|
||||||
status=record.status.value,
|
|
||||||
metadata=record.metadata,
|
|
||||||
kwargs=record.kwargs,
|
|
||||||
multitask_strategy=record.multitask_strategy,
|
|
||||||
created_at=record.created_at,
|
|
||||||
updated_at=record.updated_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
|
||||||
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
|
|
||||||
"""Create a background run (returns immediately)."""
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
return _record_to_response(record)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs/stream")
|
|
||||||
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
|
|
||||||
"""Create a run and stream events via SSE.
|
|
||||||
|
|
||||||
The response includes a ``Content-Location`` header with the run's
|
|
||||||
resource URL, matching the LangGraph Platform protocol. The
|
|
||||||
``useStream`` React hook uses this to extract run metadata.
|
|
||||||
"""
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_consumer(bridge, record, request, run_mgr),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
# LangGraph Platform includes run metadata in this header.
|
|
||||||
# The SDK uses a greedy regex to extract the run id from this path,
|
|
||||||
# so it must point at the canonical run resource without extra suffixes.
|
|
||||||
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs/wait", response_model=dict)
|
|
||||||
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
|
|
||||||
"""Create a run and block until it completes, returning the final state."""
|
|
||||||
record = await start_run(body, thread_id, request)
|
|
||||||
|
|
||||||
if record.task is not None:
|
|
||||||
try:
|
|
||||||
await record.task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
config = {"configurable": {"thread_id": thread_id}}
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
|
||||||
if checkpoint_tuple is not None:
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
return serialize_channel_values(channel_values)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
|
||||||
|
|
||||||
return {"status": record.status.value, "error": record.error}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
|
||||||
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
|
|
||||||
"""List all runs for a thread."""
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
records = await run_mgr.list_by_thread(thread_id)
|
|
||||||
return [_record_to_response(r) for r in records]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
|
||||||
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
|
||||||
"""Get details of a specific run."""
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = run_mgr.get(run_id)
|
|
||||||
if record is None or record.thread_id != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
return _record_to_response(record)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
|
||||||
async def cancel_run(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
wait: bool = Query(default=False, description="Block until run completes after cancel"),
|
|
||||||
action: Literal["interrupt", "rollback"] = Query(default="interrupt", description="Cancel action"),
|
|
||||||
) -> Response:
|
|
||||||
"""Cancel a running or pending run.
|
|
||||||
|
|
||||||
- action=interrupt: Stop execution, keep current checkpoint (can be resumed)
|
|
||||||
- action=rollback: Stop execution, revert to pre-run checkpoint state
|
|
||||||
- wait=true: Block until the run fully stops, return 204
|
|
||||||
- wait=false: Return immediately with 202
|
|
||||||
"""
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = run_mgr.get(run_id)
|
|
||||||
if record is None or record.thread_id != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
|
|
||||||
cancelled = await run_mgr.cancel(run_id, action=action)
|
|
||||||
if not cancelled:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=409,
|
|
||||||
detail=f"Run {run_id} is not cancellable (status: {record.status.value})",
|
|
||||||
)
|
|
||||||
|
|
||||||
if wait and record.task is not None:
|
|
||||||
try:
|
|
||||||
await record.task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
return Response(status_code=202)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/runs/{run_id}/join")
|
|
||||||
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
|
|
||||||
"""Join an existing run's SSE stream."""
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = run_mgr.get(run_id)
|
|
||||||
if record is None or record.thread_id != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_consumer(bridge, record, request, run_mgr),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
|
|
||||||
async def stream_existing_run(
|
|
||||||
thread_id: str,
|
|
||||||
run_id: str,
|
|
||||||
request: Request,
|
|
||||||
action: Literal["interrupt", "rollback"] | None = Query(default=None, description="Cancel action"),
|
|
||||||
wait: int = Query(default=0, description="Block until cancelled (1) or return immediately (0)"),
|
|
||||||
):
|
|
||||||
"""Join an existing run's SSE stream (GET), or cancel-then-stream (POST).
|
|
||||||
|
|
||||||
The LangGraph SDK's ``joinStream`` and ``useStream`` stop button both use
|
|
||||||
``POST`` to this endpoint. When ``action=interrupt`` or ``action=rollback``
|
|
||||||
is present the run is cancelled first; the response then streams any
|
|
||||||
remaining buffered events so the client observes a clean shutdown.
|
|
||||||
"""
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
record = run_mgr.get(run_id)
|
|
||||||
if record is None or record.thread_id != thread_id:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
||||||
|
|
||||||
# Cancel if an action was requested (stop-button / interrupt flow)
|
|
||||||
if action is not None:
|
|
||||||
cancelled = await run_mgr.cancel(run_id, action=action)
|
|
||||||
if cancelled and wait and record.task is not None:
|
|
||||||
try:
|
|
||||||
await record.task
|
|
||||||
except (asyncio.CancelledError, Exception):
|
|
||||||
pass
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_consumer(bridge, record, request, run_mgr),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@@ -1,682 +0,0 @@
|
|||||||
"""Thread CRUD, state, and history endpoints.
|
|
||||||
|
|
||||||
Combines the existing thread-local filesystem cleanup with LangGraph
|
|
||||||
Platform-compatible thread management backed by the checkpointer.
|
|
||||||
|
|
||||||
Channel values returned in state responses are serialized through
|
|
||||||
:func:`deerflow.runtime.serialization.serialize_channel_values` to
|
|
||||||
ensure LangChain message objects are converted to JSON-safe dicts
|
|
||||||
matching the LangGraph Platform wire format expected by the
|
|
||||||
``useStream`` React hook.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from app.gateway.deps import get_checkpointer, get_store
|
|
||||||
from deerflow.config.paths import Paths, get_paths
|
|
||||||
from deerflow.runtime import serialize_channel_values
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Store namespace
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
THREADS_NS: tuple[str, ...] = ("threads",)
|
|
||||||
"""Namespace used by the Store for thread metadata records."""
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Response / request models
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadDeleteResponse(BaseModel):
|
|
||||||
"""Response model for thread cleanup."""
|
|
||||||
|
|
||||||
success: bool
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadResponse(BaseModel):
|
|
||||||
"""Response model for a single thread."""
|
|
||||||
|
|
||||||
thread_id: str = Field(description="Unique thread identifier")
|
|
||||||
status: str = Field(default="idle", description="Thread status: idle, busy, interrupted, error")
|
|
||||||
created_at: str = Field(default="", description="ISO timestamp")
|
|
||||||
updated_at: str = Field(default="", description="ISO timestamp")
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
|
|
||||||
values: dict[str, Any] = Field(default_factory=dict, description="Current state channel values")
|
|
||||||
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadCreateRequest(BaseModel):
|
|
||||||
"""Request body for creating a thread."""
|
|
||||||
|
|
||||||
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadSearchRequest(BaseModel):
|
|
||||||
"""Request body for searching threads."""
|
|
||||||
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
|
|
||||||
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
|
|
||||||
offset: int = Field(default=0, ge=0, description="Pagination offset")
|
|
||||||
status: str | None = Field(default=None, description="Filter by thread status")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadStateResponse(BaseModel):
|
|
||||||
"""Response model for thread state."""
|
|
||||||
|
|
||||||
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
|
|
||||||
next: list[str] = Field(default_factory=list, description="Next tasks to execute")
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
|
|
||||||
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
|
|
||||||
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
|
|
||||||
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
|
|
||||||
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
|
|
||||||
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadPatchRequest(BaseModel):
|
|
||||||
"""Request body for patching thread metadata."""
|
|
||||||
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadStateUpdateRequest(BaseModel):
|
|
||||||
"""Request body for updating thread state (human-in-the-loop resume)."""
|
|
||||||
|
|
||||||
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
|
|
||||||
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
|
|
||||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
|
||||||
as_node: str | None = Field(default=None, description="Node identity for the update")
|
|
||||||
|
|
||||||
|
|
||||||
class HistoryEntry(BaseModel):
|
|
||||||
"""Single checkpoint history entry."""
|
|
||||||
|
|
||||||
checkpoint_id: str
|
|
||||||
parent_checkpoint_id: str | None = None
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
values: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
created_at: str | None = None
|
|
||||||
next: list[str] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadHistoryRequest(BaseModel):
|
|
||||||
"""Request body for checkpoint history."""
|
|
||||||
|
|
||||||
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
|
|
||||||
before: str | None = Field(default=None, description="Cursor for pagination")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
|
|
||||||
"""Delete local persisted filesystem data for a thread."""
|
|
||||||
path_manager = paths or get_paths()
|
|
||||||
try:
|
|
||||||
path_manager.delete_thread_dir(thread_id)
|
|
||||||
except ValueError as exc:
|
|
||||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
|
||||||
except FileNotFoundError:
|
|
||||||
# Not critical — thread data may not exist on disk
|
|
||||||
logger.debug("No local thread data to delete for %s", thread_id)
|
|
||||||
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Failed to delete thread data for %s", thread_id)
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
|
|
||||||
|
|
||||||
logger.info("Deleted local thread data for %s", thread_id)
|
|
||||||
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
|
|
||||||
|
|
||||||
|
|
||||||
async def _store_get(store, thread_id: str) -> dict | None:
|
|
||||||
"""Fetch a thread record from the Store; returns ``None`` if absent."""
|
|
||||||
item = await store.aget(THREADS_NS, thread_id)
|
|
||||||
return item.value if item is not None else None
|
|
||||||
|
|
||||||
|
|
||||||
async def _store_put(store, record: dict) -> None:
|
|
||||||
"""Write a thread record to the Store."""
|
|
||||||
await store.aput(THREADS_NS, record["thread_id"], record)
|
|
||||||
|
|
||||||
|
|
||||||
async def _store_upsert(store, thread_id: str, *, metadata: dict | None = None, values: dict | None = None) -> None:
|
|
||||||
"""Create or refresh a thread record in the Store.
|
|
||||||
|
|
||||||
On creation the record is written with ``status="idle"``. On update only
|
|
||||||
``updated_at`` (and optionally ``metadata`` / ``values``) are changed so
|
|
||||||
that existing fields are preserved.
|
|
||||||
|
|
||||||
``values`` carries the agent-state snapshot exposed to the frontend
|
|
||||||
(currently just ``{"title": "..."}``).
|
|
||||||
"""
|
|
||||||
now = time.time()
|
|
||||||
existing = await _store_get(store, thread_id)
|
|
||||||
if existing is None:
|
|
||||||
await _store_put(
|
|
||||||
store,
|
|
||||||
{
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"status": "idle",
|
|
||||||
"created_at": now,
|
|
||||||
"updated_at": now,
|
|
||||||
"metadata": metadata or {},
|
|
||||||
"values": values or {},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
val = dict(existing)
|
|
||||||
val["updated_at"] = now
|
|
||||||
if metadata:
|
|
||||||
val.setdefault("metadata", {}).update(metadata)
|
|
||||||
if values:
|
|
||||||
val.setdefault("values", {}).update(values)
|
|
||||||
await _store_put(store, val)
|
|
||||||
|
|
||||||
|
|
||||||
def _derive_thread_status(checkpoint_tuple) -> str:
|
|
||||||
"""Derive thread status from checkpoint metadata."""
|
|
||||||
if checkpoint_tuple is None:
|
|
||||||
return "idle"
|
|
||||||
pending_writes = getattr(checkpoint_tuple, "pending_writes", None) or []
|
|
||||||
|
|
||||||
# Check for error in pending writes
|
|
||||||
for pw in pending_writes:
|
|
||||||
if len(pw) >= 2 and pw[1] == "__error__":
|
|
||||||
return "error"
|
|
||||||
|
|
||||||
# Check for pending next tasks (indicates interrupt)
|
|
||||||
tasks = getattr(checkpoint_tuple, "tasks", None)
|
|
||||||
if tasks:
|
|
||||||
return "interrupted"
|
|
||||||
|
|
||||||
return "idle"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
|
|
||||||
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
|
|
||||||
"""Delete local persisted filesystem data for a thread.
|
|
||||||
|
|
||||||
Cleans DeerFlow-managed thread directories, removes checkpoint data,
|
|
||||||
and removes the thread record from the Store.
|
|
||||||
"""
|
|
||||||
# Clean local filesystem
|
|
||||||
response = _delete_thread_data(thread_id)
|
|
||||||
|
|
||||||
# Remove from Store (best-effort)
|
|
||||||
store = get_store(request)
|
|
||||||
if store is not None:
|
|
||||||
try:
|
|
||||||
await store.adelete(THREADS_NS, thread_id)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not delete store record for thread %s (not critical)", thread_id)
|
|
||||||
|
|
||||||
# Remove checkpoints (best-effort)
|
|
||||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
|
||||||
if checkpointer is not None:
|
|
||||||
try:
|
|
||||||
if hasattr(checkpointer, "adelete_thread"):
|
|
||||||
await checkpointer.adelete_thread(thread_id)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not delete checkpoints for thread %s (not critical)", thread_id)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=ThreadResponse)
|
|
||||||
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
|
|
||||||
"""Create a new thread.
|
|
||||||
|
|
||||||
The thread record is written to the Store (for fast listing) and an
|
|
||||||
empty checkpoint is written to the checkpointer (for state reads).
|
|
||||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
|
||||||
"""
|
|
||||||
store = get_store(request)
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
thread_id = body.thread_id or str(uuid.uuid4())
|
|
||||||
now = time.time()
|
|
||||||
|
|
||||||
# Idempotency: return existing record from Store when already present
|
|
||||||
if store is not None:
|
|
||||||
existing_record = await _store_get(store, thread_id)
|
|
||||||
if existing_record is not None:
|
|
||||||
return ThreadResponse(
|
|
||||||
thread_id=thread_id,
|
|
||||||
status=existing_record.get("status", "idle"),
|
|
||||||
created_at=str(existing_record.get("created_at", "")),
|
|
||||||
updated_at=str(existing_record.get("updated_at", "")),
|
|
||||||
metadata=existing_record.get("metadata", {}),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Write thread record to Store
|
|
||||||
if store is not None:
|
|
||||||
try:
|
|
||||||
await _store_put(
|
|
||||||
store,
|
|
||||||
{
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"status": "idle",
|
|
||||||
"created_at": now,
|
|
||||||
"updated_at": now,
|
|
||||||
"metadata": body.metadata,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to write thread %s to store", thread_id)
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
|
||||||
|
|
||||||
# Write an empty checkpoint so state endpoints work immediately
|
|
||||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
|
||||||
try:
|
|
||||||
from langgraph.checkpoint.base import empty_checkpoint
|
|
||||||
|
|
||||||
ckpt_metadata = {
|
|
||||||
"step": -1,
|
|
||||||
"source": "input",
|
|
||||||
"writes": None,
|
|
||||||
"parents": {},
|
|
||||||
**body.metadata,
|
|
||||||
"created_at": now,
|
|
||||||
}
|
|
||||||
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to create checkpoint for thread %s", thread_id)
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
|
||||||
|
|
||||||
logger.info("Thread created: %s", thread_id)
|
|
||||||
return ThreadResponse(
|
|
||||||
thread_id=thread_id,
|
|
||||||
status="idle",
|
|
||||||
created_at=str(now),
|
|
||||||
updated_at=str(now),
|
|
||||||
metadata=body.metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/search", response_model=list[ThreadResponse])
|
|
||||||
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
|
|
||||||
"""Search and list threads.
|
|
||||||
|
|
||||||
Two-phase approach:
|
|
||||||
|
|
||||||
**Phase 1 — Store (fast path, O(threads))**: returns threads that were
|
|
||||||
created or run through this Gateway. Store records are tiny metadata
|
|
||||||
dicts so fetching all of them at once is cheap.
|
|
||||||
|
|
||||||
**Phase 2 — Checkpointer supplement (lazy migration)**: threads that
|
|
||||||
were created directly by LangGraph Server (and therefore absent from the
|
|
||||||
Store) are discovered here by iterating the shared checkpointer. Any
|
|
||||||
newly found thread is immediately written to the Store so that the next
|
|
||||||
search skips Phase 2 for that thread — the Store converges to a full
|
|
||||||
index over time without a one-shot migration job.
|
|
||||||
"""
|
|
||||||
store = get_store(request)
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
# Phase 1: Store
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
merged: dict[str, ThreadResponse] = {}
|
|
||||||
|
|
||||||
if store is not None:
|
|
||||||
try:
|
|
||||||
items = await store.asearch(THREADS_NS, limit=10_000)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Store search failed — falling back to checkpointer only", exc_info=True)
|
|
||||||
items = []
|
|
||||||
|
|
||||||
for item in items:
|
|
||||||
val = item.value
|
|
||||||
merged[val["thread_id"]] = ThreadResponse(
|
|
||||||
thread_id=val["thread_id"],
|
|
||||||
status=val.get("status", "idle"),
|
|
||||||
created_at=str(val.get("created_at", "")),
|
|
||||||
updated_at=str(val.get("updated_at", "")),
|
|
||||||
metadata=val.get("metadata", {}),
|
|
||||||
values=val.get("values", {}),
|
|
||||||
)
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
# Phase 2: Checkpointer supplement
|
|
||||||
# Discovers threads not yet in the Store (e.g. created by LangGraph
|
|
||||||
# Server) and lazily migrates them so future searches skip this phase.
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
try:
|
|
||||||
async for checkpoint_tuple in checkpointer.alist(None):
|
|
||||||
cfg = getattr(checkpoint_tuple, "config", {})
|
|
||||||
thread_id = cfg.get("configurable", {}).get("thread_id")
|
|
||||||
if not thread_id or thread_id in merged:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Skip sub-graph checkpoints (checkpoint_ns is non-empty for those)
|
|
||||||
if cfg.get("configurable", {}).get("checkpoint_ns", ""):
|
|
||||||
continue
|
|
||||||
|
|
||||||
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
|
|
||||||
# Strip LangGraph internal keys from the user-visible metadata dict
|
|
||||||
user_meta = {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
|
|
||||||
|
|
||||||
# Extract state values (title) from the checkpoint's channel_values
|
|
||||||
checkpoint_data = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
||||||
channel_values = checkpoint_data.get("channel_values", {})
|
|
||||||
ckpt_values = {}
|
|
||||||
if title := channel_values.get("title"):
|
|
||||||
ckpt_values["title"] = title
|
|
||||||
|
|
||||||
thread_resp = ThreadResponse(
|
|
||||||
thread_id=thread_id,
|
|
||||||
status=_derive_thread_status(checkpoint_tuple),
|
|
||||||
created_at=str(ckpt_meta.get("created_at", "")),
|
|
||||||
updated_at=str(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
|
|
||||||
metadata=user_meta,
|
|
||||||
values=ckpt_values,
|
|
||||||
)
|
|
||||||
merged[thread_id] = thread_resp
|
|
||||||
|
|
||||||
# Lazy migration — write to Store so the next search finds it there
|
|
||||||
if store is not None:
|
|
||||||
try:
|
|
||||||
await _store_upsert(store, thread_id, metadata=user_meta, values=ckpt_values or None)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Failed to migrate thread %s to store (non-fatal)", thread_id)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Checkpointer scan failed during thread search")
|
|
||||||
# Don't raise — return whatever was collected from Store + partial scan
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
# Phase 3: Filter → sort → paginate
|
|
||||||
# -----------------------------------------------------------------------
|
|
||||||
results = list(merged.values())
|
|
||||||
|
|
||||||
if body.metadata:
|
|
||||||
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
|
|
||||||
|
|
||||||
if body.status:
|
|
||||||
results = [r for r in results if r.status == body.status]
|
|
||||||
|
|
||||||
results.sort(key=lambda r: r.updated_at, reverse=True)
|
|
||||||
return results[body.offset : body.offset + body.limit]
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/{thread_id}", response_model=ThreadResponse)
|
|
||||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
|
||||||
"""Merge metadata into a thread record."""
|
|
||||||
store = get_store(request)
|
|
||||||
if store is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Store not available")
|
|
||||||
|
|
||||||
record = await _store_get(store, thread_id)
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
updated = dict(record)
|
|
||||||
updated.setdefault("metadata", {}).update(body.metadata)
|
|
||||||
updated["updated_at"] = now
|
|
||||||
|
|
||||||
try:
|
|
||||||
await _store_put(store, updated)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to patch thread %s", thread_id)
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to update thread")
|
|
||||||
|
|
||||||
return ThreadResponse(
|
|
||||||
thread_id=thread_id,
|
|
||||||
status=updated.get("status", "idle"),
|
|
||||||
created_at=str(updated.get("created_at", "")),
|
|
||||||
updated_at=str(now),
|
|
||||||
metadata=updated.get("metadata", {}),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}", response_model=ThreadResponse)
|
|
||||||
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
|
||||||
"""Get thread info.
|
|
||||||
|
|
||||||
Reads metadata from the Store and derives the accurate execution
|
|
||||||
status from the checkpointer. Falls back to the checkpointer alone
|
|
||||||
for threads that pre-date Store adoption (backward compat).
|
|
||||||
"""
|
|
||||||
store = get_store(request)
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
|
|
||||||
record: dict | None = None
|
|
||||||
if store is not None:
|
|
||||||
record = await _store_get(store, thread_id)
|
|
||||||
|
|
||||||
# Derive accurate status from the checkpointer
|
|
||||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get checkpoint for thread %s", thread_id)
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to get thread")
|
|
||||||
|
|
||||||
if record is None and checkpoint_tuple is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
# If the thread exists in the checkpointer but not the store (e.g. legacy
|
|
||||||
# data), synthesize a minimal store record from the checkpoint metadata.
|
|
||||||
if record is None and checkpoint_tuple is not None:
|
|
||||||
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
|
|
||||||
record = {
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"status": "idle",
|
|
||||||
"created_at": ckpt_meta.get("created_at", ""),
|
|
||||||
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
|
|
||||||
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
|
|
||||||
}
|
|
||||||
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
|
|
||||||
return ThreadResponse(
|
|
||||||
thread_id=thread_id,
|
|
||||||
status=status,
|
|
||||||
created_at=str(record.get("created_at", "")),
|
|
||||||
updated_at=str(record.get("updated_at", "")),
|
|
||||||
metadata=record.get("metadata", {}),
|
|
||||||
values=serialize_channel_values(channel_values),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
|
||||||
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
|
||||||
"""Get the latest state snapshot for a thread.
|
|
||||||
|
|
||||||
Channel values are serialized to ensure LangChain message objects
|
|
||||||
are converted to JSON-safe dicts.
|
|
||||||
"""
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
|
|
||||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get state for thread %s", thread_id)
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
|
||||||
|
|
||||||
if checkpoint_tuple is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
||||||
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
|
||||||
checkpoint_id = None
|
|
||||||
ckpt_config = getattr(checkpoint_tuple, "config", {})
|
|
||||||
if ckpt_config:
|
|
||||||
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
|
|
||||||
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
|
|
||||||
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
|
||||||
parent_checkpoint_id = None
|
|
||||||
if parent_config:
|
|
||||||
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id")
|
|
||||||
|
|
||||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
|
||||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
|
||||||
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
|
||||||
|
|
||||||
return ThreadStateResponse(
|
|
||||||
values=serialize_channel_values(channel_values),
|
|
||||||
next=next_tasks,
|
|
||||||
metadata=metadata,
|
|
||||||
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
|
||||||
checkpoint_id=checkpoint_id,
|
|
||||||
parent_checkpoint_id=parent_checkpoint_id,
|
|
||||||
created_at=str(metadata.get("created_at", "")),
|
|
||||||
tasks=tasks,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
|
|
||||||
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
|
|
||||||
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
|
|
||||||
|
|
||||||
Writes a new checkpoint that merges *body.values* into the latest
|
|
||||||
channel values, then syncs any updated ``title`` field back to the Store
|
|
||||||
so that ``/threads/search`` reflects the change immediately.
|
|
||||||
"""
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
store = get_store(request)
|
|
||||||
|
|
||||||
# checkpoint_ns must be present in the config for aput — default to ""
|
|
||||||
# (the root graph namespace). checkpoint_id is optional; omitting it
|
|
||||||
# fetches the latest checkpoint for the thread.
|
|
||||||
read_config: dict[str, Any] = {
|
|
||||||
"configurable": {
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"checkpoint_ns": "",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if body.checkpoint_id:
|
|
||||||
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get state for thread %s", thread_id)
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
|
||||||
|
|
||||||
if checkpoint_tuple is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
# Work on mutable copies so we don't accidentally mutate cached objects.
|
|
||||||
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
|
|
||||||
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
|
|
||||||
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
|
|
||||||
|
|
||||||
if body.values:
|
|
||||||
channel_values.update(body.values)
|
|
||||||
|
|
||||||
checkpoint["channel_values"] = channel_values
|
|
||||||
metadata["updated_at"] = time.time()
|
|
||||||
|
|
||||||
if body.as_node:
|
|
||||||
metadata["source"] = "update"
|
|
||||||
metadata["step"] = metadata.get("step", 0) + 1
|
|
||||||
metadata["writes"] = {body.as_node: body.values}
|
|
||||||
|
|
||||||
# aput requires checkpoint_ns in the config — use the same config used for the
|
|
||||||
# read (which always includes checkpoint_ns=""). Do NOT include checkpoint_id
|
|
||||||
# so that aput generates a fresh checkpoint ID for the new snapshot.
|
|
||||||
write_config: dict[str, Any] = {
|
|
||||||
"configurable": {
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"checkpoint_ns": "",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to update state for thread %s", thread_id)
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to update thread state")
|
|
||||||
|
|
||||||
new_checkpoint_id: str | None = None
|
|
||||||
if isinstance(new_config, dict):
|
|
||||||
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
|
|
||||||
|
|
||||||
# Sync title changes to the Store so /threads/search reflects them immediately.
|
|
||||||
if store is not None and body.values and "title" in body.values:
|
|
||||||
try:
|
|
||||||
await _store_upsert(store, thread_id, values={"title": body.values["title"]})
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Failed to sync title to store for thread %s (non-fatal)", thread_id)
|
|
||||||
|
|
||||||
return ThreadStateResponse(
|
|
||||||
values=serialize_channel_values(channel_values),
|
|
||||||
next=[],
|
|
||||||
metadata=metadata,
|
|
||||||
checkpoint_id=new_checkpoint_id,
|
|
||||||
created_at=str(metadata.get("created_at", "")),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
|
|
||||||
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
|
|
||||||
"""Get checkpoint history for a thread."""
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
|
|
||||||
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
|
|
||||||
if body.before:
|
|
||||||
config["configurable"]["checkpoint_id"] = body.before
|
|
||||||
|
|
||||||
entries: list[HistoryEntry] = []
|
|
||||||
try:
|
|
||||||
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
|
|
||||||
ckpt_config = getattr(checkpoint_tuple, "config", {})
|
|
||||||
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
|
||||||
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
||||||
|
|
||||||
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
|
|
||||||
parent_id = None
|
|
||||||
if parent_config:
|
|
||||||
parent_id = parent_config.get("configurable", {}).get("checkpoint_id")
|
|
||||||
|
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
|
||||||
|
|
||||||
# Derive next tasks
|
|
||||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
|
||||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
|
||||||
|
|
||||||
entries.append(
|
|
||||||
HistoryEntry(
|
|
||||||
checkpoint_id=checkpoint_id,
|
|
||||||
parent_checkpoint_id=parent_id,
|
|
||||||
metadata=metadata,
|
|
||||||
values=serialize_channel_values(channel_values),
|
|
||||||
created_at=str(metadata.get("created_at", "")),
|
|
||||||
next=next_tasks,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get history for thread %s", thread_id)
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to get thread history")
|
|
||||||
|
|
||||||
return entries
|
|
||||||
@@ -4,11 +4,13 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import stat
|
import stat
|
||||||
|
|
||||||
from fastapi import APIRouter, File, HTTPException, UploadFile
|
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from deerflow.config.paths import get_paths
|
from app.plugins.auth.security.actor_context import bind_request_actor_context
|
||||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||||
|
from deerflow.config.paths import get_paths
|
||||||
|
from deerflow.runtime.actor_context import get_effective_user_id
|
||||||
from deerflow.uploads.manager import (
|
from deerflow.uploads.manager import (
|
||||||
PathTraversalError,
|
PathTraversalError,
|
||||||
delete_file_safe,
|
delete_file_safe,
|
||||||
@@ -56,74 +58,76 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
|||||||
@router.post("", response_model=UploadResponse)
|
@router.post("", response_model=UploadResponse)
|
||||||
async def upload_files(
|
async def upload_files(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
request: Request,
|
||||||
files: list[UploadFile] = File(...),
|
files: list[UploadFile] = File(...),
|
||||||
) -> UploadResponse:
|
) -> UploadResponse:
|
||||||
"""Upload multiple files to a thread's uploads directory."""
|
"""Upload multiple files to a thread's uploads directory."""
|
||||||
if not files:
|
if not files:
|
||||||
raise HTTPException(status_code=400, detail="No files provided")
|
raise HTTPException(status_code=400, detail="No files provided")
|
||||||
|
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
uploads_dir = ensure_uploads_dir(thread_id)
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
|
|
||||||
uploaded_files = []
|
|
||||||
|
|
||||||
sandbox_provider = get_sandbox_provider()
|
|
||||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
|
||||||
sandbox = sandbox_provider.get(sandbox_id)
|
|
||||||
|
|
||||||
for file in files:
|
|
||||||
if not file.filename:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
safe_filename = normalize_filename(file.filename)
|
uploads_dir = ensure_uploads_dir(thread_id)
|
||||||
except ValueError:
|
except ValueError as e:
|
||||||
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
continue
|
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||||
|
uploaded_files = []
|
||||||
|
|
||||||
try:
|
sandbox_provider = get_sandbox_provider()
|
||||||
content = await file.read()
|
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||||
file_path = uploads_dir / safe_filename
|
sandbox = sandbox_provider.get(sandbox_id)
|
||||||
file_path.write_bytes(content)
|
|
||||||
|
|
||||||
virtual_path = upload_virtual_path(safe_filename)
|
for file in files:
|
||||||
|
if not file.filename:
|
||||||
|
continue
|
||||||
|
|
||||||
if sandbox_id != "local":
|
try:
|
||||||
_make_file_sandbox_writable(file_path)
|
safe_filename = normalize_filename(file.filename)
|
||||||
sandbox.update_file(virtual_path, content)
|
except ValueError:
|
||||||
|
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
|
||||||
|
continue
|
||||||
|
|
||||||
file_info = {
|
try:
|
||||||
"filename": safe_filename,
|
content = await file.read()
|
||||||
"size": str(len(content)),
|
file_path = uploads_dir / safe_filename
|
||||||
"path": str(sandbox_uploads / safe_filename),
|
file_path.write_bytes(content)
|
||||||
"virtual_path": virtual_path,
|
|
||||||
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
|
virtual_path = upload_virtual_path(safe_filename)
|
||||||
|
|
||||||
file_ext = file_path.suffix.lower()
|
if sandbox_id != "local":
|
||||||
if file_ext in CONVERTIBLE_EXTENSIONS:
|
_make_file_sandbox_writable(file_path)
|
||||||
md_path = await convert_file_to_markdown(file_path)
|
sandbox.update_file(virtual_path, content)
|
||||||
if md_path:
|
|
||||||
md_virtual_path = upload_virtual_path(md_path.name)
|
|
||||||
|
|
||||||
if sandbox_id != "local":
|
file_info = {
|
||||||
_make_file_sandbox_writable(md_path)
|
"filename": safe_filename,
|
||||||
sandbox.update_file(md_virtual_path, md_path.read_bytes())
|
"size": str(len(content)),
|
||||||
|
"path": str(sandbox_uploads / safe_filename),
|
||||||
|
"virtual_path": virtual_path,
|
||||||
|
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
||||||
|
}
|
||||||
|
|
||||||
file_info["markdown_file"] = md_path.name
|
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
|
||||||
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
|
|
||||||
file_info["markdown_virtual_path"] = md_virtual_path
|
|
||||||
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
|
|
||||||
|
|
||||||
uploaded_files.append(file_info)
|
file_ext = file_path.suffix.lower()
|
||||||
|
if file_ext in CONVERTIBLE_EXTENSIONS:
|
||||||
|
md_path = await convert_file_to_markdown(file_path)
|
||||||
|
if md_path:
|
||||||
|
md_virtual_path = upload_virtual_path(md_path.name)
|
||||||
|
|
||||||
except Exception as e:
|
if sandbox_id != "local":
|
||||||
logger.error(f"Failed to upload {file.filename}: {e}")
|
_make_file_sandbox_writable(md_path)
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
|
sandbox.update_file(md_virtual_path, md_path.read_bytes())
|
||||||
|
|
||||||
|
file_info["markdown_file"] = md_path.name
|
||||||
|
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
|
||||||
|
file_info["markdown_virtual_path"] = md_virtual_path
|
||||||
|
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
|
||||||
|
|
||||||
|
uploaded_files.append(file_info)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to upload {file.filename}: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
|
||||||
|
|
||||||
return UploadResponse(
|
return UploadResponse(
|
||||||
success=True,
|
success=True,
|
||||||
@@ -133,25 +137,26 @@ async def upload_files(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/list", response_model=dict)
|
@router.get("/list", response_model=dict)
|
||||||
async def list_uploaded_files(thread_id: str) -> dict:
|
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
||||||
"""List all files in a thread's uploads directory."""
|
"""List all files in a thread's uploads directory."""
|
||||||
try:
|
with bind_request_actor_context(request):
|
||||||
uploads_dir = get_uploads_dir(thread_id)
|
try:
|
||||||
except ValueError as e:
|
uploads_dir = get_uploads_dir(thread_id)
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
except ValueError as e:
|
||||||
result = list_files_in_dir(uploads_dir)
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
enrich_file_listing(result, thread_id)
|
result = list_files_in_dir(uploads_dir)
|
||||||
|
enrich_file_listing(result, thread_id)
|
||||||
|
|
||||||
# Gateway additionally includes the sandbox-relative path.
|
# Gateway additionally includes the sandbox-relative path.
|
||||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
|
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||||
for f in result["files"]:
|
for f in result["files"]:
|
||||||
f["path"] = str(sandbox_uploads / f["filename"])
|
f["path"] = str(sandbox_uploads / f["filename"])
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{filename}")
|
@router.delete("/{filename}")
|
||||||
async def delete_uploaded_file(thread_id: str, filename: str) -> dict:
|
async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict:
|
||||||
"""Delete a file from a thread's uploads directory."""
|
"""Delete a file from a thread's uploads directory."""
|
||||||
try:
|
try:
|
||||||
uploads_dir = get_uploads_dir(thread_id)
|
uploads_dir = get_uploads_dir(thread_id)
|
||||||
|
|||||||
@@ -1,367 +0,0 @@
|
|||||||
"""Run lifecycle service layer.
|
|
||||||
|
|
||||||
Centralizes the business logic for creating runs, formatting SSE
|
|
||||||
frames, and consuming stream bridge events. Router modules
|
|
||||||
(``thread_runs``, ``runs``) are thin HTTP handlers that delegate here.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
|
|
||||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
|
|
||||||
from deerflow.runtime import (
|
|
||||||
END_SENTINEL,
|
|
||||||
HEARTBEAT_SENTINEL,
|
|
||||||
ConflictError,
|
|
||||||
DisconnectMode,
|
|
||||||
RunManager,
|
|
||||||
RunRecord,
|
|
||||||
RunStatus,
|
|
||||||
StreamBridge,
|
|
||||||
UnsupportedStrategyError,
|
|
||||||
run_agent,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# SSE formatting
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def format_sse(event: str, data: Any, *, event_id: str | None = None) -> str:
|
|
||||||
"""Format a single SSE frame.
|
|
||||||
|
|
||||||
Field order: ``event:`` -> ``data:`` -> ``id:`` (optional) -> blank line.
|
|
||||||
This matches the LangGraph Platform wire format consumed by the
|
|
||||||
``useStream`` React hook and the Python ``langgraph-sdk`` SSE decoder.
|
|
||||||
"""
|
|
||||||
payload = json.dumps(data, default=str, ensure_ascii=False)
|
|
||||||
parts = [f"event: {event}", f"data: {payload}"]
|
|
||||||
if event_id:
|
|
||||||
parts.append(f"id: {event_id}")
|
|
||||||
parts.append("")
|
|
||||||
parts.append("")
|
|
||||||
return "\n".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Input / config helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_stream_modes(raw: list[str] | str | None) -> list[str]:
|
|
||||||
"""Normalize the stream_mode parameter to a list.
|
|
||||||
|
|
||||||
Default matches what ``useStream`` expects: values + messages-tuple.
|
|
||||||
"""
|
|
||||||
if raw is None:
|
|
||||||
return ["values"]
|
|
||||||
if isinstance(raw, str):
|
|
||||||
return [raw]
|
|
||||||
return raw if raw else ["values"]
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
|
|
||||||
"""Convert LangGraph Platform input format to LangChain state dict."""
|
|
||||||
if raw_input is None:
|
|
||||||
return {}
|
|
||||||
messages = raw_input.get("messages")
|
|
||||||
if messages and isinstance(messages, list):
|
|
||||||
converted = []
|
|
||||||
for msg in messages:
|
|
||||||
if isinstance(msg, dict):
|
|
||||||
role = msg.get("role", msg.get("type", "user"))
|
|
||||||
content = msg.get("content", "")
|
|
||||||
if role in ("user", "human"):
|
|
||||||
converted.append(HumanMessage(content=content))
|
|
||||||
else:
|
|
||||||
# TODO: handle other message types (system, ai, tool)
|
|
||||||
converted.append(HumanMessage(content=content))
|
|
||||||
else:
|
|
||||||
converted.append(msg)
|
|
||||||
return {**raw_input, "messages": converted}
|
|
||||||
return raw_input
|
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_ASSISTANT_ID = "lead_agent"
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_agent_factory(assistant_id: str | None):
|
|
||||||
"""Resolve the agent factory callable from config.
|
|
||||||
|
|
||||||
Custom agents are implemented as ``lead_agent`` + an ``agent_name``
|
|
||||||
injected into ``configurable`` — see :func:`build_run_config`. All
|
|
||||||
``assistant_id`` values therefore map to the same factory; the routing
|
|
||||||
happens inside ``make_lead_agent`` when it reads ``cfg["agent_name"]``.
|
|
||||||
"""
|
|
||||||
from deerflow.agents.lead_agent.agent import make_lead_agent
|
|
||||||
|
|
||||||
return make_lead_agent
|
|
||||||
|
|
||||||
|
|
||||||
def build_run_config(
|
|
||||||
thread_id: str,
|
|
||||||
request_config: dict[str, Any] | None,
|
|
||||||
metadata: dict[str, Any] | None,
|
|
||||||
*,
|
|
||||||
assistant_id: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Build a RunnableConfig dict for the agent.
|
|
||||||
|
|
||||||
When *assistant_id* refers to a custom agent (anything other than
|
|
||||||
``"lead_agent"`` / ``None``), the name is forwarded as
|
|
||||||
``configurable["agent_name"]``. ``make_lead_agent`` reads this key to
|
|
||||||
load the matching ``agents/<name>/SOUL.md`` and per-agent config —
|
|
||||||
without it the agent silently runs as the default lead agent.
|
|
||||||
|
|
||||||
This mirrors the channel manager's ``_resolve_run_params`` logic so that
|
|
||||||
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
|
||||||
identically.
|
|
||||||
"""
|
|
||||||
config: dict[str, Any] = {"recursion_limit": 100}
|
|
||||||
if request_config:
|
|
||||||
# LangGraph >= 0.6.0 introduced ``context`` as the preferred way to
|
|
||||||
# pass thread-level data and rejects requests that include both
|
|
||||||
# ``configurable`` and ``context``. If the caller already sends
|
|
||||||
# ``context``, honour it and skip our own ``configurable`` dict.
|
|
||||||
if "context" in request_config:
|
|
||||||
if "configurable" in request_config:
|
|
||||||
logger.warning(
|
|
||||||
"build_run_config: client sent both 'context' and 'configurable'; preferring 'context' (LangGraph >= 0.6.0). thread_id=%s, caller_configurable keys=%s",
|
|
||||||
thread_id,
|
|
||||||
list(request_config.get("configurable", {}).keys()),
|
|
||||||
)
|
|
||||||
config["context"] = request_config["context"]
|
|
||||||
else:
|
|
||||||
configurable = {"thread_id": thread_id}
|
|
||||||
configurable.update(request_config.get("configurable", {}))
|
|
||||||
config["configurable"] = configurable
|
|
||||||
for k, v in request_config.items():
|
|
||||||
if k not in ("configurable", "context"):
|
|
||||||
config[k] = v
|
|
||||||
else:
|
|
||||||
config["configurable"] = {"thread_id": thread_id}
|
|
||||||
|
|
||||||
# Inject custom agent name when the caller specified a non-default assistant.
|
|
||||||
# Honour an explicit configurable["agent_name"] in the request if already set.
|
|
||||||
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "configurable" in config:
|
|
||||||
if "agent_name" not in config["configurable"]:
|
|
||||||
normalized = assistant_id.strip().lower().replace("_", "-")
|
|
||||||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
|
||||||
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
|
||||||
config["configurable"]["agent_name"] = normalized
|
|
||||||
if metadata:
|
|
||||||
config.setdefault("metadata", {}).update(metadata)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Run lifecycle
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def _upsert_thread_in_store(store, thread_id: str, metadata: dict | None) -> None:
|
|
||||||
"""Create or refresh the thread record in the Store.
|
|
||||||
|
|
||||||
Called from :func:`start_run` so that threads created via the stateless
|
|
||||||
``/runs/stream`` endpoint (which never calls ``POST /threads``) still
|
|
||||||
appear in ``/threads/search`` results.
|
|
||||||
"""
|
|
||||||
# Deferred import to avoid circular import with the threads router module.
|
|
||||||
from app.gateway.routers.threads import _store_upsert
|
|
||||||
|
|
||||||
try:
|
|
||||||
await _store_upsert(store, thread_id, metadata=metadata)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to upsert thread %s in store (non-fatal)", thread_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def _sync_thread_title_after_run(
|
|
||||||
run_task: asyncio.Task,
|
|
||||||
thread_id: str,
|
|
||||||
checkpointer: Any,
|
|
||||||
store: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Wait for *run_task* to finish, then persist the generated title to the Store.
|
|
||||||
|
|
||||||
TitleMiddleware writes the generated title to the LangGraph agent state
|
|
||||||
(checkpointer) but the Gateway's Store record is not updated automatically.
|
|
||||||
This coroutine closes that gap by reading the final checkpoint after the
|
|
||||||
run completes and syncing ``values.title`` into the Store record so that
|
|
||||||
subsequent ``/threads/search`` responses include the correct title.
|
|
||||||
|
|
||||||
Runs as a fire-and-forget :func:`asyncio.create_task`; failures are
|
|
||||||
logged at DEBUG level and never propagate.
|
|
||||||
"""
|
|
||||||
# Wait for the background run task to complete (any outcome).
|
|
||||||
# asyncio.wait does not propagate task exceptions — it just returns
|
|
||||||
# when the task is done, cancelled, or failed.
|
|
||||||
await asyncio.wait({run_task})
|
|
||||||
|
|
||||||
# Deferred import to avoid circular import with the threads router module.
|
|
||||||
from app.gateway.routers.threads import _store_get, _store_put
|
|
||||||
|
|
||||||
try:
|
|
||||||
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
|
||||||
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
|
|
||||||
if ckpt_tuple is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
channel_values = ckpt_tuple.checkpoint.get("channel_values", {})
|
|
||||||
title = channel_values.get("title")
|
|
||||||
if not title:
|
|
||||||
return
|
|
||||||
|
|
||||||
existing = await _store_get(store, thread_id)
|
|
||||||
if existing is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
updated = dict(existing)
|
|
||||||
updated.setdefault("values", {})["title"] = title
|
|
||||||
updated["updated_at"] = time.time()
|
|
||||||
await _store_put(store, updated)
|
|
||||||
logger.debug("Synced title %r for thread %s", title, thread_id)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id, exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
async def start_run(
|
|
||||||
body: Any,
|
|
||||||
thread_id: str,
|
|
||||||
request: Request,
|
|
||||||
) -> RunRecord:
|
|
||||||
"""Create a RunRecord and launch the background agent task.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
body : RunCreateRequest
|
|
||||||
The validated request body (typed as Any to avoid circular import
|
|
||||||
with the router module that defines the Pydantic model).
|
|
||||||
thread_id : str
|
|
||||||
Target thread.
|
|
||||||
request : Request
|
|
||||||
FastAPI request — used to retrieve singletons from ``app.state``.
|
|
||||||
"""
|
|
||||||
bridge = get_stream_bridge(request)
|
|
||||||
run_mgr = get_run_manager(request)
|
|
||||||
checkpointer = get_checkpointer(request)
|
|
||||||
store = get_store(request)
|
|
||||||
|
|
||||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
|
||||||
|
|
||||||
try:
|
|
||||||
record = await run_mgr.create_or_reject(
|
|
||||||
thread_id,
|
|
||||||
body.assistant_id,
|
|
||||||
on_disconnect=disconnect,
|
|
||||||
metadata=body.metadata or {},
|
|
||||||
kwargs={"input": body.input, "config": body.config},
|
|
||||||
multitask_strategy=body.multitask_strategy,
|
|
||||||
)
|
|
||||||
except ConflictError as exc:
|
|
||||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
|
||||||
except UnsupportedStrategyError as exc:
|
|
||||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
|
||||||
|
|
||||||
# Ensure the thread is visible in /threads/search, even for threads that
|
|
||||||
# were never explicitly created via POST /threads (e.g. stateless runs).
|
|
||||||
store = get_store(request)
|
|
||||||
if store is not None:
|
|
||||||
await _upsert_thread_in_store(store, thread_id, body.metadata)
|
|
||||||
|
|
||||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
|
||||||
graph_input = normalize_input(body.input)
|
|
||||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
|
||||||
|
|
||||||
# Merge DeerFlow-specific context overrides into configurable.
|
|
||||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
|
||||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
|
||||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
|
||||||
context = getattr(body, "context", None)
|
|
||||||
if context:
|
|
||||||
_CONTEXT_CONFIGURABLE_KEYS = {
|
|
||||||
"model_name",
|
|
||||||
"mode",
|
|
||||||
"thinking_enabled",
|
|
||||||
"reasoning_effort",
|
|
||||||
"is_plan_mode",
|
|
||||||
"subagent_enabled",
|
|
||||||
"max_concurrent_subagents",
|
|
||||||
}
|
|
||||||
configurable = config.setdefault("configurable", {})
|
|
||||||
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
|
||||||
if key in context:
|
|
||||||
configurable.setdefault(key, context[key])
|
|
||||||
|
|
||||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
|
||||||
|
|
||||||
task = asyncio.create_task(
|
|
||||||
run_agent(
|
|
||||||
bridge,
|
|
||||||
run_mgr,
|
|
||||||
record,
|
|
||||||
checkpointer=checkpointer,
|
|
||||||
store=store,
|
|
||||||
agent_factory=agent_factory,
|
|
||||||
graph_input=graph_input,
|
|
||||||
config=config,
|
|
||||||
stream_modes=stream_modes,
|
|
||||||
stream_subgraphs=body.stream_subgraphs,
|
|
||||||
interrupt_before=body.interrupt_before,
|
|
||||||
interrupt_after=body.interrupt_after,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
record.task = task
|
|
||||||
|
|
||||||
# After the run completes, sync the title generated by TitleMiddleware from
|
|
||||||
# the checkpointer into the Store record so that /threads/search returns the
|
|
||||||
# correct title instead of an empty values dict.
|
|
||||||
if store is not None:
|
|
||||||
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store))
|
|
||||||
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
async def sse_consumer(
|
|
||||||
bridge: StreamBridge,
|
|
||||||
record: RunRecord,
|
|
||||||
request: Request,
|
|
||||||
run_mgr: RunManager,
|
|
||||||
):
|
|
||||||
"""Async generator that yields SSE frames from the bridge.
|
|
||||||
|
|
||||||
The ``finally`` block implements ``on_disconnect`` semantics:
|
|
||||||
- ``cancel``: abort the background task on client disconnect.
|
|
||||||
- ``continue``: let the task run; events are discarded.
|
|
||||||
"""
|
|
||||||
last_event_id = request.headers.get("Last-Event-ID")
|
|
||||||
try:
|
|
||||||
async for entry in bridge.subscribe(record.run_id, last_event_id=last_event_id):
|
|
||||||
if await request.is_disconnected():
|
|
||||||
break
|
|
||||||
|
|
||||||
if entry is HEARTBEAT_SENTINEL:
|
|
||||||
yield ": heartbeat\n\n"
|
|
||||||
continue
|
|
||||||
|
|
||||||
if entry is END_SENTINEL:
|
|
||||||
yield format_sse("end", None, event_id=entry.id or None)
|
|
||||||
return
|
|
||||||
|
|
||||||
yield format_sse(entry.event, entry.data, event_id=entry.id or None)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if record.status in (RunStatus.pending, RunStatus.running):
|
|
||||||
if record.on_disconnect == DisconnectMode.cancel:
|
|
||||||
await run_mgr.cancel(record.run_id)
|
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
"""Gateway service layer."""
|
||||||
|
|
||||||
|
"""Compatibility package for app service submodules."""
|
||||||
|
|
||||||
|
__all__: list[str] = []
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
"""Runs app layer services."""
|
||||||
|
|
||||||
|
from app.infra.storage import StorageRunObserver
|
||||||
|
from .input import (
|
||||||
|
AdaptedRunRequest,
|
||||||
|
RunSpecBuilder,
|
||||||
|
UnsupportedRunFeatureError,
|
||||||
|
adapt_create_run_request,
|
||||||
|
adapt_create_stream_request,
|
||||||
|
adapt_create_wait_request,
|
||||||
|
adapt_join_stream_request,
|
||||||
|
adapt_join_wait_request,
|
||||||
|
)
|
||||||
|
from .store import AppRunCreateStore, AppRunDeleteStore, AppRunQueryStore
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AdaptedRunRequest",
|
||||||
|
"AppRunCreateStore",
|
||||||
|
"AppRunDeleteStore",
|
||||||
|
"AppRunQueryStore",
|
||||||
|
"RunSpecBuilder",
|
||||||
|
"StorageRunObserver",
|
||||||
|
"UnsupportedRunFeatureError",
|
||||||
|
"adapt_create_run_request",
|
||||||
|
"adapt_create_stream_request",
|
||||||
|
"adapt_create_wait_request",
|
||||||
|
"adapt_join_stream_request",
|
||||||
|
"adapt_join_wait_request",
|
||||||
|
]
|
||||||
@@ -0,0 +1,150 @@
|
|||||||
|
"""Facade factory - assembles RunsFacade with dependencies."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
from app.gateway.dependencies import get_checkpointer, get_stream_bridge
|
||||||
|
from deerflow.runtime.runs.facade import RunsFacade
|
||||||
|
from deerflow.runtime.runs.facade import RunsRuntime
|
||||||
|
from deerflow.runtime.runs.internal.execution.supervisor import RunSupervisor
|
||||||
|
from deerflow.runtime.runs.internal.planner import ExecutionPlanner
|
||||||
|
from deerflow.runtime.runs.internal.registry import RunRegistry
|
||||||
|
from deerflow.runtime.runs.internal.streams import RunStreamService
|
||||||
|
from deerflow.runtime.runs.internal.wait import RunWaitService
|
||||||
|
|
||||||
|
from app.infra.storage import StorageRunObserver, ThreadMetaStorage
|
||||||
|
from app.infra.storage.runs import RunDeleteRepository, RunReadRepository, RunWriteRepository
|
||||||
|
from .store import AppRunCreateStore, AppRunDeleteStore, AppRunQueryStore
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from deerflow.runtime.stream_bridge import StreamBridge
|
||||||
|
|
||||||
|
|
||||||
|
type AgentFactory = Callable[..., object]
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton registry (shared across requests)
|
||||||
|
_registry: RunRegistry | None = None
|
||||||
|
_supervisor: RunSupervisor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_state(request: Request, attr: str, label: str):
|
||||||
|
value = getattr(request.app.state, attr, None)
|
||||||
|
if value is None:
|
||||||
|
raise HTTPException(status_code=503, detail=f"{label} not available")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def get_registry() -> RunRegistry:
|
||||||
|
"""Get or create singleton registry."""
|
||||||
|
global _registry
|
||||||
|
if _registry is None:
|
||||||
|
_registry = RunRegistry()
|
||||||
|
return _registry
|
||||||
|
|
||||||
|
|
||||||
|
def get_supervisor() -> RunSupervisor:
|
||||||
|
"""Get or create singleton run supervisor."""
|
||||||
|
global _supervisor
|
||||||
|
if _supervisor is None:
|
||||||
|
_supervisor = RunSupervisor()
|
||||||
|
return _supervisor
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_agent_factory(assistant_id: str | None) -> AgentFactory:
|
||||||
|
"""Resolve the agent factory callable from config."""
|
||||||
|
from deerflow.agents.lead_agent.agent import make_lead_agent
|
||||||
|
|
||||||
|
return make_lead_agent
|
||||||
|
|
||||||
|
|
||||||
|
def build_runs_facade(
|
||||||
|
*,
|
||||||
|
stream_bridge: "StreamBridge",
|
||||||
|
checkpointer: object,
|
||||||
|
store: object | None = None,
|
||||||
|
run_read_repo: RunReadRepository | None = None,
|
||||||
|
run_write_repo: RunWriteRepository | None = None,
|
||||||
|
run_delete_repo: RunDeleteRepository | None = None,
|
||||||
|
thread_meta_storage: ThreadMetaStorage | None = None,
|
||||||
|
run_event_store: object | None = None,
|
||||||
|
) -> RunsFacade:
|
||||||
|
"""
|
||||||
|
Build RunsFacade with all dependencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_bridge: StreamBridge instance
|
||||||
|
checkpointer: LangGraph checkpointer
|
||||||
|
store: Optional LangGraph runtime store
|
||||||
|
run_read_repo: Optional run repository for durable reads
|
||||||
|
run_write_repo: Optional run repository for durable writes
|
||||||
|
run_delete_repo: Optional run repository for durable deletes
|
||||||
|
thread_meta_storage: Optional thread metadata storage adapter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured RunsFacade instance
|
||||||
|
"""
|
||||||
|
registry = get_registry()
|
||||||
|
planner = ExecutionPlanner()
|
||||||
|
supervisor = get_supervisor()
|
||||||
|
|
||||||
|
stream_service = RunStreamService(stream_bridge)
|
||||||
|
wait_service = RunWaitService(stream_service)
|
||||||
|
query_store = AppRunQueryStore(run_read_repo) if run_read_repo else None
|
||||||
|
create_store = (
|
||||||
|
AppRunCreateStore(run_write_repo, thread_meta_storage=thread_meta_storage)
|
||||||
|
if run_write_repo
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
delete_store = AppRunDeleteStore(run_delete_repo) if run_delete_repo else None
|
||||||
|
|
||||||
|
# Build storage observer if repositories provided
|
||||||
|
storage_observer = None
|
||||||
|
if run_write_repo or thread_meta_storage:
|
||||||
|
storage_observer = StorageRunObserver(
|
||||||
|
run_write_repo=run_write_repo,
|
||||||
|
thread_meta_storage=thread_meta_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
return RunsFacade(
|
||||||
|
registry=registry,
|
||||||
|
planner=planner,
|
||||||
|
supervisor=supervisor,
|
||||||
|
stream_service=stream_service,
|
||||||
|
wait_service=wait_service,
|
||||||
|
runtime=RunsRuntime(
|
||||||
|
bridge=stream_bridge,
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
store=store,
|
||||||
|
event_store=run_event_store,
|
||||||
|
agent_factory_resolver=resolve_agent_factory,
|
||||||
|
),
|
||||||
|
observer=storage_observer,
|
||||||
|
query_store=query_store,
|
||||||
|
create_store=create_store,
|
||||||
|
delete_store=delete_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_runs_facade_from_request(request: "Request") -> RunsFacade:
|
||||||
|
"""
|
||||||
|
Build RunsFacade from FastAPI request context.
|
||||||
|
|
||||||
|
Extracts dependencies from request.app.state.
|
||||||
|
"""
|
||||||
|
app_state = request.app.state
|
||||||
|
|
||||||
|
return build_runs_facade(
|
||||||
|
stream_bridge=get_stream_bridge(request),
|
||||||
|
checkpointer=get_checkpointer(request),
|
||||||
|
store=getattr(request.app.state, "store", None),
|
||||||
|
run_read_repo=getattr(app_state, "run_read_repo", None),
|
||||||
|
run_write_repo=getattr(app_state, "run_write_repo", None),
|
||||||
|
run_delete_repo=getattr(app_state, "run_delete_repo", None),
|
||||||
|
thread_meta_storage=getattr(app_state, "thread_meta_storage", None),
|
||||||
|
run_event_store=getattr(app_state, "run_event_store", None),
|
||||||
|
)
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
"""Input adapters for app-owned runs entrypoints."""
|
||||||
|
|
||||||
|
from .request_adapter import (
|
||||||
|
AdaptedRunRequest,
|
||||||
|
adapt_create_run_request,
|
||||||
|
adapt_create_stream_request,
|
||||||
|
adapt_create_wait_request,
|
||||||
|
adapt_join_stream_request,
|
||||||
|
adapt_join_wait_request,
|
||||||
|
)
|
||||||
|
from .spec_builder import RunSpecBuilder, UnsupportedRunFeatureError
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AdaptedRunRequest",
|
||||||
|
"RunSpecBuilder",
|
||||||
|
"UnsupportedRunFeatureError",
|
||||||
|
"adapt_create_run_request",
|
||||||
|
"adapt_create_stream_request",
|
||||||
|
"adapt_create_wait_request",
|
||||||
|
"adapt_join_stream_request",
|
||||||
|
"adapt_join_wait_request",
|
||||||
|
]
|
||||||
@@ -0,0 +1,127 @@
|
|||||||
|
"""App-owned request adapter for runs entrypoints."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from deerflow.runtime.stream_bridge import JSONValue
|
||||||
|
from deerflow.runtime.runs.types import RunIntent
|
||||||
|
|
||||||
|
type RequestBody = dict[str, JSONValue]
|
||||||
|
type RequestQuery = dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AdaptedRunRequest:
|
||||||
|
"""
|
||||||
|
统一的内部请求 DTO.
|
||||||
|
|
||||||
|
路由层只负责提取 path/query/body,适配器负责转成稳定内部结构。
|
||||||
|
"""
|
||||||
|
|
||||||
|
intent: RunIntent
|
||||||
|
thread_id: str | None
|
||||||
|
run_id: str | None
|
||||||
|
body: RequestBody
|
||||||
|
headers: dict[str, str]
|
||||||
|
query: RequestQuery
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_event_id(self) -> str | None:
|
||||||
|
"""Extract Last-Event-ID from headers."""
|
||||||
|
return self.headers.get("last-event-id") or self.headers.get("Last-Event-ID")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_stateless(self) -> bool:
|
||||||
|
"""Check if this is a stateless request."""
|
||||||
|
return self.thread_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def adapt_create_run_request(
|
||||||
|
*,
|
||||||
|
thread_id: str | None,
|
||||||
|
body: RequestBody,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
query: RequestQuery | None = None,
|
||||||
|
) -> AdaptedRunRequest:
|
||||||
|
"""Adapt POST /threads/{thread_id}/runs or POST /runs."""
|
||||||
|
return AdaptedRunRequest(
|
||||||
|
intent="create_background",
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=None,
|
||||||
|
body=body,
|
||||||
|
headers=headers or {},
|
||||||
|
query=query or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def adapt_create_stream_request(
|
||||||
|
*,
|
||||||
|
thread_id: str | None,
|
||||||
|
body: RequestBody,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
query: RequestQuery | None = None,
|
||||||
|
) -> AdaptedRunRequest:
|
||||||
|
"""Adapt POST /threads/{thread_id}/runs/stream or POST /runs/stream."""
|
||||||
|
return AdaptedRunRequest(
|
||||||
|
intent="create_and_stream",
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=None,
|
||||||
|
body=body,
|
||||||
|
headers=headers or {},
|
||||||
|
query=query or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def adapt_create_wait_request(
|
||||||
|
*,
|
||||||
|
thread_id: str | None,
|
||||||
|
body: RequestBody,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
query: RequestQuery | None = None,
|
||||||
|
) -> AdaptedRunRequest:
|
||||||
|
"""Adapt POST /threads/{thread_id}/runs/wait or POST /runs/wait."""
|
||||||
|
return AdaptedRunRequest(
|
||||||
|
intent="create_and_wait",
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=None,
|
||||||
|
body=body,
|
||||||
|
headers=headers or {},
|
||||||
|
query=query or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def adapt_join_stream_request(
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
query: RequestQuery | None = None,
|
||||||
|
) -> AdaptedRunRequest:
|
||||||
|
"""Adapt GET /threads/{thread_id}/runs/{run_id}/stream."""
|
||||||
|
return AdaptedRunRequest(
|
||||||
|
intent="join_stream",
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
body={},
|
||||||
|
headers=headers or {},
|
||||||
|
query=query or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def adapt_join_wait_request(
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
query: RequestQuery | None = None,
|
||||||
|
) -> AdaptedRunRequest:
|
||||||
|
"""Adapt GET /threads/{thread_id}/runs/{run_id}/join."""
|
||||||
|
return AdaptedRunRequest(
|
||||||
|
intent="join_wait",
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
body={},
|
||||||
|
headers=headers or {},
|
||||||
|
query=query or {},
|
||||||
|
)
|
||||||
@@ -0,0 +1,254 @@
|
|||||||
|
"""App-owned RunSpec builder."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from deerflow.runtime.runs.types import CheckpointRequest, RunScope, RunSpec
|
||||||
|
from deerflow.runtime.stream_bridge import JSONValue
|
||||||
|
|
||||||
|
from .request_adapter import AdaptedRunRequest
|
||||||
|
|
||||||
|
type JSONMapping = dict[str, JSONValue]
|
||||||
|
type GraphInput = dict[str, object]
|
||||||
|
type RunnableConfigDict = dict[str, object]
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedRunFeatureError(ValueError):
|
||||||
|
"""Raised when a phase1-unsupported feature is requested."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RunSpecBuilder:
|
||||||
|
"""
|
||||||
|
Build RunSpec from AdaptedRunRequest.
|
||||||
|
|
||||||
|
Phase 1 rules:
|
||||||
|
1. messages-tuple normalized to messages
|
||||||
|
2. enqueue not supported
|
||||||
|
3. rollback not supported
|
||||||
|
4. after_seconds not supported
|
||||||
|
5. stream_resumable accepted
|
||||||
|
6. stateless auto-generates temporary thread
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Phase 1 unsupported features
|
||||||
|
UNSUPPORTED_MULTITASK_STRATEGIES = {"enqueue"}
|
||||||
|
UNSUPPORTED_ACTIONS = {"rollback"}
|
||||||
|
|
||||||
|
# Default stream modes
|
||||||
|
DEFAULT_STREAM_MODES = ["values", "messages"]
|
||||||
|
CONTEXT_CONFIGURABLE_KEYS = frozenset({
|
||||||
|
"model_name",
|
||||||
|
"mode",
|
||||||
|
"thinking_enabled",
|
||||||
|
"reasoning_effort",
|
||||||
|
"is_plan_mode",
|
||||||
|
"subagent_enabled",
|
||||||
|
"max_concurrent_subagents",
|
||||||
|
})
|
||||||
|
DEFAULT_ASSISTANT_ID = "lead_agent"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _as_json_mapping(value: JSONValue | None) -> JSONMapping | None:
|
||||||
|
return value if isinstance(value, dict) else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _as_string_list(value: JSONValue | None) -> list[str] | None:
|
||||||
|
if not isinstance(value, list):
|
||||||
|
return None
|
||||||
|
return [item for item in value if isinstance(item, str)]
|
||||||
|
|
||||||
|
def build(self, request: AdaptedRunRequest) -> RunSpec:
|
||||||
|
"""Build RunSpec from adapted request."""
|
||||||
|
body = request.body
|
||||||
|
|
||||||
|
# Validate phase1 constraints
|
||||||
|
self._validate_constraints(body)
|
||||||
|
|
||||||
|
# Build scope
|
||||||
|
scope = self._build_scope(request)
|
||||||
|
|
||||||
|
# Normalize stream modes
|
||||||
|
stream_modes = self._normalize_stream_modes(body.get("stream_mode"))
|
||||||
|
|
||||||
|
# Build checkpoint request
|
||||||
|
checkpoint_request = self._build_checkpoint_request(body)
|
||||||
|
|
||||||
|
config = self._build_runnable_config(
|
||||||
|
thread_id=scope.thread_id,
|
||||||
|
request_config=self._as_json_mapping(body.get("config")),
|
||||||
|
metadata=self._as_json_mapping(body.get("metadata")),
|
||||||
|
assistant_id=body.get("assistant_id"),
|
||||||
|
context=self._as_json_mapping(body.get("context")),
|
||||||
|
)
|
||||||
|
|
||||||
|
return RunSpec(
|
||||||
|
intent=request.intent,
|
||||||
|
scope=scope,
|
||||||
|
assistant_id=body.get("assistant_id") if isinstance(body.get("assistant_id"), str) else None,
|
||||||
|
input=self._normalize_input(self._as_json_mapping(body.get("input"))),
|
||||||
|
command=self._as_json_mapping(body.get("command")),
|
||||||
|
runnable_config=config,
|
||||||
|
context=self._as_json_mapping(body.get("context")),
|
||||||
|
metadata=self._as_json_mapping(body.get("metadata")) or {},
|
||||||
|
stream_modes=stream_modes,
|
||||||
|
stream_subgraphs=bool(body.get("stream_subgraphs", False)),
|
||||||
|
stream_resumable=bool(body.get("stream_resumable", False)),
|
||||||
|
on_disconnect=body.get("on_disconnect", "cancel") if body.get("on_disconnect") in {"cancel", "continue"} else "cancel",
|
||||||
|
on_completion=body.get("on_completion", "keep") if body.get("on_completion") in {"delete", "keep"} else "keep",
|
||||||
|
multitask_strategy=body.get("multitask_strategy", "reject") if body.get("multitask_strategy") in {"reject", "interrupt"} else "reject",
|
||||||
|
interrupt_before="*" if body.get("interrupt_before") == "*" else self._as_string_list(body.get("interrupt_before")),
|
||||||
|
interrupt_after="*" if body.get("interrupt_after") == "*" else self._as_string_list(body.get("interrupt_after")),
|
||||||
|
checkpoint_request=checkpoint_request,
|
||||||
|
follow_up_to_run_id=body.get("follow_up_to_run_id") if isinstance(body.get("follow_up_to_run_id"), str) else None,
|
||||||
|
webhook=body.get("webhook") if isinstance(body.get("webhook"), str) else None,
|
||||||
|
feedback_keys=self._as_string_list(body.get("feedback_keys")),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_constraints(self, body: JSONMapping) -> None:
|
||||||
|
"""Validate phase1 constraints, raise UnsupportedRunFeatureError if violated."""
|
||||||
|
# Check multitask_strategy
|
||||||
|
strategy = body.get("multitask_strategy", "reject")
|
||||||
|
if strategy in self.UNSUPPORTED_MULTITASK_STRATEGIES:
|
||||||
|
raise UnsupportedRunFeatureError(
|
||||||
|
f"multitask_strategy '{strategy}' is not supported in phase1. "
|
||||||
|
f"Supported: reject, interrupt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for rollback action
|
||||||
|
command = self._as_json_mapping(body.get("command")) or {}
|
||||||
|
if command.get("action") in self.UNSUPPORTED_ACTIONS:
|
||||||
|
raise UnsupportedRunFeatureError(
|
||||||
|
f"action '{command.get('action')}' is not supported in phase1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for after_seconds
|
||||||
|
if body.get("after_seconds") is not None:
|
||||||
|
raise UnsupportedRunFeatureError("after_seconds is not supported in phase1")
|
||||||
|
|
||||||
|
def _build_scope(self, request: AdaptedRunRequest) -> RunScope:
|
||||||
|
"""Build RunScope from request."""
|
||||||
|
if request.is_stateless:
|
||||||
|
# Stateless: generate temporary thread
|
||||||
|
return RunScope(
|
||||||
|
kind="stateless",
|
||||||
|
thread_id=str(uuid.uuid4()),
|
||||||
|
temporary=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert request.thread_id is not None
|
||||||
|
return RunScope(
|
||||||
|
kind="stateful",
|
||||||
|
thread_id=request.thread_id,
|
||||||
|
temporary=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _normalize_stream_modes(self, stream_mode: JSONValue | None) -> list[str]:
|
||||||
|
"""Normalize stream_mode to list, convert messages-tuple to messages."""
|
||||||
|
if stream_mode is None:
|
||||||
|
return self.DEFAULT_STREAM_MODES.copy()
|
||||||
|
|
||||||
|
if isinstance(stream_mode, str):
|
||||||
|
modes = [stream_mode]
|
||||||
|
elif isinstance(stream_mode, list):
|
||||||
|
modes = [mode for mode in stream_mode if isinstance(mode, str)]
|
||||||
|
else:
|
||||||
|
return self.DEFAULT_STREAM_MODES.copy()
|
||||||
|
|
||||||
|
return ["messages" if m == "messages-tuple" else m for m in modes]
|
||||||
|
|
||||||
|
def _build_checkpoint_request(self, body: JSONMapping) -> CheckpointRequest | None:
|
||||||
|
"""Build CheckpointRequest if checkpoint data is provided."""
|
||||||
|
checkpoint_id = body.get("checkpoint_id")
|
||||||
|
checkpoint = self._as_json_mapping(body.get("checkpoint"))
|
||||||
|
|
||||||
|
if not isinstance(checkpoint_id, str) and checkpoint is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return CheckpointRequest(
|
||||||
|
checkpoint_id=checkpoint_id if isinstance(checkpoint_id, str) else None,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _normalize_input(self, raw_input: JSONMapping | None) -> GraphInput | None:
|
||||||
|
"""Convert HTTP-friendly message dicts into LangChain message objects."""
|
||||||
|
if raw_input is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
messages = raw_input.get("messages")
|
||||||
|
if not messages or not isinstance(messages, list):
|
||||||
|
return raw_input
|
||||||
|
|
||||||
|
converted: list[object] = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
role = msg.get("role", msg.get("type", "user"))
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if role in ("user", "human"):
|
||||||
|
converted.append(HumanMessage(content=content))
|
||||||
|
else:
|
||||||
|
converted.append(HumanMessage(content=content))
|
||||||
|
else:
|
||||||
|
converted.append(msg)
|
||||||
|
return {**raw_input, "messages": converted}
|
||||||
|
|
||||||
|
def _build_runnable_config(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
request_config: JSONMapping | None,
|
||||||
|
metadata: JSONMapping | None,
|
||||||
|
assistant_id: str | None,
|
||||||
|
context: JSONMapping | None,
|
||||||
|
) -> RunnableConfigDict:
|
||||||
|
"""Build RunnableConfig from request payload and app-side rules."""
|
||||||
|
config: RunnableConfigDict = {"recursion_limit": 100}
|
||||||
|
|
||||||
|
if request_config:
|
||||||
|
if "context" in request_config:
|
||||||
|
config["context"] = request_config["context"]
|
||||||
|
else:
|
||||||
|
configurable = {"thread_id": thread_id}
|
||||||
|
raw_configurable = request_config.get("configurable")
|
||||||
|
if isinstance(raw_configurable, dict):
|
||||||
|
configurable.update(raw_configurable)
|
||||||
|
config["configurable"] = configurable
|
||||||
|
|
||||||
|
for key, value in request_config.items():
|
||||||
|
if key not in ("configurable", "context"):
|
||||||
|
config[key] = value
|
||||||
|
else:
|
||||||
|
config["configurable"] = {"thread_id": thread_id}
|
||||||
|
|
||||||
|
configurable = config.get("configurable")
|
||||||
|
if (
|
||||||
|
assistant_id
|
||||||
|
and assistant_id != self.DEFAULT_ASSISTANT_ID
|
||||||
|
and isinstance(configurable, dict)
|
||||||
|
and "agent_name" not in configurable
|
||||||
|
):
|
||||||
|
normalized = assistant_id.strip().lower().replace("_", "-")
|
||||||
|
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization."
|
||||||
|
)
|
||||||
|
configurable["agent_name"] = normalized
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
existing_metadata = config.get("metadata")
|
||||||
|
if isinstance(existing_metadata, dict):
|
||||||
|
existing_metadata.update(metadata)
|
||||||
|
else:
|
||||||
|
config["metadata"] = dict(metadata)
|
||||||
|
|
||||||
|
if context and isinstance(configurable, dict):
|
||||||
|
for key in self.CONTEXT_CONFIGURABLE_KEYS:
|
||||||
|
if key in context:
|
||||||
|
configurable.setdefault(key, context[key])
|
||||||
|
|
||||||
|
return config
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
"""Compatibility wrapper for the app-owned storage observer."""
|
||||||
|
|
||||||
|
from app.infra.storage.runs import StorageRunObserver
|
||||||
|
|
||||||
|
__all__ = ["StorageRunObserver"]
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
"""App-owned runs store adapters."""
|
||||||
|
|
||||||
|
from .create_store import AppRunCreateStore
|
||||||
|
from .delete_store import AppRunDeleteStore
|
||||||
|
from .query_store import AppRunQueryStore
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AppRunCreateStore",
|
||||||
|
"AppRunDeleteStore",
|
||||||
|
"AppRunQueryStore",
|
||||||
|
]
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
"""App-owned durable run creation adapter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from deerflow.runtime.runs.store import RunCreateStore
|
||||||
|
from deerflow.runtime.runs.types import RunRecord
|
||||||
|
|
||||||
|
from app.infra.storage import ThreadMetaStorage
|
||||||
|
from app.infra.storage.runs import RunWriteRepository
|
||||||
|
|
||||||
|
|
||||||
|
class AppRunCreateStore(RunCreateStore):
|
||||||
|
"""Write the initial durable row for a newly created run."""
|
||||||
|
|
||||||
|
def __init__(self, repo: RunWriteRepository, thread_meta_storage: ThreadMetaStorage | None = None) -> None:
|
||||||
|
self._repo = repo
|
||||||
|
self._thread_meta_storage = thread_meta_storage
|
||||||
|
|
||||||
|
async def create_run(self, record: RunRecord) -> None:
|
||||||
|
await self._repo.create(
|
||||||
|
run_id=record.run_id,
|
||||||
|
thread_id=record.thread_id,
|
||||||
|
assistant_id=record.assistant_id,
|
||||||
|
status=str(record.status),
|
||||||
|
metadata=record.metadata,
|
||||||
|
follow_up_to_run_id=record.follow_up_to_run_id,
|
||||||
|
created_at=record.created_at,
|
||||||
|
)
|
||||||
|
if self._thread_meta_storage is not None and record.assistant_id:
|
||||||
|
thread = await self._thread_meta_storage.ensure_thread(
|
||||||
|
thread_id=record.thread_id,
|
||||||
|
assistant_id=record.assistant_id,
|
||||||
|
)
|
||||||
|
if thread.assistant_id != record.assistant_id:
|
||||||
|
await self._thread_meta_storage.sync_thread_assistant_id(
|
||||||
|
thread_id=record.thread_id,
|
||||||
|
assistant_id=record.assistant_id,
|
||||||
|
)
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
"""App-owned durable run deletion adapter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from deerflow.runtime.runs.store import RunDeleteStore
|
||||||
|
|
||||||
|
from app.infra.storage.runs import RunDeleteRepository
|
||||||
|
|
||||||
|
|
||||||
|
class AppRunDeleteStore(RunDeleteStore):
|
||||||
|
"""Delete durable run rows via the app storage adapter."""
|
||||||
|
|
||||||
|
def __init__(self, repo: RunDeleteRepository) -> None:
|
||||||
|
self._repo = repo
|
||||||
|
|
||||||
|
async def delete_run(self, run_id: str) -> bool:
|
||||||
|
return await self._repo.delete(run_id)
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
"""App-owned durable run query adapter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from deerflow.runtime.runs.store import RunQueryStore
|
||||||
|
from deerflow.runtime.runs.types import RunRecord, RunStatus
|
||||||
|
|
||||||
|
from app.infra.storage.runs import RunReadRepository, RunRow
|
||||||
|
|
||||||
|
|
||||||
|
class AppRunQueryStore(RunQueryStore):
|
||||||
|
"""Map app-side durable run rows into harness RunRecord DTOs."""
|
||||||
|
|
||||||
|
def __init__(self, repo: RunReadRepository) -> None:
|
||||||
|
self._repo = repo
|
||||||
|
|
||||||
|
async def get_run(self, run_id: str) -> RunRecord | None:
|
||||||
|
row = await self._repo.get(run_id)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return self._to_run_record(row)
|
||||||
|
|
||||||
|
async def list_runs(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> list[RunRecord]:
|
||||||
|
rows = await self._repo.list_by_thread(thread_id, limit=limit)
|
||||||
|
return [self._to_run_record(row) for row in rows]
|
||||||
|
|
||||||
|
def _to_run_record(self, row: RunRow) -> RunRecord:
|
||||||
|
return RunRecord(
|
||||||
|
run_id=row["run_id"],
|
||||||
|
thread_id=row["thread_id"],
|
||||||
|
assistant_id=row.get("assistant_id"),
|
||||||
|
status=RunStatus(row.get("status", "pending")),
|
||||||
|
temporary=False,
|
||||||
|
multitask_strategy=row.get("multitask_strategy", "reject"),
|
||||||
|
metadata=row.get("metadata", {}),
|
||||||
|
follow_up_to_run_id=row.get("follow_up_to_run_id"),
|
||||||
|
created_at=row.get("created_at", ""),
|
||||||
|
updated_at=row.get("updated_at", ""),
|
||||||
|
started_at=row.get("started_at"),
|
||||||
|
ended_at=row.get("ended_at"),
|
||||||
|
error=row.get("error"),
|
||||||
|
)
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Application-owned infrastructure adapters and wiring."""
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""Run event store backends owned by app infrastructure."""
|
||||||
|
|
||||||
|
from .factory import build_run_event_store
|
||||||
|
from .jsonl_store import JsonlRunEventStore
|
||||||
|
|
||||||
|
__all__ = ["JsonlRunEventStore", "build_run_event_store"]
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
"""Factory for app-owned run event store backends."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from app.infra.storage import AppRunEventStore
|
||||||
|
from deerflow.config import get_app_config
|
||||||
|
|
||||||
|
from .jsonl_store import JsonlRunEventStore
|
||||||
|
|
||||||
|
|
||||||
|
def build_run_event_store(session_factory: async_sessionmaker[AsyncSession]) -> AppRunEventStore | JsonlRunEventStore:
|
||||||
|
"""Build the run event store selected by app configuration."""
|
||||||
|
|
||||||
|
config = get_app_config().run_events
|
||||||
|
if config.backend == "db":
|
||||||
|
return AppRunEventStore(session_factory)
|
||||||
|
if config.backend == "jsonl":
|
||||||
|
return JsonlRunEventStore(
|
||||||
|
base_dir=Path(config.jsonl_base_dir),
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unsupported run event backend: {config.backend}")
|
||||||
@@ -0,0 +1,210 @@
|
|||||||
|
"""JSONL run event store backend owned by app infrastructure."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class JsonlRunEventStore:
|
||||||
|
"""Append-only JSONL implementation of the runs RunEventStore protocol."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_dir: Path | str = ".deer-flow/run-events",
|
||||||
|
) -> None:
|
||||||
|
self._base_dir = Path(base_dir)
|
||||||
|
self._locks: dict[str, asyncio.Lock] = {}
|
||||||
|
self._locks_guard = asyncio.Lock()
|
||||||
|
|
||||||
|
async def put_batch(self, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
if not events:
|
||||||
|
return []
|
||||||
|
|
||||||
|
grouped: dict[str, list[dict[str, Any]]] = {}
|
||||||
|
for event in events:
|
||||||
|
grouped.setdefault(str(event["thread_id"]), []).append(event)
|
||||||
|
|
||||||
|
records_by_thread: dict[str, list[dict[str, Any]]] = {}
|
||||||
|
for thread_id, thread_events in grouped.items():
|
||||||
|
async with await self._thread_lock(thread_id):
|
||||||
|
records_by_thread[thread_id] = self._append_thread_events(thread_id, thread_events)
|
||||||
|
|
||||||
|
indexes = {thread_id: 0 for thread_id in records_by_thread}
|
||||||
|
ordered: list[dict[str, Any]] = []
|
||||||
|
for event in events:
|
||||||
|
thread_id = str(event["thread_id"])
|
||||||
|
index = indexes[thread_id]
|
||||||
|
ordered.append(records_by_thread[thread_id][index])
|
||||||
|
indexes[thread_id] = index + 1
|
||||||
|
return ordered
|
||||||
|
|
||||||
|
async def list_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
events = [event for event in await self._read_thread_events(thread_id) if event.get("category") == "message"]
|
||||||
|
if before_seq is not None:
|
||||||
|
events = [event for event in events if int(event["seq"]) < before_seq]
|
||||||
|
return events[-limit:]
|
||||||
|
if after_seq is not None:
|
||||||
|
events = [event for event in events if int(event["seq"]) > after_seq]
|
||||||
|
return events[:limit]
|
||||||
|
return events[-limit:]
|
||||||
|
|
||||||
|
async def list_events(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
event_types: list[str] | None = None,
|
||||||
|
limit: int = 500,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
event_type_set = set(event_types or [])
|
||||||
|
events = [
|
||||||
|
event
|
||||||
|
for event in await self._read_thread_events(thread_id)
|
||||||
|
if event.get("run_id") == run_id and (not event_type_set or event.get("event_type") in event_type_set)
|
||||||
|
]
|
||||||
|
return events[:limit]
|
||||||
|
|
||||||
|
async def list_messages_by_run(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
events = [
|
||||||
|
event
|
||||||
|
for event in await self._read_thread_events(thread_id)
|
||||||
|
if event.get("run_id") == run_id and event.get("category") == "message"
|
||||||
|
]
|
||||||
|
if before_seq is not None:
|
||||||
|
events = [event for event in events if int(event["seq"]) < before_seq]
|
||||||
|
return events[-limit:]
|
||||||
|
if after_seq is not None:
|
||||||
|
events = [event for event in events if int(event["seq"]) > after_seq]
|
||||||
|
return events[:limit]
|
||||||
|
return events[-limit:]
|
||||||
|
|
||||||
|
async def count_messages(self, thread_id: str) -> int:
|
||||||
|
return len(await self.list_messages(thread_id, limit=10**9))
|
||||||
|
|
||||||
|
async def delete_by_thread(self, thread_id: str) -> int:
|
||||||
|
async with await self._thread_lock(thread_id):
|
||||||
|
count = len(self._read_thread_events_sync(thread_id))
|
||||||
|
shutil.rmtree(self._thread_dir(thread_id), ignore_errors=True)
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
|
||||||
|
async with await self._thread_lock(thread_id):
|
||||||
|
events = self._read_thread_events_sync(thread_id)
|
||||||
|
kept = [event for event in events if event.get("run_id") != run_id]
|
||||||
|
deleted = len(events) - len(kept)
|
||||||
|
if deleted:
|
||||||
|
self._write_thread_events(thread_id, kept)
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
async def _thread_lock(self, thread_id: str) -> asyncio.Lock:
|
||||||
|
async with self._locks_guard:
|
||||||
|
lock = self._locks.get(thread_id)
|
||||||
|
if lock is None:
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
self._locks[thread_id] = lock
|
||||||
|
return lock
|
||||||
|
|
||||||
|
def _append_thread_events(self, thread_id: str, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
thread_dir = self._thread_dir(thread_id)
|
||||||
|
thread_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
seq = self._read_seq(thread_id)
|
||||||
|
records: list[dict[str, Any]] = []
|
||||||
|
with self._events_path(thread_id).open("a", encoding="utf-8") as file:
|
||||||
|
for event in events:
|
||||||
|
seq += 1
|
||||||
|
record = self._normalize_event(event, seq=seq)
|
||||||
|
file.write(json.dumps(record, ensure_ascii=False, default=str))
|
||||||
|
file.write("\n")
|
||||||
|
records.append(record)
|
||||||
|
self._write_seq(thread_id, seq)
|
||||||
|
return records
|
||||||
|
|
||||||
|
def _normalize_event(self, event: dict[str, Any], *, seq: int) -> dict[str, Any]:
|
||||||
|
created_at = event.get("created_at")
|
||||||
|
if isinstance(created_at, datetime):
|
||||||
|
created_at_value = created_at.isoformat()
|
||||||
|
elif created_at:
|
||||||
|
created_at_value = str(created_at)
|
||||||
|
else:
|
||||||
|
created_at_value = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"thread_id": str(event["thread_id"]),
|
||||||
|
"run_id": str(event["run_id"]),
|
||||||
|
"seq": seq,
|
||||||
|
"event_type": str(event["event_type"]),
|
||||||
|
"category": str(event["category"]),
|
||||||
|
"content": event.get("content", ""),
|
||||||
|
"metadata": dict(event.get("metadata") or {}),
|
||||||
|
"created_at": created_at_value,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _read_thread_events(self, thread_id: str) -> list[dict[str, Any]]:
|
||||||
|
async with await self._thread_lock(thread_id):
|
||||||
|
return self._read_thread_events_sync(thread_id)
|
||||||
|
|
||||||
|
def _read_thread_events_sync(self, thread_id: str) -> list[dict[str, Any]]:
|
||||||
|
path = self._events_path(thread_id)
|
||||||
|
if not path.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
events: list[dict[str, Any]] = []
|
||||||
|
with path.open(encoding="utf-8") as file:
|
||||||
|
for line in file:
|
||||||
|
stripped = line.strip()
|
||||||
|
if stripped:
|
||||||
|
events.append(json.loads(stripped))
|
||||||
|
return events
|
||||||
|
|
||||||
|
def _write_thread_events(self, thread_id: str, events: Iterable[dict[str, Any]]) -> None:
|
||||||
|
thread_dir = self._thread_dir(thread_id)
|
||||||
|
thread_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
temp_path = self._events_path(thread_id).with_suffix(".jsonl.tmp")
|
||||||
|
with temp_path.open("w", encoding="utf-8") as file:
|
||||||
|
for event in events:
|
||||||
|
file.write(json.dumps(event, ensure_ascii=False, default=str))
|
||||||
|
file.write("\n")
|
||||||
|
temp_path.replace(self._events_path(thread_id))
|
||||||
|
|
||||||
|
def _read_seq(self, thread_id: str) -> int:
|
||||||
|
path = self._seq_path(thread_id)
|
||||||
|
if not path.exists():
|
||||||
|
return 0
|
||||||
|
try:
|
||||||
|
return int(path.read_text(encoding="utf-8").strip() or "0")
|
||||||
|
except ValueError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _write_seq(self, thread_id: str, seq: int) -> None:
|
||||||
|
self._seq_path(thread_id).write_text(str(seq), encoding="utf-8")
|
||||||
|
|
||||||
|
def _thread_dir(self, thread_id: str) -> Path:
|
||||||
|
return self._base_dir / "threads" / thread_id
|
||||||
|
|
||||||
|
def _events_path(self, thread_id: str) -> Path:
|
||||||
|
return self._thread_dir(thread_id) / "events.jsonl"
|
||||||
|
|
||||||
|
def _seq_path(self, thread_id: str) -> Path:
|
||||||
|
return self._thread_dir(thread_id) / "seq"
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""Storage-facing adapters owned by the app layer."""
|
||||||
|
|
||||||
|
from .run_events import AppRunEventStore
|
||||||
|
from .runs import FeedbackStoreAdapter, RunStoreAdapter, StorageRunObserver
|
||||||
|
from .thread_meta import ThreadMetaStorage, ThreadMetaStoreAdapter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AppRunEventStore",
|
||||||
|
"FeedbackStoreAdapter",
|
||||||
|
"RunStoreAdapter",
|
||||||
|
"StorageRunObserver",
|
||||||
|
"ThreadMetaStorage",
|
||||||
|
"ThreadMetaStoreAdapter",
|
||||||
|
]
|
||||||
@@ -0,0 +1,166 @@
|
|||||||
|
"""App-owned adapter from runs callbacks to storage run event repository."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
from store.repositories import RunEvent, RunEventCreate, build_run_event_repository, build_thread_meta_repository
|
||||||
|
|
||||||
|
from deerflow.runtime.actor_context import get_actor_context
|
||||||
|
|
||||||
|
|
||||||
|
class AppRunEventStore:
|
||||||
|
"""Implements the harness RunEventStore protocol using storage repositories."""
|
||||||
|
|
||||||
|
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
|
||||||
|
async def put_batch(self, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
if not events:
|
||||||
|
return []
|
||||||
|
|
||||||
|
denied = {str(event["thread_id"]) for event in events if not await self._thread_visible(str(event["thread_id"]))}
|
||||||
|
if denied:
|
||||||
|
raise PermissionError(f"actor is not allowed to append events for thread(s): {', '.join(sorted(denied))}")
|
||||||
|
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
try:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
rows = await repo.append_batch([_event_create_from_dict(event) for event in events])
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
return [_event_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
async def list_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
if not await self._thread_visible(thread_id):
|
||||||
|
return []
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
rows = await repo.list_messages(
|
||||||
|
thread_id,
|
||||||
|
limit=limit,
|
||||||
|
before_seq=before_seq,
|
||||||
|
after_seq=after_seq,
|
||||||
|
)
|
||||||
|
return [_event_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
async def list_events(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
event_types: list[str] | None = None,
|
||||||
|
limit: int = 500,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
if not await self._thread_visible(thread_id):
|
||||||
|
return []
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
rows = await repo.list_events(thread_id, run_id, event_types=event_types, limit=limit)
|
||||||
|
return [_event_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
async def list_messages_by_run(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
if not await self._thread_visible(thread_id):
|
||||||
|
return []
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
rows = await repo.list_messages_by_run(
|
||||||
|
thread_id,
|
||||||
|
run_id,
|
||||||
|
limit=limit,
|
||||||
|
before_seq=before_seq,
|
||||||
|
after_seq=after_seq,
|
||||||
|
)
|
||||||
|
return [_event_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
async def count_messages(self, thread_id: str) -> int:
|
||||||
|
if not await self._thread_visible(thread_id):
|
||||||
|
return 0
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
return await repo.count_messages(thread_id)
|
||||||
|
|
||||||
|
async def delete_by_thread(self, thread_id: str) -> int:
|
||||||
|
if not await self._thread_visible(thread_id):
|
||||||
|
return 0
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
try:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
count = await repo.delete_by_thread(thread_id)
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
|
||||||
|
if not await self._thread_visible(thread_id):
|
||||||
|
return 0
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
try:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
count = await repo.delete_by_run(thread_id, run_id)
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def _thread_visible(self, thread_id: str) -> bool:
|
||||||
|
actor = get_actor_context()
|
||||||
|
if actor is None or actor.user_id is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
thread_repo = build_thread_meta_repository(session)
|
||||||
|
thread = await thread_repo.get_thread_meta(thread_id)
|
||||||
|
|
||||||
|
if thread is None:
|
||||||
|
return True
|
||||||
|
return thread.user_id is None or thread.user_id == actor.user_id
|
||||||
|
|
||||||
|
|
||||||
|
def _event_create_from_dict(event: dict[str, Any]) -> RunEventCreate:
|
||||||
|
created_at = event.get("created_at")
|
||||||
|
return RunEventCreate(
|
||||||
|
thread_id=str(event["thread_id"]),
|
||||||
|
run_id=str(event["run_id"]),
|
||||||
|
event_type=str(event["event_type"]),
|
||||||
|
category=str(event["category"]),
|
||||||
|
content=event.get("content", ""),
|
||||||
|
metadata=dict(event.get("metadata") or {}),
|
||||||
|
created_at=datetime.fromisoformat(created_at) if isinstance(created_at, str) else created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _event_to_dict(event: RunEvent) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"thread_id": event.thread_id,
|
||||||
|
"run_id": event.run_id,
|
||||||
|
"event_type": event.event_type,
|
||||||
|
"category": event.category,
|
||||||
|
"content": event.content,
|
||||||
|
"metadata": event.metadata,
|
||||||
|
"seq": event.seq,
|
||||||
|
"created_at": event.created_at.isoformat(),
|
||||||
|
}
|
||||||
@@ -0,0 +1,515 @@
|
|||||||
|
"""Run lifecycle persistence adapters owned by the app layer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Protocol, TypedDict, Unpack, cast
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
from store.repositories import FeedbackCreate, Run, RunCreate, build_feedback_repository, build_run_repository
|
||||||
|
|
||||||
|
from deerflow.runtime.actor_context import AUTO, resolve_user_id
|
||||||
|
from deerflow.runtime.serialization import serialize_lc_object
|
||||||
|
from deerflow.runtime.runs.observer import LifecycleEventType, RunLifecycleEvent, RunObserver
|
||||||
|
from deerflow.runtime.stream_bridge import JSONValue
|
||||||
|
|
||||||
|
from .thread_meta import ThreadMetaStorage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RunCreateFields(TypedDict, total=False):
|
||||||
|
status: str
|
||||||
|
created_at: str
|
||||||
|
started_at: str
|
||||||
|
ended_at: str
|
||||||
|
assistant_id: str | None
|
||||||
|
user_id: str | None
|
||||||
|
follow_up_to_run_id: str | None
|
||||||
|
metadata: dict[str, JSONValue]
|
||||||
|
kwargs: dict[str, JSONValue]
|
||||||
|
|
||||||
|
|
||||||
|
class RunStatusUpdateFields(TypedDict, total=False):
|
||||||
|
started_at: str
|
||||||
|
ended_at: str
|
||||||
|
metadata: dict[str, JSONValue]
|
||||||
|
|
||||||
|
|
||||||
|
class RunCompletionFields(TypedDict, total=False):
|
||||||
|
total_input_tokens: int
|
||||||
|
total_output_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
llm_call_count: int
|
||||||
|
lead_agent_tokens: int
|
||||||
|
subagent_tokens: int
|
||||||
|
middleware_tokens: int
|
||||||
|
message_count: int
|
||||||
|
last_ai_message: str | None
|
||||||
|
first_human_message: str | None
|
||||||
|
error: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class RunRow(TypedDict, total=False):
|
||||||
|
run_id: str
|
||||||
|
thread_id: str
|
||||||
|
assistant_id: str | None
|
||||||
|
status: str
|
||||||
|
multitask_strategy: str
|
||||||
|
follow_up_to_run_id: str | None
|
||||||
|
metadata: dict[str, JSONValue]
|
||||||
|
created_at: str
|
||||||
|
updated_at: str
|
||||||
|
started_at: str | None
|
||||||
|
ended_at: str | None
|
||||||
|
error: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class RunReadRepository(Protocol):
|
||||||
|
"""Protocol for durable run queries."""
|
||||||
|
|
||||||
|
async def get(self, run_id: str, *, user_id: str | None | object = AUTO) -> RunRow | None: ...
|
||||||
|
|
||||||
|
async def list_by_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 100,
|
||||||
|
user_id: str | None | object = AUTO,
|
||||||
|
) -> list[RunRow]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class RunWriteRepository(Protocol):
|
||||||
|
"""Protocol for durable run writes."""
|
||||||
|
|
||||||
|
async def create(self, run_id: str, thread_id: str, **kwargs: Unpack[RunCreateFields]) -> None: ...
|
||||||
|
async def update_status(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
status: str,
|
||||||
|
**kwargs: Unpack[RunStatusUpdateFields],
|
||||||
|
) -> None: ...
|
||||||
|
async def set_error(self, run_id: str, error: str) -> None: ...
|
||||||
|
async def update_run_completion(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
**kwargs: Unpack[RunCompletionFields],
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class RunDeleteRepository(Protocol):
|
||||||
|
"""Protocol for durable run deletion."""
|
||||||
|
|
||||||
|
async def delete(self, run_id: str) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
class _RepositoryContext:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_factory: async_sessionmaker[AsyncSession],
|
||||||
|
build_repo: Callable[[AsyncSession], object],
|
||||||
|
*,
|
||||||
|
commit: bool,
|
||||||
|
) -> None:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
self._build_repo = build_repo
|
||||||
|
self._commit = commit
|
||||||
|
self._session: AsyncSession | None = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self._session = self._session_factory()
|
||||||
|
return self._build_repo(self._session)
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||||
|
if self._session is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if self._commit:
|
||||||
|
if exc_type is None:
|
||||||
|
await self._session.commit()
|
||||||
|
else:
|
||||||
|
await self._session.rollback()
|
||||||
|
finally:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_to_row(row: Run) -> RunRow:
|
||||||
|
return {
|
||||||
|
"run_id": row.run_id,
|
||||||
|
"thread_id": row.thread_id,
|
||||||
|
"assistant_id": row.assistant_id,
|
||||||
|
"user_id": row.user_id,
|
||||||
|
"status": row.status,
|
||||||
|
"model_name": row.model_name,
|
||||||
|
"multitask_strategy": row.multitask_strategy,
|
||||||
|
"follow_up_to_run_id": row.follow_up_to_run_id,
|
||||||
|
"metadata": cast(dict[str, JSONValue], row.metadata),
|
||||||
|
"kwargs": cast(dict[str, JSONValue], row.kwargs),
|
||||||
|
"created_at": row.created_time.isoformat(),
|
||||||
|
"updated_at": row.updated_time.isoformat() if row.updated_time else "",
|
||||||
|
"total_input_tokens": row.total_input_tokens,
|
||||||
|
"total_output_tokens": row.total_output_tokens,
|
||||||
|
"total_tokens": row.total_tokens,
|
||||||
|
"llm_call_count": row.llm_call_count,
|
||||||
|
"lead_agent_tokens": row.lead_agent_tokens,
|
||||||
|
"subagent_tokens": row.subagent_tokens,
|
||||||
|
"middleware_tokens": row.middleware_tokens,
|
||||||
|
"message_count": row.message_count,
|
||||||
|
"first_human_message": row.first_human_message,
|
||||||
|
"last_ai_message": row.last_ai_message,
|
||||||
|
"error": row.error,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackStoreAdapter:
|
||||||
|
"""Expose feedback route semantics on top of storage package repositories."""
|
||||||
|
|
||||||
|
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str,
|
||||||
|
thread_id: str,
|
||||||
|
rating: int,
|
||||||
|
owner_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
message_id: str | None = None,
|
||||||
|
comment: str | None = None,
|
||||||
|
) -> dict[str, object]:
|
||||||
|
if rating not in (1, -1):
|
||||||
|
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||||
|
effective_user_id = user_id if user_id is not None else owner_id
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
row = await repo.create_feedback(
|
||||||
|
FeedbackCreate(
|
||||||
|
feedback_id=str(uuid.uuid4()),
|
||||||
|
run_id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
rating=rating,
|
||||||
|
user_id=effective_user_id,
|
||||||
|
message_id=message_id,
|
||||||
|
comment=comment,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return _feedback_to_dict(row)
|
||||||
|
|
||||||
|
async def get(self, feedback_id: str) -> dict[str, object] | None:
|
||||||
|
async with self._read() as repo:
|
||||||
|
row = await repo.get_feedback(feedback_id)
|
||||||
|
return _feedback_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
async def list_by_run(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 100,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> list[dict[str, object]]:
|
||||||
|
async with self._read() as repo:
|
||||||
|
rows = await repo.list_feedback_by_run(run_id)
|
||||||
|
filtered = [row for row in rows if row.thread_id == thread_id]
|
||||||
|
if user_id is not None:
|
||||||
|
filtered = [row for row in filtered if row.user_id == user_id]
|
||||||
|
return [_feedback_to_dict(row) for row in filtered][:limit]
|
||||||
|
|
||||||
|
async def list_by_thread(self, thread_id: str, *, limit: int = 100) -> list[dict[str, object]]:
|
||||||
|
async with self._read() as repo:
|
||||||
|
rows = await repo.list_feedback_by_thread(thread_id)
|
||||||
|
return [_feedback_to_dict(row) for row in rows][:limit]
|
||||||
|
|
||||||
|
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict[str, object]:
|
||||||
|
rows = await self.list_by_run(thread_id, run_id)
|
||||||
|
positive = sum(1 for row in rows if row["rating"] == 1)
|
||||||
|
negative = sum(1 for row in rows if row["rating"] == -1)
|
||||||
|
return {"run_id": run_id, "total": len(rows), "positive": positive, "negative": negative}
|
||||||
|
|
||||||
|
async def delete(self, feedback_id: str) -> bool:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
return await repo.delete_feedback(feedback_id)
|
||||||
|
|
||||||
|
async def upsert(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str,
|
||||||
|
thread_id: str,
|
||||||
|
rating: int,
|
||||||
|
user_id: str,
|
||||||
|
comment: str | None = None,
|
||||||
|
) -> dict[str, object]:
|
||||||
|
if rating not in (1, -1):
|
||||||
|
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
rows = await repo.list_feedback_by_run(run_id)
|
||||||
|
existing = next((row for row in rows if row.thread_id == thread_id and row.user_id == user_id), None)
|
||||||
|
feedback_id = existing.feedback_id if existing is not None else str(uuid.uuid4())
|
||||||
|
if existing is not None:
|
||||||
|
await repo.delete_feedback(existing.feedback_id)
|
||||||
|
row = await repo.create_feedback(
|
||||||
|
FeedbackCreate(
|
||||||
|
feedback_id=feedback_id,
|
||||||
|
run_id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
rating=rating,
|
||||||
|
user_id=user_id,
|
||||||
|
comment=comment,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return _feedback_to_dict(row)
|
||||||
|
|
||||||
|
async def delete_by_run(self, *, thread_id: str, run_id: str, user_id: str) -> bool:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
rows = await repo.list_feedback_by_run(run_id)
|
||||||
|
existing = next((row for row in rows if row.thread_id == thread_id and row.user_id == user_id), None)
|
||||||
|
if existing is None:
|
||||||
|
return False
|
||||||
|
return await repo.delete_feedback(existing.feedback_id)
|
||||||
|
|
||||||
|
async def list_by_thread_grouped(self, thread_id: str, *, user_id: str) -> dict[str, dict[str, object]]:
|
||||||
|
rows = await self.list_by_thread(thread_id)
|
||||||
|
return {
|
||||||
|
row["run_id"]: row
|
||||||
|
for row in rows
|
||||||
|
if row["user_id"] == user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
def _read(self) -> _RepositoryContext:
|
||||||
|
return _RepositoryContext(self._session_factory, build_feedback_repository, commit=False)
|
||||||
|
|
||||||
|
def _transaction(self) -> _RepositoryContext:
|
||||||
|
return _RepositoryContext(self._session_factory, build_feedback_repository, commit=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _feedback_to_dict(row) -> dict[str, object]:
|
||||||
|
return {
|
||||||
|
"feedback_id": row.feedback_id,
|
||||||
|
"run_id": row.run_id,
|
||||||
|
"thread_id": row.thread_id,
|
||||||
|
"user_id": row.user_id,
|
||||||
|
"owner_id": row.user_id,
|
||||||
|
"message_id": row.message_id,
|
||||||
|
"rating": row.rating,
|
||||||
|
"comment": row.comment,
|
||||||
|
"created_at": row.created_time.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RunStoreAdapter:
|
||||||
|
"""Expose runs facade storage semantics on top of storage package repositories."""
|
||||||
|
|
||||||
|
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
|
||||||
|
async def get(self, run_id: str, *, user_id: str | None | object = AUTO) -> RunRow | None:
|
||||||
|
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.get")
|
||||||
|
async with self._read() as repo:
|
||||||
|
row = await repo.get_run(run_id)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
if effective_user_id is not None and row.user_id != effective_user_id:
|
||||||
|
return None
|
||||||
|
return _run_to_row(row)
|
||||||
|
|
||||||
|
async def list_by_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 100,
|
||||||
|
user_id: str | None | object = AUTO,
|
||||||
|
) -> list[RunRow]:
|
||||||
|
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.list_by_thread")
|
||||||
|
async with self._read() as repo:
|
||||||
|
rows = await repo.list_runs_by_thread(thread_id, limit=limit, offset=0)
|
||||||
|
if effective_user_id is not None:
|
||||||
|
rows = [row for row in rows if row.user_id == effective_user_id]
|
||||||
|
return [_run_to_row(row) for row in rows]
|
||||||
|
|
||||||
|
async def create(self, run_id: str, thread_id: str, **kwargs: Unpack[RunCreateFields]) -> None:
|
||||||
|
metadata = cast(dict[str, JSONValue], serialize_lc_object(kwargs.get("metadata") or {}))
|
||||||
|
run_kwargs = cast(dict[str, JSONValue], serialize_lc_object(kwargs.get("kwargs") or {}))
|
||||||
|
effective_user_id = resolve_user_id(kwargs.get("user_id", AUTO), method_name="RunStoreAdapter.create")
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
await repo.create_run(
|
||||||
|
RunCreate(
|
||||||
|
run_id=run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=kwargs.get("assistant_id"),
|
||||||
|
user_id=effective_user_id,
|
||||||
|
status=kwargs.get("status", "pending"),
|
||||||
|
metadata=dict(metadata),
|
||||||
|
kwargs=dict(run_kwargs),
|
||||||
|
follow_up_to_run_id=kwargs.get("follow_up_to_run_id"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete(self, run_id: str, *, user_id: str | None | object = AUTO) -> bool:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
existing = await repo.get_run(run_id)
|
||||||
|
if existing is None:
|
||||||
|
return False
|
||||||
|
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.delete")
|
||||||
|
if effective_user_id is not None and existing.user_id != effective_user_id:
|
||||||
|
return False
|
||||||
|
await repo.delete_run(run_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def update_status(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
status: str,
|
||||||
|
**kwargs: Unpack[RunStatusUpdateFields],
|
||||||
|
) -> None:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
await repo.update_run_status(run_id, status)
|
||||||
|
|
||||||
|
async def set_error(self, run_id: str, error: str) -> None:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
await repo.update_run_status(run_id, "error", error=error)
|
||||||
|
|
||||||
|
async def update_run_completion(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
**kwargs: Unpack[RunCompletionFields],
|
||||||
|
) -> None:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
await repo.update_run_completion(
|
||||||
|
run_id,
|
||||||
|
status=status,
|
||||||
|
total_input_tokens=kwargs.get("total_input_tokens", 0),
|
||||||
|
total_output_tokens=kwargs.get("total_output_tokens", 0),
|
||||||
|
total_tokens=kwargs.get("total_tokens", 0),
|
||||||
|
llm_call_count=kwargs.get("llm_call_count", 0),
|
||||||
|
lead_agent_tokens=kwargs.get("lead_agent_tokens", 0),
|
||||||
|
subagent_tokens=kwargs.get("subagent_tokens", 0),
|
||||||
|
middleware_tokens=kwargs.get("middleware_tokens", 0),
|
||||||
|
message_count=kwargs.get("message_count", 0),
|
||||||
|
last_ai_message=kwargs.get("last_ai_message"),
|
||||||
|
first_human_message=kwargs.get("first_human_message"),
|
||||||
|
error=kwargs.get("error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _read(self) -> _RepositoryContext:
|
||||||
|
return _RepositoryContext(self._session_factory, build_run_repository, commit=False)
|
||||||
|
|
||||||
|
def _transaction(self) -> _RepositoryContext:
|
||||||
|
return _RepositoryContext(self._session_factory, build_run_repository, commit=True)
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRunObserver(RunObserver):
|
||||||
|
"""Persist run lifecycle state into app-owned repositories."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
run_write_repo: RunWriteRepository | None = None,
|
||||||
|
thread_meta_storage: ThreadMetaStorage | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._run_write_repo = run_write_repo
|
||||||
|
self._thread_meta_storage = thread_meta_storage
|
||||||
|
|
||||||
|
async def on_event(self, event: RunLifecycleEvent) -> None:
|
||||||
|
try:
|
||||||
|
await self._dispatch(event)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"StorageRunObserver failed to persist event %s for run %s",
|
||||||
|
event.event_type,
|
||||||
|
event.run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _dispatch(self, event: RunLifecycleEvent) -> None:
|
||||||
|
handlers = {
|
||||||
|
LifecycleEventType.RUN_STARTED: self._handle_run_started,
|
||||||
|
LifecycleEventType.RUN_COMPLETED: self._handle_run_completed,
|
||||||
|
LifecycleEventType.RUN_FAILED: self._handle_run_failed,
|
||||||
|
LifecycleEventType.RUN_CANCELLED: self._handle_run_cancelled,
|
||||||
|
LifecycleEventType.THREAD_STATUS_UPDATED: self._handle_thread_status,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler = handlers.get(event.event_type)
|
||||||
|
if handler:
|
||||||
|
await handler(event)
|
||||||
|
|
||||||
|
async def _handle_run_started(self, event: RunLifecycleEvent) -> None:
|
||||||
|
if self._run_write_repo:
|
||||||
|
await self._run_write_repo.update_status(
|
||||||
|
run_id=event.run_id,
|
||||||
|
status="running",
|
||||||
|
started_at=event.occurred_at.isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_run_completed(self, event: RunLifecycleEvent) -> None:
|
||||||
|
payload = dict(event.payload) if event.payload else {}
|
||||||
|
if self._run_write_repo:
|
||||||
|
completion_data = payload.get("completion_data")
|
||||||
|
if isinstance(completion_data, dict):
|
||||||
|
await self._run_write_repo.update_run_completion(
|
||||||
|
run_id=event.run_id,
|
||||||
|
status="success",
|
||||||
|
**cast(RunCompletionFields, completion_data),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self._run_write_repo.update_status(
|
||||||
|
run_id=event.run_id,
|
||||||
|
status="success",
|
||||||
|
ended_at=event.occurred_at.isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._thread_meta_storage and "title" in payload:
|
||||||
|
await self._thread_meta_storage.sync_thread_title(
|
||||||
|
thread_id=event.thread_id,
|
||||||
|
title=payload["title"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_run_failed(self, event: RunLifecycleEvent) -> None:
|
||||||
|
if self._run_write_repo:
|
||||||
|
payload = dict(event.payload) if event.payload else {}
|
||||||
|
error = payload.get("error", "Unknown error")
|
||||||
|
completion_data = payload.get("completion_data")
|
||||||
|
if isinstance(completion_data, dict):
|
||||||
|
await self._run_write_repo.update_run_completion(
|
||||||
|
run_id=event.run_id,
|
||||||
|
status="error",
|
||||||
|
error=str(error),
|
||||||
|
**cast(RunCompletionFields, completion_data),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self._run_write_repo.update_status(
|
||||||
|
run_id=event.run_id,
|
||||||
|
status="error",
|
||||||
|
ended_at=event.occurred_at.isoformat(),
|
||||||
|
)
|
||||||
|
await self._run_write_repo.set_error(run_id=event.run_id, error=str(error))
|
||||||
|
|
||||||
|
async def _handle_run_cancelled(self, event: RunLifecycleEvent) -> None:
|
||||||
|
if self._run_write_repo:
|
||||||
|
payload = dict(event.payload) if event.payload else {}
|
||||||
|
completion_data = payload.get("completion_data")
|
||||||
|
if isinstance(completion_data, dict):
|
||||||
|
await self._run_write_repo.update_run_completion(
|
||||||
|
run_id=event.run_id,
|
||||||
|
status="interrupted",
|
||||||
|
**cast(RunCompletionFields, completion_data),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self._run_write_repo.update_status(
|
||||||
|
run_id=event.run_id,
|
||||||
|
status="interrupted",
|
||||||
|
ended_at=event.occurred_at.isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_thread_status(self, event: RunLifecycleEvent) -> None:
|
||||||
|
if self._thread_meta_storage:
|
||||||
|
payload = dict(event.payload) if event.payload else {}
|
||||||
|
status = payload.get("status", "idle")
|
||||||
|
await self._thread_meta_storage.sync_thread_status(
|
||||||
|
thread_id=event.thread_id,
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
@@ -0,0 +1,208 @@
|
|||||||
|
"""Thread metadata storage adapter owned by the app layer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
from store.repositories import build_thread_meta_repository
|
||||||
|
from store.repositories.contracts import (
|
||||||
|
ThreadMeta,
|
||||||
|
ThreadMetaCreate,
|
||||||
|
ThreadMetaRepositoryProtocol,
|
||||||
|
)
|
||||||
|
from deerflow.runtime.actor_context import AUTO, resolve_user_id
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadMetaStoreAdapter:
|
||||||
|
"""Use storage package thread repositories with per-call sessions."""
|
||||||
|
|
||||||
|
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
|
||||||
|
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
return await repo.create_thread_meta(data)
|
||||||
|
|
||||||
|
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
|
||||||
|
async with self._read() as repo:
|
||||||
|
return await repo.get_thread_meta(thread_id)
|
||||||
|
|
||||||
|
async def update_thread_meta(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
assistant_id: str | None = None,
|
||||||
|
display_name: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
await repo.update_thread_meta(
|
||||||
|
thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
display_name=display_name,
|
||||||
|
status=status,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_thread(self, thread_id: str) -> None:
|
||||||
|
async with self._transaction() as repo:
|
||||||
|
await repo.delete_thread(thread_id)
|
||||||
|
|
||||||
|
async def search_threads(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
assistant_id: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[ThreadMeta]:
|
||||||
|
async with self._read() as repo:
|
||||||
|
return await repo.search_threads(
|
||||||
|
metadata=metadata,
|
||||||
|
status=status,
|
||||||
|
user_id=user_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _read(self):
|
||||||
|
return _ThreadMetaRepositoryContext(self._session_factory, commit=False)
|
||||||
|
|
||||||
|
def _transaction(self):
|
||||||
|
return _ThreadMetaRepositoryContext(self._session_factory, commit=True)
|
||||||
|
|
||||||
|
|
||||||
|
class _ThreadMetaRepositoryContext:
|
||||||
|
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, commit: bool) -> None:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
self._commit = commit
|
||||||
|
self._session: AsyncSession | None = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self._session = self._session_factory()
|
||||||
|
return build_thread_meta_repository(self._session)
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||||
|
if self._session is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if self._commit:
|
||||||
|
if exc_type is None:
|
||||||
|
await self._session.commit()
|
||||||
|
else:
|
||||||
|
await self._session.rollback()
|
||||||
|
finally:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadMetaStorage:
|
||||||
|
"""App-facing adapter around the storage thread metadata contract."""
|
||||||
|
|
||||||
|
def __init__(self, repo: ThreadMetaRepositoryProtocol) -> None:
|
||||||
|
self._repo = repo
|
||||||
|
|
||||||
|
async def get_thread(self, thread_id: str, *, user_id: str | None | object = AUTO) -> ThreadMeta | None:
|
||||||
|
thread = await self._repo.get_thread_meta(thread_id)
|
||||||
|
if thread is None:
|
||||||
|
return None
|
||||||
|
effective_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.get_thread")
|
||||||
|
if effective_user_id is not None and thread.user_id != effective_user_id:
|
||||||
|
return None
|
||||||
|
return thread
|
||||||
|
|
||||||
|
async def ensure_thread(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
user_id: str | None | object = AUTO,
|
||||||
|
) -> ThreadMeta:
|
||||||
|
effective_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.ensure_thread")
|
||||||
|
existing = await self.get_thread(thread_id, user_id=effective_user_id)
|
||||||
|
if existing is not None:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
return await self._repo.create_thread_meta(
|
||||||
|
ThreadMetaCreate(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
user_id=effective_user_id,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ensure_thread_running(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> ThreadMeta | None:
|
||||||
|
existing = await self._repo.get_thread_meta(thread_id)
|
||||||
|
if existing is None:
|
||||||
|
return await self._repo.create_thread_meta(
|
||||||
|
ThreadMetaCreate(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
status="running",
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._repo.update_thread_meta(thread_id, status="running")
|
||||||
|
return await self._repo.get_thread_meta(thread_id)
|
||||||
|
|
||||||
|
async def sync_thread_title(self, *, thread_id: str, title: str) -> None:
|
||||||
|
await self._repo.update_thread_meta(thread_id, display_name=title)
|
||||||
|
|
||||||
|
async def sync_thread_assistant_id(self, *, thread_id: str, assistant_id: str) -> None:
|
||||||
|
await self._repo.update_thread_meta(thread_id, assistant_id=assistant_id)
|
||||||
|
|
||||||
|
async def sync_thread_status(self, *, thread_id: str, status: str) -> None:
|
||||||
|
await self._repo.update_thread_meta(thread_id, status=status)
|
||||||
|
|
||||||
|
async def sync_thread_metadata(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await self._repo.update_thread_meta(thread_id, metadata=metadata)
|
||||||
|
|
||||||
|
async def delete_thread(self, thread_id: str) -> None:
|
||||||
|
await self._repo.delete_thread(thread_id)
|
||||||
|
|
||||||
|
async def search_threads(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
user_id: str | None | object = AUTO,
|
||||||
|
assistant_id: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[ThreadMeta]:
|
||||||
|
normalized_status = status.strip() if status is not None else None
|
||||||
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.search_threads")
|
||||||
|
normalized_user_id = resolved_user_id.strip() if resolved_user_id is not None else None
|
||||||
|
normalized_assistant_id = (
|
||||||
|
assistant_id.strip() if assistant_id is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self._repo.search_threads(
|
||||||
|
metadata=metadata,
|
||||||
|
status=normalized_status or None,
|
||||||
|
user_id=normalized_user_id or None,
|
||||||
|
assistant_id=normalized_assistant_id or None,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ThreadMetaStorage", "ThreadMetaStoreAdapter"]
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""App-owned stream bridge adapters and factory."""
|
||||||
|
|
||||||
|
from .factory import build_stream_bridge
|
||||||
|
from .adapters import MemoryStreamBridge, RedisStreamBridge
|
||||||
|
|
||||||
|
__all__ = ["MemoryStreamBridge", "RedisStreamBridge", "build_stream_bridge"]
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""Concrete stream bridge adapters owned by the app layer."""
|
||||||
|
|
||||||
|
from .memory import MemoryStreamBridge
|
||||||
|
from .redis import RedisStreamBridge
|
||||||
|
|
||||||
|
__all__ = ["MemoryStreamBridge", "RedisStreamBridge"]
|
||||||
@@ -0,0 +1,450 @@
|
|||||||
|
"""In-memory stream bridge implementation owned by the app layer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from deerflow.runtime.stream_bridge import (
|
||||||
|
CANCELLED_SENTINEL,
|
||||||
|
END_SENTINEL,
|
||||||
|
HEARTBEAT_SENTINEL,
|
||||||
|
TERMINAL_STATES,
|
||||||
|
ResumeResult,
|
||||||
|
StreamBridge,
|
||||||
|
StreamEvent,
|
||||||
|
StreamStatus,
|
||||||
|
)
|
||||||
|
from deerflow.runtime.stream_bridge.exceptions import (
|
||||||
|
BridgeClosedError,
|
||||||
|
StreamCapacityExceededError,
|
||||||
|
StreamTerminatedError,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _RunStream:
|
||||||
|
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
|
||||||
|
events: list[StreamEvent] = field(default_factory=list)
|
||||||
|
id_to_offset: dict[str, int] = field(default_factory=dict)
|
||||||
|
start_offset: int = 0
|
||||||
|
current_bytes: int = 0
|
||||||
|
seq: int = 0
|
||||||
|
status: StreamStatus = StreamStatus.ACTIVE
|
||||||
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
last_publish_at: float | None = None
|
||||||
|
ended_at: float | None = None
|
||||||
|
subscriber_count: int = 0
|
||||||
|
last_subscribe_at: float | None = None
|
||||||
|
awaiting_input: bool = False
|
||||||
|
awaiting_since: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryStreamBridge(StreamBridge):
|
||||||
|
"""Per-run in-memory event log implementation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
max_events_per_stream: int = 256,
|
||||||
|
max_bytes_per_stream: int = 10 * 1024 * 1024,
|
||||||
|
max_active_streams: int = 1000,
|
||||||
|
stream_eviction_policy: Literal["reject", "lru"] = "lru",
|
||||||
|
terminal_retention_ttl: float = 300.0,
|
||||||
|
active_no_publish_timeout: float = 600.0,
|
||||||
|
orphan_timeout: float = 60.0,
|
||||||
|
max_stream_age: float = 86400.0,
|
||||||
|
hitl_extended_timeout: float = 7200.0,
|
||||||
|
cleanup_interval: float = 30.0,
|
||||||
|
queue_maxsize: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
if queue_maxsize is not None:
|
||||||
|
max_events_per_stream = queue_maxsize
|
||||||
|
|
||||||
|
self._max_events = max_events_per_stream
|
||||||
|
self._max_bytes = max_bytes_per_stream
|
||||||
|
self._max_streams = max_active_streams
|
||||||
|
self._eviction_policy = stream_eviction_policy
|
||||||
|
self._terminal_ttl = terminal_retention_ttl
|
||||||
|
self._active_timeout = active_no_publish_timeout
|
||||||
|
self._orphan_timeout = orphan_timeout
|
||||||
|
self._max_age = max_stream_age
|
||||||
|
self._hitl_timeout = hitl_extended_timeout
|
||||||
|
self._cleanup_interval = cleanup_interval
|
||||||
|
self._streams: dict[str, _RunStream] = {}
|
||||||
|
self._registry_lock = asyncio.Lock()
|
||||||
|
self._closed = False
|
||||||
|
self._cleanup_task: asyncio.Task[None] | None = None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
if self._cleanup_task is None:
|
||||||
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||||
|
logger.info(
|
||||||
|
"MemoryStreamBridge started (max_events=%d, max_bytes=%d, max_streams=%d)",
|
||||||
|
self._max_events,
|
||||||
|
self._max_bytes,
|
||||||
|
self._max_streams,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
async with self._registry_lock:
|
||||||
|
self._closed = True
|
||||||
|
if self._cleanup_task is not None:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._cleanup_task = None
|
||||||
|
|
||||||
|
for stream in self._streams.values():
|
||||||
|
async with stream.condition:
|
||||||
|
stream.status = StreamStatus.CLOSED
|
||||||
|
stream.condition.notify_all()
|
||||||
|
|
||||||
|
self._streams.clear()
|
||||||
|
logger.info("MemoryStreamBridge closed")
|
||||||
|
|
||||||
|
async def _get_or_create_stream(self, run_id: str) -> _RunStream:
|
||||||
|
stream = self._streams.get(run_id)
|
||||||
|
if stream is not None:
|
||||||
|
return stream
|
||||||
|
|
||||||
|
async with self._registry_lock:
|
||||||
|
if self._closed:
|
||||||
|
raise BridgeClosedError("Stream bridge is closed")
|
||||||
|
|
||||||
|
stream = self._streams.get(run_id)
|
||||||
|
if stream is not None:
|
||||||
|
return stream
|
||||||
|
|
||||||
|
if len(self._streams) >= self._max_streams:
|
||||||
|
if self._eviction_policy == "reject":
|
||||||
|
raise StreamCapacityExceededError(
|
||||||
|
f"Max {self._max_streams} active streams reached"
|
||||||
|
)
|
||||||
|
evicted = self._evict_oldest_terminal()
|
||||||
|
if evicted is None:
|
||||||
|
raise StreamCapacityExceededError("All streams active, cannot evict")
|
||||||
|
logger.info("Evicted stream %s to make room", evicted)
|
||||||
|
|
||||||
|
stream = _RunStream()
|
||||||
|
self._streams[run_id] = stream
|
||||||
|
logger.debug("Created stream for run %s", run_id)
|
||||||
|
return stream
|
||||||
|
|
||||||
|
def _evict_oldest_terminal(self) -> str | None:
|
||||||
|
oldest_run_id: str | None = None
|
||||||
|
oldest_ended_at: float = float("inf")
|
||||||
|
for run_id, stream in self._streams.items():
|
||||||
|
if stream.status in TERMINAL_STATES and stream.ended_at is not None:
|
||||||
|
if stream.ended_at < oldest_ended_at:
|
||||||
|
oldest_ended_at = stream.ended_at
|
||||||
|
oldest_run_id = run_id
|
||||||
|
if oldest_run_id is not None:
|
||||||
|
del self._streams[oldest_run_id]
|
||||||
|
return oldest_run_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _next_id(self, stream: _RunStream) -> str:
|
||||||
|
stream.seq += 1
|
||||||
|
return f"{int(time.time() * 1000)}-{stream.seq}"
|
||||||
|
|
||||||
|
def _estimate_size(self, event: StreamEvent) -> int:
|
||||||
|
base = len(event.id) + len(event.event) + 100
|
||||||
|
if event.data is None:
|
||||||
|
return base
|
||||||
|
if isinstance(event.data, str):
|
||||||
|
return base + len(event.data)
|
||||||
|
if isinstance(event.data, (dict, list)):
|
||||||
|
try:
|
||||||
|
return base + len(json.dumps(event.data, default=str))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return base + 200
|
||||||
|
return base + 50
|
||||||
|
|
||||||
|
def _evict_overflow(self, stream: _RunStream) -> None:
|
||||||
|
while len(stream.events) > self._max_events or stream.current_bytes > self._max_bytes:
|
||||||
|
if not stream.events:
|
||||||
|
break
|
||||||
|
evicted = stream.events.pop(0)
|
||||||
|
stream.id_to_offset.pop(evicted.id, None)
|
||||||
|
stream.current_bytes -= self._estimate_size(evicted)
|
||||||
|
stream.start_offset += 1
|
||||||
|
|
||||||
|
async def publish(self, run_id: str, event: str, data: Any) -> str:
|
||||||
|
stream = await self._get_or_create_stream(run_id)
|
||||||
|
async with stream.condition:
|
||||||
|
if stream.status != StreamStatus.ACTIVE:
|
||||||
|
raise StreamTerminatedError(
|
||||||
|
f"Cannot publish to {stream.status.value} stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
entry = StreamEvent(id=self._next_id(stream), event=event, data=data)
|
||||||
|
absolute_offset = stream.start_offset + len(stream.events)
|
||||||
|
stream.events.append(entry)
|
||||||
|
stream.id_to_offset[entry.id] = absolute_offset
|
||||||
|
stream.current_bytes += self._estimate_size(entry)
|
||||||
|
stream.last_publish_at = time.monotonic()
|
||||||
|
self._evict_overflow(stream)
|
||||||
|
stream.condition.notify_all()
|
||||||
|
return entry.id
|
||||||
|
|
||||||
|
async def publish_end(self, run_id: str) -> str:
|
||||||
|
return await self.publish_terminal(run_id, StreamStatus.ENDED)
|
||||||
|
|
||||||
|
async def publish_terminal(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
kind: StreamStatus,
|
||||||
|
data: Any = None,
|
||||||
|
) -> str:
|
||||||
|
if kind not in TERMINAL_STATES:
|
||||||
|
raise ValueError(f"Invalid terminal kind: {kind}")
|
||||||
|
|
||||||
|
stream = await self._get_or_create_stream(run_id)
|
||||||
|
async with stream.condition:
|
||||||
|
if stream.status != StreamStatus.ACTIVE:
|
||||||
|
for evt in reversed(stream.events):
|
||||||
|
if evt.event in ("end", "cancel", "error", "dead_letter"):
|
||||||
|
return evt.id
|
||||||
|
return ""
|
||||||
|
|
||||||
|
event_name = {
|
||||||
|
StreamStatus.ENDED: "end",
|
||||||
|
StreamStatus.CANCELLED: "cancel",
|
||||||
|
StreamStatus.ERRORED: "error",
|
||||||
|
}[kind]
|
||||||
|
entry = StreamEvent(id=self._next_id(stream), event=event_name, data=data)
|
||||||
|
absolute_offset = stream.start_offset + len(stream.events)
|
||||||
|
stream.events.append(entry)
|
||||||
|
stream.id_to_offset[entry.id] = absolute_offset
|
||||||
|
stream.current_bytes += self._estimate_size(entry)
|
||||||
|
stream.status = kind
|
||||||
|
stream.ended_at = time.monotonic()
|
||||||
|
stream.awaiting_input = False
|
||||||
|
stream.condition.notify_all()
|
||||||
|
logger.debug("Stream %s terminal: %s", run_id, kind.value)
|
||||||
|
return entry.id
|
||||||
|
|
||||||
|
async def cancel(self, run_id: str) -> None:
|
||||||
|
await self.publish_terminal(run_id, StreamStatus.CANCELLED)
|
||||||
|
|
||||||
|
async def subscribe(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
last_event_id: str | None = None,
|
||||||
|
heartbeat_interval: float = 15.0,
|
||||||
|
) -> AsyncIterator[StreamEvent]:
|
||||||
|
stream = await self._get_or_create_stream(run_id)
|
||||||
|
resume = self._resolve_resume_point(stream, last_event_id)
|
||||||
|
next_offset = resume.next_offset
|
||||||
|
|
||||||
|
async with stream.condition:
|
||||||
|
stream.subscriber_count += 1
|
||||||
|
stream.last_subscribe_at = time.monotonic()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
entry_to_yield: StreamEvent | None = None
|
||||||
|
sentinel_to_yield: StreamEvent | None = None
|
||||||
|
should_return = False
|
||||||
|
should_wait = False
|
||||||
|
|
||||||
|
async with stream.condition:
|
||||||
|
if self._closed or stream.status == StreamStatus.CLOSED:
|
||||||
|
sentinel_to_yield = CANCELLED_SENTINEL
|
||||||
|
should_return = True
|
||||||
|
elif next_offset < stream.start_offset:
|
||||||
|
next_offset = stream.start_offset
|
||||||
|
else:
|
||||||
|
local_index = next_offset - stream.start_offset
|
||||||
|
if 0 <= local_index < len(stream.events):
|
||||||
|
entry_to_yield = stream.events[local_index]
|
||||||
|
next_offset += 1
|
||||||
|
if entry_to_yield.event in ("end", "cancel", "error", "dead_letter"):
|
||||||
|
should_return = True
|
||||||
|
elif stream.status in TERMINAL_STATES:
|
||||||
|
sentinel_to_yield = END_SENTINEL
|
||||||
|
should_return = True
|
||||||
|
else:
|
||||||
|
should_wait = True
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
stream.condition.wait(),
|
||||||
|
timeout=heartbeat_interval,
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if sentinel_to_yield is not None:
|
||||||
|
yield sentinel_to_yield
|
||||||
|
if should_return:
|
||||||
|
return
|
||||||
|
continue
|
||||||
|
|
||||||
|
if entry_to_yield is not None:
|
||||||
|
yield entry_to_yield
|
||||||
|
if should_return:
|
||||||
|
return
|
||||||
|
continue
|
||||||
|
|
||||||
|
if should_wait:
|
||||||
|
async with stream.condition:
|
||||||
|
local_index = next_offset - stream.start_offset
|
||||||
|
has_events = 0 <= local_index < len(stream.events)
|
||||||
|
is_terminal = stream.status in TERMINAL_STATES
|
||||||
|
if not has_events and not is_terminal:
|
||||||
|
yield HEARTBEAT_SENTINEL
|
||||||
|
|
||||||
|
finally:
|
||||||
|
async with stream.condition:
|
||||||
|
stream.subscriber_count = max(0, stream.subscriber_count - 1)
|
||||||
|
|
||||||
|
async def mark_awaiting_input(self, run_id: str) -> None:
|
||||||
|
stream = self._streams.get(run_id)
|
||||||
|
if stream is None:
|
||||||
|
return
|
||||||
|
async with stream.condition:
|
||||||
|
if stream.status == StreamStatus.ACTIVE:
|
||||||
|
stream.awaiting_input = True
|
||||||
|
stream.awaiting_since = time.monotonic()
|
||||||
|
logger.debug("Stream %s marked as awaiting input", run_id)
|
||||||
|
|
||||||
|
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
|
||||||
|
if delay > 0:
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
await self._do_cleanup(run_id, "manual")
|
||||||
|
|
||||||
|
async def _do_cleanup(self, run_id: str, reason: str) -> None:
|
||||||
|
async with self._registry_lock:
|
||||||
|
stream = self._streams.pop(run_id, None)
|
||||||
|
if stream is not None:
|
||||||
|
async with stream.condition:
|
||||||
|
stream.status = StreamStatus.CLOSED
|
||||||
|
stream.condition.notify_all()
|
||||||
|
logger.debug("Cleaned up stream %s (reason: %s)", run_id, reason)
|
||||||
|
|
||||||
|
async def _mark_dead_letter(self, run_id: str, reason: str) -> None:
|
||||||
|
stream = self._streams.get(run_id)
|
||||||
|
if stream is None:
|
||||||
|
return
|
||||||
|
async with stream.condition:
|
||||||
|
if stream.status != StreamStatus.ACTIVE:
|
||||||
|
return
|
||||||
|
entry = StreamEvent(
|
||||||
|
id=self._next_id(stream),
|
||||||
|
event="dead_letter",
|
||||||
|
data={"reason": reason, "timestamp": time.time()},
|
||||||
|
)
|
||||||
|
absolute_offset = stream.start_offset + len(stream.events)
|
||||||
|
stream.events.append(entry)
|
||||||
|
stream.id_to_offset[entry.id] = absolute_offset
|
||||||
|
stream.current_bytes += self._estimate_size(entry)
|
||||||
|
stream.status = StreamStatus.ERRORED
|
||||||
|
stream.ended_at = time.monotonic()
|
||||||
|
stream.condition.notify_all()
|
||||||
|
logger.warning("Stream %s marked as dead letter: %s", run_id, reason)
|
||||||
|
|
||||||
|
async def _cleanup_loop(self) -> None:
|
||||||
|
while not self._closed:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self._cleanup_interval)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
to_cleanup: list[tuple[str, str]] = []
|
||||||
|
to_mark_dead: list[tuple[str, str]] = []
|
||||||
|
|
||||||
|
async with self._registry_lock:
|
||||||
|
for run_id, stream in list(self._streams.items()):
|
||||||
|
if now - stream.created_at > self._max_age:
|
||||||
|
to_cleanup.append((run_id, "max_age_exceeded"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if stream.status == StreamStatus.ACTIVE:
|
||||||
|
timeout = self._hitl_timeout if stream.awaiting_input else self._active_timeout
|
||||||
|
last_activity = stream.last_publish_at or stream.created_at
|
||||||
|
if now - last_activity > timeout:
|
||||||
|
to_mark_dead.append((run_id, "no_publish_timeout"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if stream.status in TERMINAL_STATES and stream.ended_at:
|
||||||
|
if stream.subscriber_count > 0:
|
||||||
|
continue
|
||||||
|
last_sub = stream.last_subscribe_at or stream.ended_at
|
||||||
|
if now - last_sub > self._orphan_timeout:
|
||||||
|
to_cleanup.append((run_id, "orphan"))
|
||||||
|
continue
|
||||||
|
if now - stream.ended_at > self._terminal_ttl:
|
||||||
|
to_cleanup.append((run_id, "ttl_expired"))
|
||||||
|
|
||||||
|
for run_id, reason in to_mark_dead:
|
||||||
|
await self._mark_dead_letter(run_id, reason)
|
||||||
|
for run_id, reason in to_cleanup:
|
||||||
|
await self._do_cleanup(run_id, reason)
|
||||||
|
|
||||||
|
def get_stats(self) -> dict[str, Any]:
|
||||||
|
active = sum(1 for s in self._streams.values() if s.status == StreamStatus.ACTIVE)
|
||||||
|
terminal = sum(1 for s in self._streams.values() if s.status in TERMINAL_STATES)
|
||||||
|
total_events = sum(len(s.events) for s in self._streams.values())
|
||||||
|
total_bytes = sum(s.current_bytes for s in self._streams.values())
|
||||||
|
total_subs = sum(s.subscriber_count for s in self._streams.values())
|
||||||
|
return {
|
||||||
|
"total_streams": len(self._streams),
|
||||||
|
"active_streams": active,
|
||||||
|
"terminal_streams": terminal,
|
||||||
|
"total_events": total_events,
|
||||||
|
"total_bytes": total_bytes,
|
||||||
|
"total_subscribers": total_subs,
|
||||||
|
"closed": self._closed,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _resolve_resume_point(
|
||||||
|
self,
|
||||||
|
stream: _RunStream,
|
||||||
|
last_event_id: str | None,
|
||||||
|
) -> ResumeResult:
|
||||||
|
if last_event_id is None:
|
||||||
|
return ResumeResult(next_offset=stream.start_offset, status="fresh")
|
||||||
|
if last_event_id in stream.id_to_offset:
|
||||||
|
return ResumeResult(
|
||||||
|
next_offset=stream.id_to_offset[last_event_id] + 1,
|
||||||
|
status="resumed",
|
||||||
|
)
|
||||||
|
|
||||||
|
parts = last_event_id.split("-")
|
||||||
|
if len(parts) != 2:
|
||||||
|
return ResumeResult(next_offset=stream.start_offset, status="invalid")
|
||||||
|
try:
|
||||||
|
event_ts = int(parts[0])
|
||||||
|
_event_seq = int(parts[1])
|
||||||
|
except ValueError:
|
||||||
|
return ResumeResult(next_offset=stream.start_offset, status="invalid")
|
||||||
|
|
||||||
|
if stream.events:
|
||||||
|
try:
|
||||||
|
oldest_parts = stream.events[0].id.split("-")
|
||||||
|
oldest_ts = int(oldest_parts[0])
|
||||||
|
if event_ts < oldest_ts:
|
||||||
|
return ResumeResult(
|
||||||
|
next_offset=stream.start_offset,
|
||||||
|
status="evicted",
|
||||||
|
gap_count=stream.start_offset,
|
||||||
|
)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return ResumeResult(next_offset=stream.start_offset, status="unknown")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["MemoryStreamBridge"]
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
"""Redis-backed stream bridge placeholder owned by the app layer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from deerflow.runtime.stream_bridge import StreamBridge, StreamEvent
|
||||||
|
|
||||||
|
|
||||||
|
class RedisStreamBridge(StreamBridge):
|
||||||
|
"""Reserved app-owned Redis implementation.
|
||||||
|
|
||||||
|
Phase 1 intentionally keeps Redis out of the harness package. The concrete
|
||||||
|
implementation will live here once cross-process streaming is introduced.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, redis_url: str) -> None:
|
||||||
|
self._redis_url = redis_url
|
||||||
|
|
||||||
|
async def publish(self, run_id: str, event: str, data: Any) -> str:
|
||||||
|
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||||
|
|
||||||
|
async def publish_end(self, run_id: str) -> str:
|
||||||
|
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||||
|
|
||||||
|
def subscribe(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
last_event_id: str | None = None,
|
||||||
|
heartbeat_interval: float = 15.0,
|
||||||
|
) -> AsyncIterator[StreamEvent]:
|
||||||
|
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||||
|
|
||||||
|
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
|
||||||
|
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
"""App-owned stream bridge factory."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||||
|
|
||||||
|
from deerflow.config.stream_bridge_config import get_stream_bridge_config
|
||||||
|
from deerflow.runtime.stream_bridge import StreamBridge
|
||||||
|
|
||||||
|
from .adapters import MemoryStreamBridge, RedisStreamBridge
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def build_stream_bridge(config=None) -> AbstractAsyncContextManager[StreamBridge]:
|
||||||
|
"""Build the configured app-owned stream bridge."""
|
||||||
|
return _build_stream_bridge_impl(config)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _build_stream_bridge_impl(config=None) -> AsyncIterator[StreamBridge]:
|
||||||
|
if config is None:
|
||||||
|
config = get_stream_bridge_config()
|
||||||
|
|
||||||
|
if config is None or config.type == "memory":
|
||||||
|
maxsize = config.queue_maxsize if config is not None else 256
|
||||||
|
bridge = MemoryStreamBridge(queue_maxsize=maxsize)
|
||||||
|
await bridge.start()
|
||||||
|
logger.info("Stream bridge initialised: memory (queue_maxsize=%d)", maxsize)
|
||||||
|
try:
|
||||||
|
yield bridge
|
||||||
|
finally:
|
||||||
|
await bridge.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
if config.type == "redis":
|
||||||
|
if not config.redis_url:
|
||||||
|
raise ValueError("Redis stream bridge requires redis_url")
|
||||||
|
bridge = RedisStreamBridge(redis_url=config.redis_url)
|
||||||
|
await bridge.start()
|
||||||
|
logger.info("Stream bridge initialised: redis (%s)", config.redis_url)
|
||||||
|
try:
|
||||||
|
yield bridge
|
||||||
|
finally:
|
||||||
|
await bridge.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown stream bridge type: {config.type!r}")
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
"""Entry point for running the Gateway API via `python app/main.py`.
|
||||||
|
|
||||||
|
Useful for IDE debugging (e.g., PyCharm / VS Code debug configurations).
|
||||||
|
Equivalent to: PYTHONPATH=. uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(
|
||||||
|
"app.gateway.app:app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=8001,
|
||||||
|
reload=True,
|
||||||
|
)
|
||||||
@@ -0,0 +1,314 @@
|
|||||||
|
# app.plugins Design Overview
|
||||||
|
|
||||||
|
This document describes the current role of `backend/app/plugins`, its plugin design contract, dependency boundaries, and how the current `auth` plugin provides services with minimal intrusion into the host application.
|
||||||
|
|
||||||
|
## 1. Overall Role
|
||||||
|
|
||||||
|
`app.plugins` is the application-side plugin boundary.
|
||||||
|
|
||||||
|
Its purpose is not to implement a generic plugin marketplace. Instead, it provides a clear boundary inside `app` for separable business capabilities, so that a capability can:
|
||||||
|
|
||||||
|
1. carry its own domain model, runtime state, and adapters inside the plugin
|
||||||
|
2. interact with the host application only through a limited set of seams
|
||||||
|
3. remain replaceable, removable, and extensible over time
|
||||||
|
|
||||||
|
The only real plugin currently implemented under this directory is [`auth`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth).
|
||||||
|
|
||||||
|
The current direction is not “put all logic into app”. It is:
|
||||||
|
|
||||||
|
1. the host application owns unified bootstrap, shared infrastructure, and top-level router assembly
|
||||||
|
2. each plugin owns its own business contract, persistence definitions, runtime state, and outward-facing adapters
|
||||||
|
|
||||||
|
## 2. Plugin Design Contract
|
||||||
|
|
||||||
|
### 2.1 A plugin should carry its own implementation
|
||||||
|
|
||||||
|
The primary contract visible in the current codebase is:
|
||||||
|
|
||||||
|
A plugin’s own ORM, runtime, domain, and adapters should be implemented inside the plugin itself. Core business behavior should not be scattered into unrelated external modules.
|
||||||
|
|
||||||
|
The `auth` plugin already follows that pattern with a fairly complete internal structure:
|
||||||
|
|
||||||
|
1. `domain`
|
||||||
|
- config, errors, JWT, password logic, domain models, service
|
||||||
|
2. `storage`
|
||||||
|
- plugin-owned ORM models, repository contracts, and repository implementations
|
||||||
|
3. `runtime`
|
||||||
|
- plugin-owned runtime config state
|
||||||
|
4. `api`
|
||||||
|
- plugin-owned HTTP router and schemas
|
||||||
|
5. `security`
|
||||||
|
- plugin-owned middleware, dependencies, CSRF logic, and LangGraph adapter
|
||||||
|
6. `authorization`
|
||||||
|
- plugin-owned permission model, policy resolution, and hooks
|
||||||
|
7. `injection`
|
||||||
|
- plugin-owned route-policy loading, injection, and validation
|
||||||
|
|
||||||
|
In other words, a plugin should be a self-contained capability module, not a bag of helpers.
|
||||||
|
|
||||||
|
### 2.2 The host app should provide shared infrastructure, not plugin internals
|
||||||
|
|
||||||
|
The current contract is not that every plugin must be fully infrastructure-independent.
|
||||||
|
|
||||||
|
It is:
|
||||||
|
|
||||||
|
1. a plugin may reuse the application’s shared `engine`, `session_factory`, FastAPI app, and router tree
|
||||||
|
2. but the plugin must still own its table definitions, repositories, runtime config, and business/auth behavior
|
||||||
|
|
||||||
|
This is stated explicitly in [`auth/plugin.toml`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/plugin.toml):
|
||||||
|
|
||||||
|
1. `storage.mode = "shared_infrastructure"`
|
||||||
|
2. the plugin owns its storage definitions and repositories
|
||||||
|
3. but it reuses the application’s shared persistence infrastructure
|
||||||
|
|
||||||
|
So the real rule is not “never reuse infrastructure”. The real rule is “do not outsource plugin business semantics to the rest of the app”.
|
||||||
|
|
||||||
|
### 2.3 Dependencies should remain one-way
|
||||||
|
|
||||||
|
The intended dependency direction in the current design is:
|
||||||
|
|
||||||
|
```text
|
||||||
|
gateway / app bootstrap
|
||||||
|
-> plugin public adapters
|
||||||
|
-> plugin domain / storage / runtime
|
||||||
|
```
|
||||||
|
|
||||||
|
Not:
|
||||||
|
|
||||||
|
```text
|
||||||
|
plugin domain
|
||||||
|
-> depends on app business modules
|
||||||
|
```
|
||||||
|
|
||||||
|
A plugin may depend on:
|
||||||
|
|
||||||
|
1. shared persistence infrastructure
|
||||||
|
2. `app.state` provided by the host application
|
||||||
|
3. generic framework capabilities such as FastAPI / Starlette
|
||||||
|
|
||||||
|
But its core business rules should not depend on unrelated app business modules, otherwise hot-swappability becomes unrealistic.
|
||||||
|
|
||||||
|
## 3. The Current auth Plugin Structure
|
||||||
|
|
||||||
|
The current `auth` plugin is effectively a self-contained authentication and authorization package with its own models, services, and adapters.
|
||||||
|
|
||||||
|
### 3.1 domain
|
||||||
|
|
||||||
|
[`auth/domain`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/domain) owns:
|
||||||
|
|
||||||
|
1. `config.py`
|
||||||
|
- auth-related configuration definition and loading
|
||||||
|
2. `errors.py`
|
||||||
|
- error codes and response contracts
|
||||||
|
3. `jwt.py`
|
||||||
|
- token encoding and decoding
|
||||||
|
4. `password.py`
|
||||||
|
- password hashing and verification
|
||||||
|
5. `models.py`
|
||||||
|
- auth domain models
|
||||||
|
6. `service.py`
|
||||||
|
- `AuthService` as the core business service
|
||||||
|
|
||||||
|
`AuthService` depends only on the plugin’s own `DbUserRepository` plus the shared session factory. The auth business logic is not reimplemented in `gateway`.
|
||||||
|
|
||||||
|
### 3.2 storage
|
||||||
|
|
||||||
|
[`auth/storage`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage) clearly shows the “ORM is owned by the plugin” contract:
|
||||||
|
|
||||||
|
1. [`models.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/models.py)
|
||||||
|
- defines the plugin-owned `users` table model
|
||||||
|
2. [`contracts.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/contracts.py)
|
||||||
|
- defines `User`, `UserCreate`, and `UserRepositoryProtocol`
|
||||||
|
3. [`repositories.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/repositories.py)
|
||||||
|
- implements `DbUserRepository`
|
||||||
|
|
||||||
|
The key point is:
|
||||||
|
|
||||||
|
1. the plugin defines its own ORM model
|
||||||
|
2. the plugin defines its own repository protocol
|
||||||
|
3. the plugin implements its own repository
|
||||||
|
4. external code only needs to provide a session or session factory
|
||||||
|
|
||||||
|
That is the minimal shared seam the boundary should preserve.
|
||||||
|
|
||||||
|
### 3.3 runtime
|
||||||
|
|
||||||
|
[`auth/runtime/config_state.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/runtime/config_state.py) keeps plugin-owned runtime config state:
|
||||||
|
|
||||||
|
1. `get_auth_config()`
|
||||||
|
2. `set_auth_config()`
|
||||||
|
3. `reset_auth_config()`
|
||||||
|
|
||||||
|
This matters because runtime state is also part of the plugin boundary. If future plugins need their own caches, state holders, or feature flags, they should follow the same pattern and keep them inside the plugin.
|
||||||
|
|
||||||
|
### 3.4 adapters
|
||||||
|
|
||||||
|
The `auth` plugin exposes capability through four main adapter groups:
|
||||||
|
|
||||||
|
1. `api/router.py`
|
||||||
|
- HTTP endpoints
|
||||||
|
2. `security/*`
|
||||||
|
- middleware, dependencies, request-user resolution, actor-context bridge
|
||||||
|
3. `authorization/*`
|
||||||
|
- capabilities, policy evaluators, auth hooks
|
||||||
|
4. `injection/*`
|
||||||
|
- route-policy registry, guard injection, startup validation
|
||||||
|
|
||||||
|
These adapters all follow the same rule:
|
||||||
|
|
||||||
|
1. entry-point behavior is defined inside the plugin
|
||||||
|
2. the host app only assembles and wires it
|
||||||
|
|
||||||
|
## 4. How a Plugin Interacts with the Host App
|
||||||
|
|
||||||
|
### 4.1 The top-level router only includes plugin routers
|
||||||
|
|
||||||
|
[`app/gateway/router.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/gateway/router.py) simply:
|
||||||
|
|
||||||
|
1. imports `app.plugins.auth.api.router`
|
||||||
|
2. calls `include_router(auth_router)`
|
||||||
|
|
||||||
|
That means the host app integrates auth HTTP behavior by assembly, not by duplicating login/register logic in `gateway`.
|
||||||
|
|
||||||
|
### 4.2 registrar performs wiring, not takeover
|
||||||
|
|
||||||
|
In [`app/gateway/registrar.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/gateway/registrar.py), the host app mainly does this:
|
||||||
|
|
||||||
|
1. `app.state.authz_hooks = build_authz_hooks()`
|
||||||
|
2. loads and validates the route-policy registry
|
||||||
|
3. calls `install_route_guards(app)`
|
||||||
|
4. calls `app.add_middleware(CSRFMiddleware)`
|
||||||
|
5. calls `app.add_middleware(AuthMiddleware)`
|
||||||
|
|
||||||
|
So the host app only wires the plugin in:
|
||||||
|
|
||||||
|
1. register middleware
|
||||||
|
2. install route guards
|
||||||
|
3. expose hooks and registries through `app.state`
|
||||||
|
|
||||||
|
The actual auth logic, authz logic, and route-policy semantics still live inside the plugin.
|
||||||
|
|
||||||
|
### 4.3 The plugin reuses shared sessions, but still owns business repositories
|
||||||
|
|
||||||
|
In [`auth/security/dependencies.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/security/dependencies.py):
|
||||||
|
|
||||||
|
1. the plugin reads the shared session factory from `request.app.state.persistence.session_factory`
|
||||||
|
2. constructs `DbUserRepository` itself
|
||||||
|
3. constructs `AuthService` itself
|
||||||
|
|
||||||
|
This is a good low-intrusion seam:
|
||||||
|
|
||||||
|
1. the outside world provides only shared infrastructure handles
|
||||||
|
2. the plugin decides how to instantiate its internal dependencies
|
||||||
|
|
||||||
|
## 5. Hot-Swappability and Low-Intrusion Principles
|
||||||
|
|
||||||
|
### 5.1 If a plugin serves other modules, it should minimize intrusion
|
||||||
|
|
||||||
|
When a plugin provides services to the rest of the app, the preferred patterns are:
|
||||||
|
|
||||||
|
1. expose a router
|
||||||
|
2. expose middleware or dependencies
|
||||||
|
3. expose hooks or protocols
|
||||||
|
4. inject a small number of shared objects through `app.state`
|
||||||
|
5. use config-driven route policies or capabilities instead of hardcoding checks inside business routes
|
||||||
|
|
||||||
|
Patterns to avoid:
|
||||||
|
|
||||||
|
1. large plugin-specific branches spread across `gateway`
|
||||||
|
2. unrelated business modules importing plugin ORM internals and rebuilding plugin logic themselves
|
||||||
|
3. plugin state being maintained across many global modules
|
||||||
|
|
||||||
|
### 5.2 Low-intrusion seams already visible in auth
|
||||||
|
|
||||||
|
The current `auth` plugin already uses four important low-intrusion seams:
|
||||||
|
|
||||||
|
1. router integration
|
||||||
|
- `gateway.router` only calls `include_router`
|
||||||
|
2. middleware integration
|
||||||
|
- `registrar` only registers `AuthMiddleware` and `CSRFMiddleware`
|
||||||
|
3. policy injection
|
||||||
|
- `install_route_guards(app)` appends `Depends(enforce_route_policy)` uniformly to routes
|
||||||
|
4. hook seam
|
||||||
|
- `authz_hooks` is exposed via `app.state`, so permission providers and policy builders can be replaced
|
||||||
|
|
||||||
|
This structure has three practical benefits:
|
||||||
|
|
||||||
|
1. host-app changes stay concentrated in the assembly layer
|
||||||
|
2. plugin core logic stays concentrated inside the plugin directory
|
||||||
|
3. swapping implementations does not require editing business routes one by one
|
||||||
|
|
||||||
|
### 5.3 Route policy is a key low-intrusion mechanism
|
||||||
|
|
||||||
|
[`auth/injection/registry_loader.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/registry_loader.py), [`validation.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/validation.py), and [`route_injector.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/route_injector.py) together form an important contract:
|
||||||
|
|
||||||
|
1. route policies live in the plugin-owned `route_policies.yaml`
|
||||||
|
2. startup validates that policy entries and real routes stay aligned
|
||||||
|
3. guards are attached by uniform injection instead of manual per-endpoint code
|
||||||
|
|
||||||
|
That allows the plugin to:
|
||||||
|
|
||||||
|
1. describe which routes are public, which capabilities are required, and which owner policies apply
|
||||||
|
2. avoid large invasive changes to the host routing layer
|
||||||
|
3. remain easier to replace or trim down later
|
||||||
|
|
||||||
|
## 6. What “ORM and runtime are implemented inside the plugin” Should Mean
|
||||||
|
|
||||||
|
That contract should be read as three concrete rules:
|
||||||
|
|
||||||
|
1. data models belong to the plugin
|
||||||
|
- the plugin’s own tables, Pydantic contracts, repository protocols, and repository implementations stay inside the plugin directory
|
||||||
|
2. runtime state belongs to the plugin
|
||||||
|
- plugin-owned config caches, context bridges, and plugin-level hooks stay inside the plugin
|
||||||
|
3. the outside world exposes infrastructure, not plugin semantics
|
||||||
|
- for example shared `session_factory`, FastAPI app, and `app.state`
|
||||||
|
|
||||||
|
Using `auth` as the example:
|
||||||
|
|
||||||
|
1. the `users` table is defined inside the plugin, not in `app.infra`
|
||||||
|
2. `AuthService` is implemented inside the plugin, not in `gateway`
|
||||||
|
3. `get_auth_config()` is maintained inside the plugin, not cached elsewhere
|
||||||
|
4. `AuthMiddleware`, `route_guard`, and `AuthzHooks` are all provided by the plugin itself
|
||||||
|
|
||||||
|
This is the structural prerequisite for meaningful pluginization later.
|
||||||
|
|
||||||
|
## 7. Current Scope and Non-Goals
|
||||||
|
|
||||||
|
At the current stage, the role of `app.plugins` is mainly:
|
||||||
|
|
||||||
|
1. to create module boundaries for separable application-side capabilities
|
||||||
|
2. to let each plugin own its own domain/storage/runtime/adapters
|
||||||
|
3. to connect plugins to the host app through assembly-oriented seams
|
||||||
|
|
||||||
|
The current non-goals are also clear:
|
||||||
|
|
||||||
|
1. this is not yet a full generic plugin discovery/installation system
|
||||||
|
2. plugins are not dynamically enabled or disabled at runtime
|
||||||
|
3. shared infrastructure is not being duplicated into every plugin
|
||||||
|
|
||||||
|
So at this stage, “hot-swappable” should be interpreted more precisely as:
|
||||||
|
|
||||||
|
1. plugin boundaries stay as independent as possible
|
||||||
|
2. integration points stay concentrated in the assembly layer
|
||||||
|
3. replacing or removing a plugin should mostly affect a small number of places such as `registrar`, router includes, and `app.state` hooks
|
||||||
|
|
||||||
|
## 8. Suggested Evolution Rules
|
||||||
|
|
||||||
|
If `app.plugins` is going to become a more stable plugin boundary, the codebase should keep following these rules:
|
||||||
|
|
||||||
|
1. each plugin directory should keep a `domain` / `storage` / `runtime` / `adapter` split
|
||||||
|
2. plugin-owned ORM and repositories should not drift into shared business directories
|
||||||
|
3. when a plugin serves the rest of the app, it should prefer exposing protocols, hooks, routers, and middleware over forcing external code to import internal implementation details
|
||||||
|
4. seams between a plugin and the host app should stay mostly limited to:
|
||||||
|
- `router.include_router(...)`
|
||||||
|
- `app.add_middleware(...)`
|
||||||
|
- `app.state.*`
|
||||||
|
- lifespan/bootstrap wiring
|
||||||
|
5. config-driven integration should be preferred over scattered hardcoded integration
|
||||||
|
6. startup validation should be preferred over implicit runtime failure
|
||||||
|
|
||||||
|
## 9. Summary
|
||||||
|
|
||||||
|
The current `app.plugins` contract can be summarized in one sentence:
|
||||||
|
|
||||||
|
Each plugin owns its own business implementation, ORM, and runtime; the host application provides shared infrastructure and assembly seams; and services should be integrated through low-intrusion, replaceable boundaries so the system can evolve toward real hot-swappability.
|
||||||
@@ -0,0 +1,310 @@
|
|||||||
|
# app.plugins 设计说明
|
||||||
|
|
||||||
|
本文基于当前代码实现,说明 `backend/app/plugins` 的定位、插件设计契约、依赖边界,以及当前 `auth` 插件是如何在尽量少侵入宿主应用的前提下提供服务的。
|
||||||
|
|
||||||
|
## 1. 总体定位
|
||||||
|
|
||||||
|
`app.plugins` 是应用侧插件边界。它的目标不是做一个通用插件市场,而是在 `app` 这一层给可拆分的业务能力预留清晰边界,使某一类能力可以:
|
||||||
|
|
||||||
|
1. 在插件内部自带领域模型、运行时状态和适配器
|
||||||
|
2. 只通过有限的接缝与宿主应用交互
|
||||||
|
3. 在未来保持“可替换、可裁剪、可扩展”
|
||||||
|
|
||||||
|
当前目录下实际落地的插件是 [`auth`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth)。
|
||||||
|
|
||||||
|
从当前实现看,`app.plugins` 的方向不是“所有逻辑都塞进 app”,而是:
|
||||||
|
|
||||||
|
1. 宿主应用负责统一启动、共享基础设施和总路由装配
|
||||||
|
2. 插件负责自己的业务契约、持久化定义、运行时状态和外部适配器
|
||||||
|
|
||||||
|
## 2. 插件设计契约
|
||||||
|
|
||||||
|
### 2.1 插件内部要自带完整能力
|
||||||
|
|
||||||
|
当前代码体现出的首要契约是:
|
||||||
|
|
||||||
|
插件自己的 ORM、runtime、domain、adapter,原则上都应由插件内部实现,不要把核心业务依赖散落到外部模块。
|
||||||
|
|
||||||
|
以 `auth` 插件为例,它内部已经自带了完整分层:
|
||||||
|
|
||||||
|
1. `domain`
|
||||||
|
- 配置、错误、JWT、密码、领域模型、服务
|
||||||
|
2. `storage`
|
||||||
|
- 插件自己的 ORM 模型、仓储契约和仓储实现
|
||||||
|
3. `runtime`
|
||||||
|
- 插件自己的运行时配置状态
|
||||||
|
4. `api`
|
||||||
|
- 插件自己的 HTTP router 和 schema
|
||||||
|
5. `security`
|
||||||
|
- 插件自己的 middleware、dependency、csrf、LangGraph 适配
|
||||||
|
6. `authorization`
|
||||||
|
- 插件自己的权限模型、policy 解析和 hook
|
||||||
|
7. `injection`
|
||||||
|
- 插件自己的路由策略注册、注入和校验逻辑
|
||||||
|
|
||||||
|
换句话说,插件不是一组零散 helper,而应该是一个自闭合的功能模块。
|
||||||
|
|
||||||
|
### 2.2 宿主应用只提供共享基础设施,不承接插件内部逻辑
|
||||||
|
|
||||||
|
当前约束不是“插件完全独立进程”,而是:
|
||||||
|
|
||||||
|
1. 插件可以复用应用共享的 `engine`、`session_factory`、FastAPI app、路由树
|
||||||
|
2. 但插件自己的表结构、仓储、运行时配置、鉴权逻辑,仍然应由插件自己拥有
|
||||||
|
|
||||||
|
这一点在 [`auth/plugin.toml`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/plugin.toml) 里写得很明确:
|
||||||
|
|
||||||
|
1. `storage.mode = "shared_infrastructure"`
|
||||||
|
2. 说明插件拥有自己的 storage definitions 和 repositories
|
||||||
|
3. 但复用应用共享的 persistence infrastructure
|
||||||
|
|
||||||
|
所以这里的契约不是“禁止复用基础设施”,而是“不要把插件内部业务实现外包给 app 其他模块”。
|
||||||
|
|
||||||
|
### 2.3 依赖方向要单向
|
||||||
|
|
||||||
|
按当前实现,比较理想的依赖方向是:
|
||||||
|
|
||||||
|
```text
|
||||||
|
gateway / app bootstrap
|
||||||
|
-> plugin public adapters
|
||||||
|
-> plugin domain / storage / runtime
|
||||||
|
```
|
||||||
|
|
||||||
|
而不是:
|
||||||
|
|
||||||
|
```text
|
||||||
|
plugin domain
|
||||||
|
-> 依赖 app 里的业务模块
|
||||||
|
```
|
||||||
|
|
||||||
|
插件可以使用:
|
||||||
|
|
||||||
|
1. 共享持久化基础设施
|
||||||
|
2. 宿主应用提供的 `app.state`
|
||||||
|
3. FastAPI / Starlette 等通用框架能力
|
||||||
|
|
||||||
|
但不应该把自己的核心业务规则建立在别的业务模块之上,否则后续无法热插拔。
|
||||||
|
|
||||||
|
## 3. 当前 auth 插件的实际结构
|
||||||
|
|
||||||
|
当前 `auth` 插件可以概括为一套“自带模型、自带服务、自带适配器”的认证授权包。
|
||||||
|
|
||||||
|
### 3.1 domain
|
||||||
|
|
||||||
|
[`auth/domain`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/domain) 负责:
|
||||||
|
|
||||||
|
1. `config.py`
|
||||||
|
- 认证相关配置定义与加载
|
||||||
|
2. `errors.py`
|
||||||
|
- 错误码和错误响应契约
|
||||||
|
3. `jwt.py`
|
||||||
|
- token 编解码
|
||||||
|
4. `password.py`
|
||||||
|
- 密码哈希和校验
|
||||||
|
5. `models.py`
|
||||||
|
- auth 域模型
|
||||||
|
6. `service.py`
|
||||||
|
- `AuthService`,作为核心业务服务
|
||||||
|
|
||||||
|
`AuthService` 本身只依赖插件内部的 `DbUserRepository` 和共享 session factory,没有把认证逻辑散到 `gateway`。
|
||||||
|
|
||||||
|
### 3.2 storage
|
||||||
|
|
||||||
|
[`auth/storage`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage) 明确体现了“ORM 由插件自己内部实现”的契约:
|
||||||
|
|
||||||
|
1. [`models.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/models.py)
|
||||||
|
- 定义插件自己的 `users` 表模型
|
||||||
|
2. [`contracts.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/contracts.py)
|
||||||
|
- 定义 `User`、`UserCreate` 和 `UserRepositoryProtocol`
|
||||||
|
3. [`repositories.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/storage/repositories.py)
|
||||||
|
- 实现 `DbUserRepository`
|
||||||
|
|
||||||
|
这里的关键点是:
|
||||||
|
|
||||||
|
1. 插件自己定义 ORM model
|
||||||
|
2. 插件自己定义 repository protocol
|
||||||
|
3. 插件自己实现 repository
|
||||||
|
4. 外部只需要给它 session / session_factory
|
||||||
|
|
||||||
|
这就是插件边界应该保持的最小共享面。
|
||||||
|
|
||||||
|
### 3.3 runtime
|
||||||
|
|
||||||
|
[`auth/runtime/config_state.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/runtime/config_state.py) 维护插件自己的 runtime config state:
|
||||||
|
|
||||||
|
1. `get_auth_config()`
|
||||||
|
2. `set_auth_config()`
|
||||||
|
3. `reset_auth_config()`
|
||||||
|
|
||||||
|
这说明运行时配置状态也属于插件内部,而不是由外部模块代持。后续如果别的插件需要自己的缓存、状态机、feature flag,也应沿这个模式内聚在插件内部。
|
||||||
|
|
||||||
|
### 3.4 adapters
|
||||||
|
|
||||||
|
`auth` 插件对外暴露能力主要通过四类 adapter:
|
||||||
|
|
||||||
|
1. `api/router.py`
|
||||||
|
- HTTP 接口
|
||||||
|
2. `security/*`
|
||||||
|
- middleware、dependency、request user 解析、actor context bridge
|
||||||
|
3. `authorization/*`
|
||||||
|
- capability、policy evaluator、auth hooks
|
||||||
|
4. `injection/*`
|
||||||
|
- route policy registry、guard 注入、启动校验
|
||||||
|
|
||||||
|
这类 adapter 的共同特征是:
|
||||||
|
|
||||||
|
1. 入口能力在插件内定义
|
||||||
|
2. 宿主应用只负责调用和装配
|
||||||
|
|
||||||
|
## 4. 插件如何与宿主应用交互
|
||||||
|
|
||||||
|
### 4.1 总路由只 include,不重写插件逻辑
|
||||||
|
|
||||||
|
[`app/gateway/router.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/gateway/router.py) 只是:
|
||||||
|
|
||||||
|
1. 引入 `app.plugins.auth.api.router`
|
||||||
|
2. `include_router(auth_router)`
|
||||||
|
|
||||||
|
这说明宿主应用对 auth HTTP 能力的接入是装配式的,而不是在 `gateway` 里重写一套登录/注册逻辑。
|
||||||
|
|
||||||
|
### 4.2 registrar 负责启动装配,不负责接管插件实现
|
||||||
|
|
||||||
|
[`app/gateway/registrar.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/gateway/registrar.py) 里,宿主应用做的事情主要是:
|
||||||
|
|
||||||
|
1. `app.state.authz_hooks = build_authz_hooks()`
|
||||||
|
2. 加载并校验 route policy registry
|
||||||
|
3. `install_route_guards(app)`
|
||||||
|
4. `app.add_middleware(CSRFMiddleware)`
|
||||||
|
5. `app.add_middleware(AuthMiddleware)`
|
||||||
|
|
||||||
|
也就是说,宿主应用只负责把插件接进来:
|
||||||
|
|
||||||
|
1. 注册 middleware
|
||||||
|
2. 安装 route guard
|
||||||
|
3. 把 hooks 和 registry 放到 `app.state`
|
||||||
|
|
||||||
|
真正的鉴权逻辑、认证逻辑、路由策略语义仍然在插件内部。
|
||||||
|
|
||||||
|
### 4.3 共享会话工厂,但业务仓储仍归插件
|
||||||
|
|
||||||
|
在 [`auth/security/dependencies.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/security/dependencies.py) 中:
|
||||||
|
|
||||||
|
1. 插件从 `request.app.state.persistence.session_factory` 取得共享 session factory
|
||||||
|
2. 然后自己构造 `DbUserRepository`
|
||||||
|
3. 再自己构造 `AuthService`
|
||||||
|
|
||||||
|
这就是一个很典型的低侵入接缝:
|
||||||
|
|
||||||
|
1. 外部只提供共享基础设施句柄
|
||||||
|
2. 插件自己决定如何实例化内部依赖
|
||||||
|
|
||||||
|
## 5. 热插拔与低侵入原则
|
||||||
|
|
||||||
|
### 5.1 如果要向其他模块提供服务,应尽量减少入侵
|
||||||
|
|
||||||
|
插件给其他模块提供服务时,优先选下面这些方式:
|
||||||
|
|
||||||
|
1. 暴露 router
|
||||||
|
2. 暴露 middleware / dependency
|
||||||
|
3. 暴露 hook 或 protocol
|
||||||
|
4. 通过 `app.state` 注入少量共享对象
|
||||||
|
5. 使用配置驱动的 route policy / capability,而不是把判断逻辑硬编码进业务路由
|
||||||
|
|
||||||
|
不推荐的方式是:
|
||||||
|
|
||||||
|
1. 在 `gateway` 大量写插件特定分支
|
||||||
|
2. 让别的业务模块直接 import 插件内部 ORM 细节后自行拼逻辑
|
||||||
|
3. 把插件状态散落到全局多个模块中共同维护
|
||||||
|
|
||||||
|
### 5.2 当前 auth 插件已经体现出的低侵入点
|
||||||
|
|
||||||
|
当前 `auth` 插件的低侵入接入点主要有四个:
|
||||||
|
|
||||||
|
1. 路由接入
|
||||||
|
- `gateway.router` 只 `include_router`
|
||||||
|
2. 中间件接入
|
||||||
|
- `registrar` 只注册 `AuthMiddleware` / `CSRFMiddleware`
|
||||||
|
3. 策略注入
|
||||||
|
- `install_route_guards(app)` 给路由统一追加 `Depends(enforce_route_policy)`
|
||||||
|
4. hook 接缝
|
||||||
|
- `authz_hooks` 通过 `app.state` 暴露,策略构建和权限提供器可以替换
|
||||||
|
|
||||||
|
这套结构的好处是:
|
||||||
|
|
||||||
|
1. 宿主应用改动面集中在装配层
|
||||||
|
2. 插件核心实现集中在插件目录内部
|
||||||
|
3. 替换实现时,不需要在业务路由里逐个修改
|
||||||
|
|
||||||
|
### 5.3 route policy 是低侵入的关键机制
|
||||||
|
|
||||||
|
[`auth/injection/registry_loader.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/registry_loader.py)、[`validation.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/validation.py) 和 [`route_injector.py`](/Users/rayhpeng/workspace/open-source/deer-flow/backend/app/plugins/auth/injection/route_injector.py) 共同形成了一套很关键的契约:
|
||||||
|
|
||||||
|
1. 路由策略写在插件自己的 `route_policies.yaml`
|
||||||
|
2. 启动时会校验策略表和真实路由是否一致
|
||||||
|
3. guard 通过统一注入附着到路由,而不是每个 endpoint 手写一遍
|
||||||
|
|
||||||
|
这使得插件能够:
|
||||||
|
|
||||||
|
1. 用配置描述“哪些路由公开、需要哪些 capability、需要哪些 owner policy”
|
||||||
|
2. 避免对宿主路由层做大规模侵入
|
||||||
|
3. 在未来更容易替换或裁剪某个插件
|
||||||
|
|
||||||
|
## 6. 关于“ORM、runtime 都由自己内部实现”的具体说明
|
||||||
|
|
||||||
|
这条契约建议明确理解为以下三点:
|
||||||
|
|
||||||
|
1. 数据模型归插件
|
||||||
|
- 插件自己的表、Pydantic contract、repository protocol、repository implementation 都放在插件目录内
|
||||||
|
2. 运行时状态归插件
|
||||||
|
- 插件自己的配置缓存、上下文桥、插件级 hooks 都在插件内部维护
|
||||||
|
3. 外部只暴露基础设施,不接管插件语义
|
||||||
|
- 例如共享 `session_factory`、FastAPI app、`app.state`
|
||||||
|
|
||||||
|
拿 `auth` 举例:
|
||||||
|
|
||||||
|
1. `users` 表在插件里定义,不在 `app.infra` 定义
|
||||||
|
2. `AuthService` 在插件里实现,不在 `gateway` 实现
|
||||||
|
3. `get_auth_config()` 在插件里维护,不由别的模块缓存
|
||||||
|
4. `AuthMiddleware`、`route_guard`、`AuthzHooks` 都由插件自己提供
|
||||||
|
|
||||||
|
这是后续做插件化时最重要的结构前提。
|
||||||
|
|
||||||
|
## 7. 当前作用范围与非目标
|
||||||
|
|
||||||
|
就当前实现而言,`app.plugins` 的作用范围主要是:
|
||||||
|
|
||||||
|
1. 为应用侧可拆分能力建立模块边界
|
||||||
|
2. 让插件拥有自己的 domain/storage/runtime/adapter
|
||||||
|
3. 通过装配式接缝接入宿主应用
|
||||||
|
|
||||||
|
当前非目标也很明确:
|
||||||
|
|
||||||
|
1. 还不是一个完整的通用插件发现/安装系统
|
||||||
|
2. 还没有做到运行时动态启停插件
|
||||||
|
3. 也不是把共享基础设施完全复制进每个插件
|
||||||
|
|
||||||
|
所以“热插拔”在当前阶段更准确的含义是:
|
||||||
|
|
||||||
|
1. 插件边界尽量独立
|
||||||
|
2. 接入点尽量集中在装配层
|
||||||
|
3. 替换或移除时,改动尽量局限在 `registrar`、`router include`、`app.state` hooks 这些少数位置
|
||||||
|
|
||||||
|
## 8. 后续演进建议
|
||||||
|
|
||||||
|
如果后续要继续把 `app.plugins` 做成更稳定的插件边界,建议保持这些规则:
|
||||||
|
|
||||||
|
1. 每个插件目录内部都保持 `domain` / `storage` / `runtime` / `adapter` 分层
|
||||||
|
2. 插件自己的 ORM 与 repository 不要下沉到共享业务目录
|
||||||
|
3. 插件向外提供服务时优先暴露 protocol、hook、router、middleware,而不是要求外部 import 内部实现细节
|
||||||
|
4. 插件与宿主应用的接缝尽量限制在:
|
||||||
|
- `router.include_router(...)`
|
||||||
|
- `app.add_middleware(...)`
|
||||||
|
- `app.state.*`
|
||||||
|
- 生命周期装配
|
||||||
|
5. 配置驱动优先于散落的硬编码接入
|
||||||
|
6. 启动期校验优先于运行时隐式失败
|
||||||
|
|
||||||
|
## 9. 设计总结
|
||||||
|
|
||||||
|
可以把当前 `app.plugins` 的契约总结为一句话:
|
||||||
|
|
||||||
|
插件内部拥有自己的业务实现、ORM 和 runtime;宿主应用只提供共享基础设施和装配接缝;对外服务时尽量通过低侵入、可替换的方式接入,以便后续做到真正的热插拔和边界演进。
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Application plugin packages."""
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
# Auth Plugin
|
||||||
|
|
||||||
|
This package is the future Level 2 auth plugin boundary for DeerFlow.
|
||||||
|
|
||||||
|
Scope:
|
||||||
|
|
||||||
|
- Auth domain logic: config, errors, models, JWT, password hashing, service
|
||||||
|
- Auth adapters: HTTP router, FastAPI dependencies, middleware, LangGraph adapter
|
||||||
|
- Auth storage: user/account models and repositories
|
||||||
|
|
||||||
|
Non-scope:
|
||||||
|
|
||||||
|
- Shared app/container bootstrap
|
||||||
|
- Shared persistence engine/session lifecycle
|
||||||
|
- Generic plugin discovery/registration framework
|
||||||
|
|
||||||
|
Target architecture:
|
||||||
|
|
||||||
|
- The plugin owns its storage definitions and business logic
|
||||||
|
- The plugin reuses the application's shared persistence infrastructure
|
||||||
|
- The gateway only assembles the plugin instead of owning auth logic directly
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""Auth plugin package.
|
||||||
|
|
||||||
|
Level 2 plugin goal:
|
||||||
|
- Own auth domain logic
|
||||||
|
- Own auth adapters (router, dependencies, middleware, LangGraph adapter)
|
||||||
|
- Own auth storage definitions
|
||||||
|
- Reuse the application's shared persistence/session infrastructure
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.plugins.auth.authorization.hooks import build_authz_hooks
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"build_authz_hooks",
|
||||||
|
]
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
"""HTTP API layer for the auth plugin."""
|
||||||
|
|
||||||
|
from app.plugins.auth.api.router import (
|
||||||
|
ChangePasswordRequest,
|
||||||
|
LoginResponse,
|
||||||
|
MessageResponse,
|
||||||
|
RegisterRequest,
|
||||||
|
router,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChangePasswordRequest",
|
||||||
|
"LoginResponse",
|
||||||
|
"MessageResponse",
|
||||||
|
"RegisterRequest",
|
||||||
|
"router",
|
||||||
|
]
|
||||||
@@ -0,0 +1,171 @@
|
|||||||
|
"""Authentication endpoints for the auth plugin."""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
|
from app.plugins.auth.api.schemas import (
|
||||||
|
ChangePasswordRequest,
|
||||||
|
InitializeAdminRequest,
|
||||||
|
LoginResponse,
|
||||||
|
MessageResponse,
|
||||||
|
RegisterRequest,
|
||||||
|
_check_rate_limit,
|
||||||
|
_get_client_ip,
|
||||||
|
_login_attempts,
|
||||||
|
_record_login_failure,
|
||||||
|
_record_login_success,
|
||||||
|
)
|
||||||
|
from app.plugins.auth.domain.errors import AuthErrorResponse
|
||||||
|
from app.plugins.auth.domain.jwt import create_access_token
|
||||||
|
from app.plugins.auth.domain.models import UserResponse
|
||||||
|
from app.plugins.auth.domain.service import AuthServiceError
|
||||||
|
from app.plugins.auth.runtime.config_state import get_auth_config
|
||||||
|
from app.plugins.auth.security.csrf import is_secure_request
|
||||||
|
from app.plugins.auth.security.dependencies import CurrentAuthService, get_current_user_from_request
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
|
||||||
|
config = get_auth_config()
|
||||||
|
is_https = is_secure_request(request)
|
||||||
|
response.set_cookie(
|
||||||
|
key="access_token",
|
||||||
|
value=token,
|
||||||
|
httponly=True,
|
||||||
|
secure=is_https,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login/local", response_model=LoginResponse)
|
||||||
|
async def login_local(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
auth_service: CurrentAuthService,
|
||||||
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||||
|
):
|
||||||
|
client_ip = _get_client_ip(request)
|
||||||
|
_check_rate_limit(client_ip)
|
||||||
|
try:
|
||||||
|
user = await auth_service.login_local(form_data.username, form_data.password)
|
||||||
|
except AuthServiceError as exc:
|
||||||
|
_record_login_failure(client_ip)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
detail=AuthErrorResponse(code=exc.code, message=exc.message).model_dump(),
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
_record_login_success(client_ip)
|
||||||
|
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||||
|
_set_session_cookie(response, token, request)
|
||||||
|
return LoginResponse(
|
||||||
|
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
|
||||||
|
needs_setup=user.needs_setup,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def register(request: Request, response: Response, body: RegisterRequest, auth_service: CurrentAuthService):
|
||||||
|
try:
|
||||||
|
user = await auth_service.register(body.email, body.password)
|
||||||
|
except AuthServiceError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
detail=AuthErrorResponse(code=exc.code, message=exc.message).model_dump(),
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||||
|
_set_session_cookie(response, token, request)
|
||||||
|
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout", response_model=MessageResponse)
|
||||||
|
async def logout(request: Request, response: Response):
|
||||||
|
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
|
||||||
|
return MessageResponse(message="Successfully logged out")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/change-password", response_model=MessageResponse)
|
||||||
|
async def change_password(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
body: ChangePasswordRequest,
|
||||||
|
auth_service: CurrentAuthService,
|
||||||
|
):
|
||||||
|
user = await get_current_user_from_request(request)
|
||||||
|
try:
|
||||||
|
user = await auth_service.change_password(
|
||||||
|
user,
|
||||||
|
current_password=body.current_password,
|
||||||
|
new_password=body.new_password,
|
||||||
|
new_email=body.new_email,
|
||||||
|
)
|
||||||
|
except AuthServiceError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
detail=AuthErrorResponse(code=exc.code, message=exc.message).model_dump(),
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||||
|
_set_session_cookie(response, token, request)
|
||||||
|
return MessageResponse(message="Password changed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
async def get_me(request: Request):
|
||||||
|
user = await get_current_user_from_request(request)
|
||||||
|
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/setup-status")
|
||||||
|
async def setup_status(auth_service: CurrentAuthService):
|
||||||
|
return {"needs_setup": await auth_service.get_setup_status()}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/initialize", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def initialize_admin(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
body: InitializeAdminRequest,
|
||||||
|
auth_service: CurrentAuthService,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
user = await auth_service.initialize_admin(body.email, body.password)
|
||||||
|
except AuthServiceError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
detail=AuthErrorResponse(code=exc.code, message=exc.message).model_dump(),
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||||
|
_set_session_cookie(response, token, request)
|
||||||
|
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/oauth/{provider}")
|
||||||
|
async def oauth_login(provider: str):
|
||||||
|
if provider not in ["github", "google"]:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported OAuth provider: {provider}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="OAuth login not yet implemented")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/callback/{provider}")
|
||||||
|
async def oauth_callback(provider: str, code: str, state: str):
|
||||||
|
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="OAuth callback not yet implemented")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChangePasswordRequest",
|
||||||
|
"InitializeAdminRequest",
|
||||||
|
"LoginResponse",
|
||||||
|
"MessageResponse",
|
||||||
|
"RegisterRequest",
|
||||||
|
"_check_rate_limit",
|
||||||
|
"_get_client_ip",
|
||||||
|
"_login_attempts",
|
||||||
|
"_record_login_failure",
|
||||||
|
"_record_login_success",
|
||||||
|
"router",
|
||||||
|
]
|
||||||
@@ -0,0 +1,176 @@
|
|||||||
|
"""HTTP schemas and request helpers for the auth plugin API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from ipaddress import ip_address, ip_network
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
from pydantic import BaseModel, EmailStr, Field, field_validator
|
||||||
|
|
||||||
|
_COMMON_PASSWORDS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"password",
|
||||||
|
"password1",
|
||||||
|
"password12",
|
||||||
|
"password123",
|
||||||
|
"password1234",
|
||||||
|
"12345678",
|
||||||
|
"123456789",
|
||||||
|
"1234567890",
|
||||||
|
"qwerty12",
|
||||||
|
"qwertyui",
|
||||||
|
"qwerty123",
|
||||||
|
"abc12345",
|
||||||
|
"abcd1234",
|
||||||
|
"iloveyou",
|
||||||
|
"letmein1",
|
||||||
|
"welcome1",
|
||||||
|
"welcome123",
|
||||||
|
"admin123",
|
||||||
|
"administrator",
|
||||||
|
"passw0rd",
|
||||||
|
"p@ssw0rd",
|
||||||
|
"monkey12",
|
||||||
|
"trustno1",
|
||||||
|
"sunshine",
|
||||||
|
"princess",
|
||||||
|
"football",
|
||||||
|
"baseball",
|
||||||
|
"superman",
|
||||||
|
"batman123",
|
||||||
|
"starwars",
|
||||||
|
"dragon123",
|
||||||
|
"master123",
|
||||||
|
"shadow12",
|
||||||
|
"michael1",
|
||||||
|
"jennifer",
|
||||||
|
"computer",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
_MAX_LOGIN_ATTEMPTS = 5
|
||||||
|
_LOCKOUT_SECONDS = 300
|
||||||
|
_MAX_TRACKED_IPS = 10000
|
||||||
|
_login_attempts: dict[str, tuple[int, float]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class LoginResponse(BaseModel):
|
||||||
|
expires_in: int
|
||||||
|
needs_setup: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterRequest(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
password: str = Field(..., min_length=8)
|
||||||
|
|
||||||
|
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
||||||
|
|
||||||
|
|
||||||
|
class ChangePasswordRequest(BaseModel):
|
||||||
|
current_password: str
|
||||||
|
new_password: str = Field(..., min_length=8)
|
||||||
|
new_email: EmailStr | None = None
|
||||||
|
|
||||||
|
_strong_password = field_validator("new_password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
||||||
|
|
||||||
|
|
||||||
|
class MessageResponse(BaseModel):
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class InitializeAdminRequest(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
password: str = Field(..., min_length=8)
|
||||||
|
|
||||||
|
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
||||||
|
|
||||||
|
|
||||||
|
def _password_is_common(password: str) -> bool:
|
||||||
|
return password.lower() in _COMMON_PASSWORDS
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_strong_password(value: str) -> str:
|
||||||
|
if _password_is_common(value):
|
||||||
|
raise ValueError("Password is too common; choose a stronger password.")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _trusted_proxies() -> list:
|
||||||
|
raw = os.getenv("AUTH_TRUSTED_PROXIES", "").strip()
|
||||||
|
if not raw:
|
||||||
|
return []
|
||||||
|
nets = []
|
||||||
|
for entry in raw.split(","):
|
||||||
|
entry = entry.strip()
|
||||||
|
if not entry:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
nets.append(ip_network(entry, strict=False))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return nets
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client_ip(request: Request) -> str:
|
||||||
|
peer_host = request.client.host if request.client else None
|
||||||
|
trusted = _trusted_proxies()
|
||||||
|
if trusted and peer_host:
|
||||||
|
try:
|
||||||
|
peer_ip = ip_address(peer_host)
|
||||||
|
if any(peer_ip in net for net in trusted):
|
||||||
|
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||||
|
if real_ip:
|
||||||
|
return real_ip
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return peer_host or "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def _check_rate_limit(ip: str) -> None:
|
||||||
|
record = _login_attempts.get(ip)
|
||||||
|
if record is None:
|
||||||
|
return
|
||||||
|
fail_count, lock_until = record
|
||||||
|
if fail_count >= _MAX_LOGIN_ATTEMPTS:
|
||||||
|
if time.time() < lock_until:
|
||||||
|
raise HTTPException(status_code=429, detail="Too many login attempts. Try again later.")
|
||||||
|
del _login_attempts[ip]
|
||||||
|
|
||||||
|
|
||||||
|
def _record_login_failure(ip: str) -> None:
|
||||||
|
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||||
|
now = time.time()
|
||||||
|
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
|
||||||
|
for key in expired:
|
||||||
|
del _login_attempts[key]
|
||||||
|
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||||
|
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
|
||||||
|
for key, _ in by_time[: len(by_time) // 2]:
|
||||||
|
del _login_attempts[key]
|
||||||
|
|
||||||
|
record = _login_attempts.get(ip)
|
||||||
|
if record is None:
|
||||||
|
_login_attempts[ip] = (1, 0.0)
|
||||||
|
else:
|
||||||
|
new_count = record[0] + 1
|
||||||
|
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
|
||||||
|
_login_attempts[ip] = (new_count, lock_until)
|
||||||
|
|
||||||
|
|
||||||
|
def _record_login_success(ip: str) -> None:
|
||||||
|
_login_attempts.pop(ip, None)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChangePasswordRequest",
|
||||||
|
"InitializeAdminRequest",
|
||||||
|
"LoginResponse",
|
||||||
|
"MessageResponse",
|
||||||
|
"RegisterRequest",
|
||||||
|
"_check_rate_limit",
|
||||||
|
"_get_client_ip",
|
||||||
|
"_login_attempts",
|
||||||
|
"_record_login_failure",
|
||||||
|
"_record_login_success",
|
||||||
|
]
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
"""Authorization layer for the auth plugin."""
|
||||||
|
|
||||||
|
from app.plugins.auth.authorization.authentication import get_auth_context
|
||||||
|
from app.plugins.auth.authorization.hooks import (
|
||||||
|
AuthzHooks,
|
||||||
|
build_authz_hooks,
|
||||||
|
build_permission_provider,
|
||||||
|
build_policy_chain_builder,
|
||||||
|
get_authz_hooks,
|
||||||
|
get_default_authz_hooks,
|
||||||
|
)
|
||||||
|
from app.plugins.auth.authorization.types import (
|
||||||
|
AuthContext,
|
||||||
|
Permissions,
|
||||||
|
ALL_PERMISSIONS,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ALL_PERMISSIONS = ALL_PERMISSIONS
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AuthContext",
|
||||||
|
"AuthzHooks",
|
||||||
|
"Permissions",
|
||||||
|
"_ALL_PERMISSIONS",
|
||||||
|
"build_authz_hooks",
|
||||||
|
"build_permission_provider",
|
||||||
|
"build_policy_chain_builder",
|
||||||
|
"get_auth_context",
|
||||||
|
"get_authz_hooks",
|
||||||
|
"get_default_authz_hooks",
|
||||||
|
]
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
"""Authentication helpers used by auth-plugin authorization decorators."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from app.plugins.auth.authorization.providers import PermissionProvider, default_permission_provider
|
||||||
|
from app.plugins.auth.authorization.types import AuthContext
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_context(request: Request) -> AuthContext | None:
|
||||||
|
"""Get AuthContext, preferring Starlette-style request.auth."""
|
||||||
|
|
||||||
|
auth = request.scope.get("auth")
|
||||||
|
if isinstance(auth, AuthContext):
|
||||||
|
return auth
|
||||||
|
return getattr(request.state, "auth", None)
|
||||||
|
|
||||||
|
|
||||||
|
def set_auth_context(request: Request, auth_context: AuthContext) -> AuthContext:
|
||||||
|
"""Persist AuthContext on the standard request surfaces."""
|
||||||
|
|
||||||
|
request.scope["auth"] = auth_context
|
||||||
|
request.state.auth = auth_context
|
||||||
|
return auth_context
|
||||||
|
|
||||||
|
|
||||||
|
async def authenticate_request(
|
||||||
|
request: Request,
|
||||||
|
*,
|
||||||
|
permission_provider: PermissionProvider = default_permission_provider,
|
||||||
|
) -> AuthContext:
|
||||||
|
"""Authenticate request and build AuthContext."""
|
||||||
|
|
||||||
|
from app.plugins.auth.security.dependencies import get_optional_user_from_request
|
||||||
|
|
||||||
|
user = await get_optional_user_from_request(request)
|
||||||
|
if user is None:
|
||||||
|
return AuthContext(user=None, permissions=[])
|
||||||
|
return AuthContext(user=user, permissions=permission_provider(user))
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["authenticate_request", "get_auth_context", "set_auth_context"]
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
"""Authorization requirement and policy evaluation helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable, Mapping
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
from app.plugins.auth.authorization.policies import require_thread_owner
|
||||||
|
from app.plugins.auth.authorization.types import AuthContext
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PermissionRequirement:
|
||||||
|
"""Authorization requirement for a single route action."""
|
||||||
|
|
||||||
|
resource: str
|
||||||
|
action: str
|
||||||
|
owner_check: bool = False
|
||||||
|
require_existing: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def permission(self) -> str:
|
||||||
|
return f"{self.resource}:{self.action}"
|
||||||
|
|
||||||
|
|
||||||
|
PolicyEvaluator = Callable[[Request, AuthContext, PermissionRequirement, Mapping[str, Any]], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_authenticated(auth: AuthContext) -> None:
|
||||||
|
if not auth.is_authenticated:
|
||||||
|
raise HTTPException(status_code=401, detail="Authentication required")
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_capability(auth: AuthContext, requirement: PermissionRequirement) -> None:
|
||||||
|
if not auth.has_permission(requirement.resource, requirement.action):
|
||||||
|
raise HTTPException(status_code=403, detail=f"Permission denied: {requirement.permission}")
|
||||||
|
|
||||||
|
|
||||||
|
async def evaluate_owner_policy(
|
||||||
|
request: Request,
|
||||||
|
auth: AuthContext,
|
||||||
|
requirement: PermissionRequirement,
|
||||||
|
route_params: Mapping[str, Any],
|
||||||
|
) -> None:
|
||||||
|
if not requirement.owner_check:
|
||||||
|
return
|
||||||
|
|
||||||
|
thread_id = route_params.get("thread_id")
|
||||||
|
if thread_id is None:
|
||||||
|
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||||
|
|
||||||
|
await require_thread_owner(
|
||||||
|
request,
|
||||||
|
auth,
|
||||||
|
thread_id=thread_id,
|
||||||
|
require_existing=requirement.require_existing,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def evaluate_requirement(
|
||||||
|
request: Request,
|
||||||
|
auth: AuthContext,
|
||||||
|
requirement: PermissionRequirement,
|
||||||
|
route_params: Mapping[str, Any],
|
||||||
|
*,
|
||||||
|
policy_evaluators: tuple[PolicyEvaluator, ...],
|
||||||
|
) -> None:
|
||||||
|
ensure_authenticated(auth)
|
||||||
|
ensure_capability(auth, requirement)
|
||||||
|
for evaluator in policy_evaluators:
|
||||||
|
await evaluator(request, auth, requirement, route_params)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PermissionRequirement",
|
||||||
|
"PolicyEvaluator",
|
||||||
|
"ensure_authenticated",
|
||||||
|
"ensure_capability",
|
||||||
|
"evaluate_owner_policy",
|
||||||
|
"evaluate_requirement",
|
||||||
|
]
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
"""Auth-plugin authz extension hooks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from app.plugins.auth.authorization.providers import PermissionProvider, default_permission_provider
|
||||||
|
from app.plugins.auth.authorization.registry import PolicyChainBuilder, build_default_policy_evaluators
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AuthzHooks:
|
||||||
|
"""Extension hooks for permission and policy resolution."""
|
||||||
|
|
||||||
|
permission_provider: PermissionProvider = default_permission_provider
|
||||||
|
policy_chain_builder: PolicyChainBuilder = build_default_policy_evaluators
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_AUTHZ_HOOKS = AuthzHooks()
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_authz_hooks() -> AuthzHooks:
|
||||||
|
return DEFAULT_AUTHZ_HOOKS
|
||||||
|
|
||||||
|
|
||||||
|
def get_authz_hooks(request: Request | Any | None = None) -> AuthzHooks:
|
||||||
|
if request is not None:
|
||||||
|
app = getattr(request, "app", None)
|
||||||
|
state = getattr(app, "state", None)
|
||||||
|
hooks = getattr(state, "authz_hooks", None)
|
||||||
|
if isinstance(hooks, AuthzHooks):
|
||||||
|
return hooks
|
||||||
|
return DEFAULT_AUTHZ_HOOKS
|
||||||
|
|
||||||
|
|
||||||
|
def build_permission_provider() -> PermissionProvider:
|
||||||
|
return default_permission_provider
|
||||||
|
|
||||||
|
|
||||||
|
def build_policy_chain_builder() -> PolicyChainBuilder:
|
||||||
|
return build_default_policy_evaluators
|
||||||
|
|
||||||
|
|
||||||
|
def build_authz_hooks() -> AuthzHooks:
|
||||||
|
return AuthzHooks(
|
||||||
|
permission_provider=build_permission_provider(),
|
||||||
|
policy_chain_builder=build_policy_chain_builder(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AuthzHooks",
|
||||||
|
"DEFAULT_AUTHZ_HOOKS",
|
||||||
|
"build_authz_hooks",
|
||||||
|
"build_permission_provider",
|
||||||
|
"build_policy_chain_builder",
|
||||||
|
"get_authz_hooks",
|
||||||
|
"get_default_authz_hooks",
|
||||||
|
]
|
||||||
@@ -0,0 +1,101 @@
|
|||||||
|
"""Authorization policies for resource ownership and access checks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
from app.plugins.auth.authorization.types import AuthContext
|
||||||
|
|
||||||
|
|
||||||
|
def _get_thread_owner_id(thread_meta: Any) -> str | None:
|
||||||
|
owner_id = getattr(thread_meta, "user_id", None)
|
||||||
|
if owner_id is not None:
|
||||||
|
return str(owner_id)
|
||||||
|
|
||||||
|
metadata = getattr(thread_meta, "metadata", None) or {}
|
||||||
|
metadata_owner_id = metadata.get("user_id")
|
||||||
|
if metadata_owner_id is not None:
|
||||||
|
return str(metadata_owner_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _thread_exists_via_legacy_sources(request: Request, auth: AuthContext, *, thread_id: str) -> bool:
|
||||||
|
from app.gateway.dependencies.repositories import get_run_repository
|
||||||
|
|
||||||
|
principal_id = auth.principal_id
|
||||||
|
run_store = get_run_repository(request)
|
||||||
|
runs = await run_store.list_by_thread(
|
||||||
|
thread_id,
|
||||||
|
limit=1,
|
||||||
|
user_id=principal_id,
|
||||||
|
)
|
||||||
|
if runs:
|
||||||
|
return True
|
||||||
|
|
||||||
|
checkpointer = getattr(request.app.state, "checkpointer", None)
|
||||||
|
if checkpointer is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
checkpoint_tuple = await checkpointer.aget_tuple(
|
||||||
|
{"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
|
)
|
||||||
|
return checkpoint_tuple is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def require_thread_owner(
|
||||||
|
request: Request,
|
||||||
|
auth: AuthContext,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
require_existing: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Ensure the current user owns the thread referenced by ``thread_id``."""
|
||||||
|
|
||||||
|
from app.gateway.dependencies.repositories import get_thread_meta_repository
|
||||||
|
|
||||||
|
thread_repo = get_thread_meta_repository(request)
|
||||||
|
thread_meta = await thread_repo.get_thread_meta(thread_id)
|
||||||
|
if thread_meta is None:
|
||||||
|
allowed = not require_existing
|
||||||
|
if not allowed:
|
||||||
|
allowed = await _thread_exists_via_legacy_sources(request, auth, thread_id=thread_id)
|
||||||
|
else:
|
||||||
|
owner_id = _get_thread_owner_id(thread_meta)
|
||||||
|
allowed = owner_id in (None, str(auth.user.id))
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Thread {thread_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def require_run_owner(
|
||||||
|
request: Request,
|
||||||
|
auth: AuthContext,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
require_existing: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Ensure the current user owns the run referenced by ``run_id``."""
|
||||||
|
|
||||||
|
from app.gateway.dependencies import get_run_repository
|
||||||
|
|
||||||
|
run_store = get_run_repository(request)
|
||||||
|
run = await run_store.get(run_id)
|
||||||
|
if run is None:
|
||||||
|
allowed = not require_existing
|
||||||
|
else:
|
||||||
|
allowed = run.get("thread_id") == thread_id
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Run {run_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["require_run_owner", "require_thread_owner"]
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
"""Default permission provider hooks for auth-plugin authorization."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from app.plugins.auth.authorization.types import ALL_PERMISSIONS
|
||||||
|
|
||||||
|
PermissionProvider = Callable[[object], list[str]]
|
||||||
|
|
||||||
|
|
||||||
|
def default_permission_provider(user: object) -> list[str]:
|
||||||
|
"""Return the current static permission set for an authenticated user."""
|
||||||
|
|
||||||
|
return list(ALL_PERMISSIONS)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["PermissionProvider", "default_permission_provider"]
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
"""Registry/build helpers for default authorization evaluators."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.plugins.auth.authorization.authorization import PolicyEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
PolicyChainBuilder = Callable[[], tuple["PolicyEvaluator", ...]]
|
||||||
|
|
||||||
|
|
||||||
|
def build_default_policy_evaluators() -> tuple["PolicyEvaluator", ...]:
|
||||||
|
"""Return the default policy chain for auth-plugin authorization."""
|
||||||
|
|
||||||
|
from app.plugins.auth.authorization.authorization import evaluate_owner_policy
|
||||||
|
|
||||||
|
return (evaluate_owner_policy,)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["PolicyChainBuilder", "build_default_policy_evaluators"]
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
"""Authorization context and capability constants for the auth plugin."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.plugins.auth.domain.models import User
|
||||||
|
|
||||||
|
|
||||||
|
class Permissions:
|
||||||
|
"""Permission constants for resource:action format."""
|
||||||
|
|
||||||
|
THREADS_READ = "threads:read"
|
||||||
|
THREADS_WRITE = "threads:write"
|
||||||
|
THREADS_DELETE = "threads:delete"
|
||||||
|
|
||||||
|
RUNS_CREATE = "runs:create"
|
||||||
|
RUNS_READ = "runs:read"
|
||||||
|
RUNS_CANCEL = "runs:cancel"
|
||||||
|
|
||||||
|
|
||||||
|
class AuthContext:
|
||||||
|
"""Authentication context for the current request."""
|
||||||
|
|
||||||
|
__slots__ = ("user", "permissions")
|
||||||
|
|
||||||
|
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
|
||||||
|
self.user = user
|
||||||
|
self.permissions = permissions or []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_authenticated(self) -> bool:
|
||||||
|
return self.user is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def principal_id(self) -> str | None:
|
||||||
|
if self.user is None:
|
||||||
|
return None
|
||||||
|
return str(self.user.id)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capabilities(self) -> tuple[str, ...]:
|
||||||
|
return tuple(self.permissions)
|
||||||
|
|
||||||
|
def has_permission(self, resource: str, action: str) -> bool:
|
||||||
|
return f"{resource}:{action}" in self.permissions
|
||||||
|
|
||||||
|
def require_user(self) -> User:
|
||||||
|
if not self.user:
|
||||||
|
raise HTTPException(status_code=401, detail="Authentication required")
|
||||||
|
return self.user
|
||||||
|
|
||||||
|
|
||||||
|
ALL_PERMISSIONS: list[str] = [
|
||||||
|
Permissions.THREADS_READ,
|
||||||
|
Permissions.THREADS_WRITE,
|
||||||
|
Permissions.THREADS_DELETE,
|
||||||
|
Permissions.RUNS_CREATE,
|
||||||
|
Permissions.RUNS_READ,
|
||||||
|
Permissions.RUNS_CANCEL,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ALL_PERMISSIONS", "AuthContext", "Permissions"]
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
"""Domain layer for the auth plugin."""
|
||||||
|
|
||||||
|
from app.plugins.auth.domain.config import AuthConfig, load_auth_config_from_env
|
||||||
|
from app.plugins.auth.domain.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
||||||
|
from app.plugins.auth.domain.jwt import TokenPayload, create_access_token, decode_token
|
||||||
|
from app.plugins.auth.domain.models import User, UserResponse
|
||||||
|
from app.plugins.auth.domain.password import hash_password, hash_password_async, verify_password, verify_password_async
|
||||||
|
from app.plugins.auth.domain.service import AuthService, AuthServiceError
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AuthConfig",
|
||||||
|
"AuthErrorCode",
|
||||||
|
"AuthErrorResponse",
|
||||||
|
"AuthService",
|
||||||
|
"AuthServiceError",
|
||||||
|
"TokenError",
|
||||||
|
"TokenPayload",
|
||||||
|
"User",
|
||||||
|
"UserResponse",
|
||||||
|
"create_access_token",
|
||||||
|
"decode_token",
|
||||||
|
"hash_password",
|
||||||
|
"hash_password_async",
|
||||||
|
"load_auth_config_from_env",
|
||||||
|
"token_error_to_code",
|
||||||
|
"verify_password",
|
||||||
|
"verify_password_async",
|
||||||
|
]
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
"""Auth configuration schema and environment loader."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthConfig(BaseModel):
|
||||||
|
"""JWT and auth-related configuration."""
|
||||||
|
|
||||||
|
jwt_secret: str = Field(..., description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.")
|
||||||
|
token_expiry_days: int = Field(default=7, ge=1, le=30)
|
||||||
|
oauth_github_client_id: str | None = Field(default=None)
|
||||||
|
oauth_github_client_secret: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
def load_auth_config_from_env() -> AuthConfig:
|
||||||
|
"""Build an auth config from environment variables."""
|
||||||
|
|
||||||
|
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
||||||
|
if not jwt_secret:
|
||||||
|
jwt_secret = secrets.token_urlsafe(32)
|
||||||
|
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
||||||
|
logger.warning(
|
||||||
|
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
|
||||||
|
"Sessions will be invalidated on restart. "
|
||||||
|
"For production, add AUTH_JWT_SECRET to your .env file: "
|
||||||
|
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
||||||
|
)
|
||||||
|
return AuthConfig(jwt_secret=jwt_secret)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["AuthConfig", "load_auth_config_from_env"]
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
"""Typed error definitions for auth plugin."""
|
||||||
|
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class AuthErrorCode(StrEnum):
|
||||||
|
INVALID_CREDENTIALS = "invalid_credentials"
|
||||||
|
TOKEN_EXPIRED = "token_expired"
|
||||||
|
TOKEN_INVALID = "token_invalid"
|
||||||
|
USER_NOT_FOUND = "user_not_found"
|
||||||
|
EMAIL_ALREADY_EXISTS = "email_already_exists"
|
||||||
|
PROVIDER_NOT_FOUND = "provider_not_found"
|
||||||
|
NOT_AUTHENTICATED = "not_authenticated"
|
||||||
|
SYSTEM_ALREADY_INITIALIZED = "system_already_initialized"
|
||||||
|
|
||||||
|
|
||||||
|
class TokenError(StrEnum):
|
||||||
|
EXPIRED = "expired"
|
||||||
|
INVALID_SIGNATURE = "invalid_signature"
|
||||||
|
MALFORMED = "malformed"
|
||||||
|
|
||||||
|
|
||||||
|
class AuthErrorResponse(BaseModel):
|
||||||
|
code: AuthErrorCode
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
def token_error_to_code(err: TokenError) -> AuthErrorCode:
|
||||||
|
if err == TokenError.EXPIRED:
|
||||||
|
return AuthErrorCode.TOKEN_EXPIRED
|
||||||
|
return AuthErrorCode.TOKEN_INVALID
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
"""JWT token creation and verification."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.plugins.auth.domain.errors import TokenError
|
||||||
|
from app.plugins.auth.runtime.config_state import get_auth_config
|
||||||
|
|
||||||
|
|
||||||
|
class TokenPayload(BaseModel):
|
||||||
|
sub: str
|
||||||
|
exp: datetime
|
||||||
|
iat: datetime | None = None
|
||||||
|
ver: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(user_id: str, expires_delta: timedelta | None = None, token_version: int = 0) -> str:
|
||||||
|
config = get_auth_config()
|
||||||
|
expiry = expires_delta or timedelta(days=config.token_expiry_days)
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
payload = {"sub": user_id, "exp": now + expiry, "iat": now, "ver": token_version}
|
||||||
|
return jwt.encode(payload, config.jwt_secret, algorithm="HS256")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_token(token: str) -> TokenPayload | TokenError:
|
||||||
|
config = get_auth_config()
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, config.jwt_secret, algorithms=["HS256"])
|
||||||
|
return TokenPayload(**payload)
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
return TokenError.EXPIRED
|
||||||
|
except jwt.InvalidSignatureError:
|
||||||
|
return TokenError.INVALID_SIGNATURE
|
||||||
|
except jwt.PyJWTError:
|
||||||
|
return TokenError.MALFORMED
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
"""User Pydantic models for the auth plugin."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Literal
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
||||||
|
|
||||||
|
|
||||||
|
def _utc_now() -> datetime:
|
||||||
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
class User(BaseModel):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: UUID = Field(default_factory=uuid4, description="Primary key")
|
||||||
|
email: EmailStr = Field(..., description="Unique email address")
|
||||||
|
password_hash: str | None = Field(None, description="bcrypt hash, nullable for OAuth users")
|
||||||
|
system_role: Literal["admin", "user"] = Field(default="user")
|
||||||
|
created_at: datetime = Field(default_factory=_utc_now)
|
||||||
|
oauth_provider: str | None = Field(None, description="e.g. 'github', 'google'")
|
||||||
|
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
|
||||||
|
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
|
||||||
|
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
|
||||||
|
|
||||||
|
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
email: str
|
||||||
|
system_role: Literal["admin", "user"]
|
||||||
|
needs_setup: bool = False
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
"""Password hashing utilities using bcrypt directly."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
|
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
async def hash_password_async(password: str) -> str:
|
||||||
|
return await asyncio.to_thread(hash_password, password)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||||
|
return await asyncio.to_thread(verify_password, plain_password, hashed_password)
|
||||||
@@ -0,0 +1,175 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from http import HTTPStatus
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from app.plugins.auth.domain.errors import AuthErrorCode
|
||||||
|
from app.plugins.auth.domain.models import User
|
||||||
|
from app.plugins.auth.domain.password import hash_password_async, verify_password_async
|
||||||
|
from app.plugins.auth.storage import DbUserRepository, UserCreate
|
||||||
|
from app.plugins.auth.storage.contracts import User as StoreUser
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class AuthServiceError(Exception):
|
||||||
|
code: AuthErrorCode
|
||||||
|
message: str
|
||||||
|
status_code: int
|
||||||
|
|
||||||
|
|
||||||
|
def _to_auth_user(user: StoreUser) -> User:
|
||||||
|
return User(
|
||||||
|
id=UUID(user.id),
|
||||||
|
email=user.email,
|
||||||
|
password_hash=user.password_hash,
|
||||||
|
system_role=user.system_role, # type: ignore[arg-type]
|
||||||
|
created_at=user.created_time,
|
||||||
|
oauth_provider=user.oauth_provider,
|
||||||
|
oauth_id=user.oauth_id,
|
||||||
|
needs_setup=user.needs_setup,
|
||||||
|
token_version=user.token_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_store_user(user: User) -> StoreUser:
|
||||||
|
return StoreUser(
|
||||||
|
id=str(user.id),
|
||||||
|
email=user.email,
|
||||||
|
password_hash=user.password_hash,
|
||||||
|
system_role=user.system_role,
|
||||||
|
oauth_provider=user.oauth_provider,
|
||||||
|
oauth_id=user.oauth_id,
|
||||||
|
needs_setup=user.needs_setup,
|
||||||
|
token_version=user.token_version,
|
||||||
|
created_time=user.created_at,
|
||||||
|
updated_time=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthService:
|
||||||
|
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
|
||||||
|
async def login_local(self, email: str, password: str) -> User:
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = DbUserRepository(session)
|
||||||
|
user = await repo.get_user_by_email(email)
|
||||||
|
if user is None or user.password_hash is None:
|
||||||
|
raise AuthServiceError(
|
||||||
|
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||||
|
message="Incorrect email or password",
|
||||||
|
status_code=HTTPStatus.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
if not await verify_password_async(password, user.password_hash):
|
||||||
|
raise AuthServiceError(
|
||||||
|
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||||
|
message="Incorrect email or password",
|
||||||
|
status_code=HTTPStatus.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
return _to_auth_user(user)
|
||||||
|
|
||||||
|
async def register(self, email: str, password: str) -> User:
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = DbUserRepository(session)
|
||||||
|
try:
|
||||||
|
user = await repo.create_user(
|
||||||
|
UserCreate(
|
||||||
|
email=email,
|
||||||
|
password_hash=await hash_password_async(password),
|
||||||
|
system_role="user",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
except ValueError as exc:
|
||||||
|
await session.rollback()
|
||||||
|
raise AuthServiceError(
|
||||||
|
code=AuthErrorCode.EMAIL_ALREADY_EXISTS,
|
||||||
|
message="Email already registered",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
) from exc
|
||||||
|
return _to_auth_user(user)
|
||||||
|
|
||||||
|
async def change_password(
|
||||||
|
self,
|
||||||
|
user: User | StoreUser,
|
||||||
|
*,
|
||||||
|
current_password: str,
|
||||||
|
new_password: str,
|
||||||
|
new_email: str | None = None,
|
||||||
|
) -> User:
|
||||||
|
if user.password_hash is None:
|
||||||
|
raise AuthServiceError(
|
||||||
|
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||||
|
message="OAuth users cannot change password",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
)
|
||||||
|
if not await verify_password_async(current_password, user.password_hash):
|
||||||
|
raise AuthServiceError(
|
||||||
|
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||||
|
message="Current password is incorrect",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = DbUserRepository(session)
|
||||||
|
updated_email = user.email
|
||||||
|
if new_email is not None:
|
||||||
|
existing = await repo.get_user_by_email(new_email)
|
||||||
|
if existing and existing.id != str(user.id):
|
||||||
|
raise AuthServiceError(
|
||||||
|
code=AuthErrorCode.EMAIL_ALREADY_EXISTS,
|
||||||
|
message="Email already in use",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
)
|
||||||
|
updated_email = new_email
|
||||||
|
|
||||||
|
updated_user = user.model_copy(
|
||||||
|
update={
|
||||||
|
"email": updated_email,
|
||||||
|
"password_hash": await hash_password_async(new_password),
|
||||||
|
"token_version": user.token_version + 1,
|
||||||
|
"needs_setup": False if user.needs_setup and new_email is not None else user.needs_setup,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
updated = await repo.update_user(_to_store_user(_to_auth_user(updated_user) if isinstance(updated_user, StoreUser) else updated_user))
|
||||||
|
await session.commit()
|
||||||
|
return _to_auth_user(updated)
|
||||||
|
|
||||||
|
async def get_setup_status(self) -> bool:
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = DbUserRepository(session)
|
||||||
|
admin_count = await repo.count_admin_users()
|
||||||
|
return admin_count == 0
|
||||||
|
|
||||||
|
async def initialize_admin(self, email: str, password: str) -> User:
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = DbUserRepository(session)
|
||||||
|
admin_count = await repo.count_admin_users()
|
||||||
|
if admin_count > 0:
|
||||||
|
raise AuthServiceError(
|
||||||
|
code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED,
|
||||||
|
message="System already initialized",
|
||||||
|
status_code=HTTPStatus.CONFLICT,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
user = await repo.create_user(
|
||||||
|
UserCreate(
|
||||||
|
email=email,
|
||||||
|
password_hash=await hash_password_async(password),
|
||||||
|
system_role="admin",
|
||||||
|
needs_setup=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
except ValueError as exc:
|
||||||
|
await session.rollback()
|
||||||
|
raise AuthServiceError(
|
||||||
|
code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED,
|
||||||
|
message="System already initialized",
|
||||||
|
status_code=HTTPStatus.CONFLICT,
|
||||||
|
) from exc
|
||||||
|
return _to_auth_user(user)
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
"""Config-driven route authorization injection for the auth plugin."""
|
||||||
|
|
||||||
|
from app.plugins.auth.injection.registry_loader import (
|
||||||
|
RoutePolicyRegistry,
|
||||||
|
RoutePolicySpec,
|
||||||
|
load_route_policy_registry,
|
||||||
|
)
|
||||||
|
from app.plugins.auth.injection.route_injector import install_route_guards
|
||||||
|
from app.plugins.auth.injection.validation import validate_route_policy_registry
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RoutePolicyRegistry",
|
||||||
|
"RoutePolicySpec",
|
||||||
|
"install_route_guards",
|
||||||
|
"load_route_policy_registry",
|
||||||
|
"validate_route_policy_registry",
|
||||||
|
]
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
"""Load auth route policies from the plugin's YAML registry."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from starlette.routing import compile_path
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
_POLICY_FILE = Path(__file__).resolve().parents[1] / "route_policies.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RoutePolicySpec:
|
||||||
|
public: bool = False
|
||||||
|
capability: str | None = None
|
||||||
|
policies: tuple[str, ...] = ()
|
||||||
|
require_existing: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RoutePolicyEntry:
|
||||||
|
method: str
|
||||||
|
path: str
|
||||||
|
spec: RoutePolicySpec
|
||||||
|
path_regex: object = field(repr=False)
|
||||||
|
|
||||||
|
def matches_request(self, method: str, path: str) -> bool:
|
||||||
|
if self.method != method.upper():
|
||||||
|
return False
|
||||||
|
return self.path_regex.match(path) is not None
|
||||||
|
|
||||||
|
|
||||||
|
class RoutePolicyRegistry:
|
||||||
|
def __init__(self, entries: list[RoutePolicyEntry]) -> None:
|
||||||
|
self._entries = entries
|
||||||
|
self._specs = {(entry.method, entry.path): entry.spec for entry in entries}
|
||||||
|
|
||||||
|
def get(self, method: str, path_template: str) -> RoutePolicySpec | None:
|
||||||
|
return self._specs.get((method.upper(), path_template))
|
||||||
|
|
||||||
|
def has(self, method: str, path_template: str) -> bool:
|
||||||
|
return (method.upper(), path_template) in self._specs
|
||||||
|
|
||||||
|
def match_request(self, method: str, path: str) -> RoutePolicySpec | None:
|
||||||
|
normalized_method = method.upper()
|
||||||
|
for entry in self._entries:
|
||||||
|
if entry.matches_request(normalized_method, path):
|
||||||
|
return entry.spec
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_public_request(self, method: str, path: str) -> bool:
|
||||||
|
spec = self.match_request(method, path)
|
||||||
|
return bool(spec and spec.public)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def keys(self) -> set[tuple[str, str]]:
|
||||||
|
return set(self._specs)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_methods(item: dict) -> tuple[str, ...]:
|
||||||
|
methods = item.get("methods")
|
||||||
|
if methods is None:
|
||||||
|
methods = [item["method"]]
|
||||||
|
if isinstance(methods, str):
|
||||||
|
methods = [methods]
|
||||||
|
return tuple(str(method).upper() for method in methods)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_spec(item: dict) -> RoutePolicySpec:
|
||||||
|
return RoutePolicySpec(
|
||||||
|
public=bool(item.get("public", False)),
|
||||||
|
capability=item.get("capability"),
|
||||||
|
policies=tuple(item.get("policies", [])),
|
||||||
|
require_existing=bool(item.get("require_existing", True)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_route_policy_registry() -> RoutePolicyRegistry:
|
||||||
|
payload = yaml.safe_load(_POLICY_FILE.read_text(encoding="utf-8")) or {}
|
||||||
|
raw_routes: list[dict] = []
|
||||||
|
for section, entries in payload.items():
|
||||||
|
if section == "routes":
|
||||||
|
if isinstance(entries, list):
|
||||||
|
raw_routes.extend(entries)
|
||||||
|
continue
|
||||||
|
if not isinstance(entries, list):
|
||||||
|
continue
|
||||||
|
for item in entries:
|
||||||
|
normalized = dict(item)
|
||||||
|
if section == "public":
|
||||||
|
normalized["public"] = True
|
||||||
|
raw_routes.append(normalized)
|
||||||
|
entries: list[RoutePolicyEntry] = []
|
||||||
|
for item in raw_routes:
|
||||||
|
path = str(item["path"])
|
||||||
|
spec = _build_spec(item)
|
||||||
|
path_regex, _, _ = compile_path(path)
|
||||||
|
for method in _normalize_methods(item):
|
||||||
|
entries.append(
|
||||||
|
RoutePolicyEntry(
|
||||||
|
method=method,
|
||||||
|
path=path,
|
||||||
|
spec=spec,
|
||||||
|
path_regex=path_regex,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return RoutePolicyRegistry(entries)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["RoutePolicyRegistry", "RoutePolicySpec", "load_route_policy_registry"]
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
"""Runtime route guard backed by the auth plugin's route policy registry."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
from app.plugins.auth.authorization.authentication import (
|
||||||
|
authenticate_request,
|
||||||
|
get_auth_context,
|
||||||
|
set_auth_context,
|
||||||
|
)
|
||||||
|
from app.plugins.auth.authorization.authorization import ensure_authenticated
|
||||||
|
from app.plugins.auth.authorization.hooks import get_authz_hooks
|
||||||
|
from app.plugins.auth.authorization.policies import require_run_owner, require_thread_owner
|
||||||
|
from app.plugins.auth.injection.registry_loader import RoutePolicyRegistry, RoutePolicySpec
|
||||||
|
|
||||||
|
PolicyGuard = Callable[[Request, RoutePolicySpec], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_capability(request: Request, spec: RoutePolicySpec) -> None:
|
||||||
|
if not spec.capability:
|
||||||
|
return
|
||||||
|
|
||||||
|
auth = get_auth_context(request)
|
||||||
|
if auth is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Missing auth context")
|
||||||
|
|
||||||
|
if ":" not in spec.capability:
|
||||||
|
raise RuntimeError(f"Invalid capability format: {spec.capability}")
|
||||||
|
resource, action = spec.capability.split(":", 1)
|
||||||
|
if not auth.has_permission(resource, action):
|
||||||
|
raise HTTPException(status_code=403, detail=f"Permission denied: {spec.capability}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _guard_thread_owner(request: Request, spec: RoutePolicySpec) -> None:
|
||||||
|
auth = get_auth_context(request)
|
||||||
|
if auth is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Missing auth context")
|
||||||
|
thread_id = request.path_params.get("thread_id")
|
||||||
|
if not isinstance(thread_id, str):
|
||||||
|
raise RuntimeError("owner:thread policy requires thread_id path parameter")
|
||||||
|
await require_thread_owner(request, auth, thread_id=thread_id, require_existing=spec.require_existing)
|
||||||
|
|
||||||
|
|
||||||
|
async def _guard_run_owner(request: Request, spec: RoutePolicySpec) -> None:
|
||||||
|
auth = get_auth_context(request)
|
||||||
|
if auth is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Missing auth context")
|
||||||
|
thread_id = request.path_params.get("thread_id")
|
||||||
|
run_id = request.path_params.get("run_id")
|
||||||
|
if not isinstance(thread_id, str) or not isinstance(run_id, str):
|
||||||
|
raise RuntimeError("owner:run policy requires thread_id and run_id path parameters")
|
||||||
|
await require_run_owner(
|
||||||
|
request,
|
||||||
|
auth,
|
||||||
|
thread_id=thread_id,
|
||||||
|
run_id=run_id,
|
||||||
|
require_existing=spec.require_existing,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_POLICY_GUARDS: dict[str, PolicyGuard] = {
|
||||||
|
"owner:thread": _guard_thread_owner,
|
||||||
|
"owner:run": _guard_run_owner,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def enforce_route_policy(request: Request) -> None:
|
||||||
|
registry = getattr(request.app.state, "auth_route_policy_registry", None)
|
||||||
|
if not isinstance(registry, RoutePolicyRegistry):
|
||||||
|
raise RuntimeError("Auth route policy registry is not configured")
|
||||||
|
|
||||||
|
route = request.scope.get("route")
|
||||||
|
path_template = getattr(route, "path", None)
|
||||||
|
if not isinstance(path_template, str):
|
||||||
|
raise RuntimeError("Unable to resolve route path for authorization")
|
||||||
|
|
||||||
|
spec = registry.get(request.method, path_template)
|
||||||
|
if spec is None:
|
||||||
|
raise RuntimeError(f"Missing auth route policy for {request.method} {path_template}")
|
||||||
|
if spec.public:
|
||||||
|
return
|
||||||
|
|
||||||
|
auth = get_auth_context(request)
|
||||||
|
if auth is None:
|
||||||
|
hooks = get_authz_hooks(request)
|
||||||
|
auth = await authenticate_request(request, permission_provider=hooks.permission_provider)
|
||||||
|
set_auth_context(request, auth)
|
||||||
|
|
||||||
|
ensure_authenticated(auth)
|
||||||
|
await _check_capability(request, spec)
|
||||||
|
|
||||||
|
for policy_name in spec.policies:
|
||||||
|
guard = _POLICY_GUARDS.get(policy_name)
|
||||||
|
if guard is None:
|
||||||
|
raise RuntimeError(f"Unknown route policy guard: {policy_name}")
|
||||||
|
await guard(request, spec)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["enforce_route_policy"]
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
"""Inject config-driven auth guards into FastAPI routes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import Depends, FastAPI
|
||||||
|
from fastapi.dependencies.utils import get_dependant, get_flat_dependant, get_parameterless_sub_dependant
|
||||||
|
from fastapi.routing import APIRoute, _should_embed_body_fields, get_body_field, request_response
|
||||||
|
|
||||||
|
from app.plugins.auth.injection.route_guard import enforce_route_policy
|
||||||
|
|
||||||
|
|
||||||
|
def _rebuild_route(route: APIRoute) -> None:
|
||||||
|
route.dependant = get_dependant(path=route.path_format, call=route.endpoint, scope="function")
|
||||||
|
for depends in route.dependencies[::-1]:
|
||||||
|
route.dependant.dependencies.insert(
|
||||||
|
0,
|
||||||
|
get_parameterless_sub_dependant(depends=depends, path=route.path_format),
|
||||||
|
)
|
||||||
|
route._flat_dependant = get_flat_dependant(route.dependant)
|
||||||
|
route._embed_body_fields = _should_embed_body_fields(route._flat_dependant.body_params)
|
||||||
|
route.body_field = get_body_field(
|
||||||
|
flat_dependant=route._flat_dependant,
|
||||||
|
name=route.unique_id,
|
||||||
|
embed_body_fields=route._embed_body_fields,
|
||||||
|
)
|
||||||
|
route.app = request_response(route.get_route_handler())
|
||||||
|
|
||||||
|
|
||||||
|
def install_route_guards(app: FastAPI) -> None:
|
||||||
|
for route in app.routes:
|
||||||
|
if not isinstance(route, APIRoute):
|
||||||
|
continue
|
||||||
|
if any(getattr(dependency, "dependency", None) is enforce_route_policy for dependency in route.dependencies):
|
||||||
|
continue
|
||||||
|
route.dependencies.append(Depends(enforce_route_policy))
|
||||||
|
_rebuild_route(route)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["install_route_guards"]
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
"""Validation helpers for config-driven auth route policies."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.routing import APIRoute
|
||||||
|
|
||||||
|
from app.plugins.auth.injection.registry_loader import RoutePolicyRegistry
|
||||||
|
|
||||||
|
_IGNORED_METHODS = frozenset({"HEAD", "OPTIONS"})
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_route_keys(app: FastAPI) -> set[tuple[str, str]]:
|
||||||
|
keys: set[tuple[str, str]] = set()
|
||||||
|
for route in app.routes:
|
||||||
|
if not isinstance(route, APIRoute):
|
||||||
|
continue
|
||||||
|
for method in route.methods:
|
||||||
|
if method in _IGNORED_METHODS:
|
||||||
|
continue
|
||||||
|
keys.add((method, route.path))
|
||||||
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
def validate_route_policy_registry(app: FastAPI, registry: RoutePolicyRegistry) -> None:
|
||||||
|
route_keys = _iter_route_keys(app)
|
||||||
|
missing = sorted(route_keys - registry.keys)
|
||||||
|
extra = sorted(registry.keys - route_keys)
|
||||||
|
problems: list[str] = []
|
||||||
|
if missing:
|
||||||
|
problems.append("Missing route policy entries:\n" + "\n".join(f" - {method} {path}" for method, path in missing))
|
||||||
|
if extra:
|
||||||
|
problems.append("Unknown route policy entries:\n" + "\n".join(f" - {method} {path}" for method, path in extra))
|
||||||
|
if problems:
|
||||||
|
raise RuntimeError("\n\n".join(problems))
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["validate_route_policy_registry"]
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""Operational tooling for the auth plugin."""
|
||||||
|
|
||||||
|
from app.plugins.auth.ops.credential_file import write_initial_credentials
|
||||||
|
from app.plugins.auth.ops.reset_admin import main
|
||||||
|
|
||||||
|
__all__ = ["main", "write_initial_credentials"]
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
"""Write initial admin credentials to a restricted file instead of logs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from deerflow.config.paths import get_paths
|
||||||
|
|
||||||
|
_CREDENTIAL_FILENAME = "admin_initial_credentials.txt"
|
||||||
|
|
||||||
|
|
||||||
|
def write_initial_credentials(email: str, password: str, *, label: str = "initial") -> Path:
|
||||||
|
target = get_paths().base_dir / _CREDENTIAL_FILENAME
|
||||||
|
target.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
content = (
|
||||||
|
f"# DeerFlow admin {label} credentials\n# This file is generated on first boot or password reset.\n# Change the password after login via Settings -> Account,\n# then delete this file.\n#\nemail: {email}\npassword: {password}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
fd = os.open(target, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||||
|
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
||||||
|
fh.write(content)
|
||||||
|
|
||||||
|
return target.resolve()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["write_initial_credentials"]
|
||||||
@@ -0,0 +1,74 @@
|
|||||||
|
"""CLI tool to reset an admin password."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import secrets
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.plugins.auth.domain.password import hash_password
|
||||||
|
from app.plugins.auth.ops.credential_file import write_initial_credentials
|
||||||
|
from app.plugins.auth.storage import DbUserRepository
|
||||||
|
from app.plugins.auth.storage.models import User as UserModel
|
||||||
|
|
||||||
|
|
||||||
|
async def _run(email: str | None) -> int:
|
||||||
|
from store.persistence import create_persistence
|
||||||
|
|
||||||
|
app_persistence = await create_persistence()
|
||||||
|
await app_persistence.setup()
|
||||||
|
try:
|
||||||
|
if email:
|
||||||
|
async with app_persistence.session_factory() as session:
|
||||||
|
repo = DbUserRepository(session)
|
||||||
|
user = await repo.get_user_by_email(email)
|
||||||
|
else:
|
||||||
|
async with app_persistence.session_factory() as session:
|
||||||
|
stmt = select(UserModel).where(UserModel.system_role == "admin").limit(1)
|
||||||
|
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
user = None
|
||||||
|
else:
|
||||||
|
repo = DbUserRepository(session)
|
||||||
|
user = await repo.get_user_by_id(row.id)
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
print(f"Error: user '{email}' not found." if email else "Error: no admin user found.", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
new_password = secrets.token_urlsafe(16)
|
||||||
|
updated_user = user.model_copy(
|
||||||
|
update={
|
||||||
|
"password_hash": hash_password(new_password),
|
||||||
|
"token_version": user.token_version + 1,
|
||||||
|
"needs_setup": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
async with app_persistence.session_factory() as session:
|
||||||
|
repo = DbUserRepository(session)
|
||||||
|
await repo.update_user(updated_user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
cred_path = write_initial_credentials(user.email, new_password, label="reset")
|
||||||
|
print(f"Password reset for: {user.email}")
|
||||||
|
print(f"Credentials written to: {cred_path} (mode 0600)")
|
||||||
|
print("Next login will require setup (new email + password).")
|
||||||
|
return 0
|
||||||
|
finally:
|
||||||
|
await app_persistence.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="Reset admin password")
|
||||||
|
parser.add_argument("--email", help="Admin email (default: first admin found)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
exit_code = asyncio.run(_run(args.email))
|
||||||
|
sys.exit(exit_code)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
[plugin]
|
||||||
|
name = "auth"
|
||||||
|
summary = "Cookie-based authentication and authorization"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Owns DeerFlow authentication, authorization adapters, and auth storage definitions while reusing shared persistence infrastructure."
|
||||||
|
author = "DeerFlow"
|
||||||
|
tags = ["auth", "gateway", "session"]
|
||||||
|
|
||||||
|
[capabilities]
|
||||||
|
router = true
|
||||||
|
middleware = true
|
||||||
|
dependencies = true
|
||||||
|
langgraph_adapter = true
|
||||||
|
storage = true
|
||||||
|
|
||||||
|
[storage]
|
||||||
|
mode = "shared_infrastructure"
|
||||||
|
notes = "This plugin owns its storage definitions and repositories but uses the application's shared engine and session factory."
|
||||||
@@ -0,0 +1,204 @@
|
|||||||
|
public:
|
||||||
|
- method: POST
|
||||||
|
path: /api/v1/auth/login/local
|
||||||
|
- method: POST
|
||||||
|
path: /api/v1/auth/register
|
||||||
|
- method: POST
|
||||||
|
path: /api/v1/auth/logout
|
||||||
|
- method: GET
|
||||||
|
path: /api/v1/auth/setup-status
|
||||||
|
- method: POST
|
||||||
|
path: /api/v1/auth/initialize
|
||||||
|
- method: GET
|
||||||
|
path: /api/v1/auth/oauth/{provider}
|
||||||
|
- method: GET
|
||||||
|
path: /api/v1/auth/callback/{provider}
|
||||||
|
- method: GET
|
||||||
|
path: /docs
|
||||||
|
|
||||||
|
auth:
|
||||||
|
- method: POST
|
||||||
|
path: /api/v1/auth/change-password
|
||||||
|
- method: GET
|
||||||
|
path: /api/v1/auth/me
|
||||||
|
|
||||||
|
threads:
|
||||||
|
- method: POST
|
||||||
|
path: /api/threads
|
||||||
|
capability: threads:write
|
||||||
|
- method: POST
|
||||||
|
path: /api/threads/search
|
||||||
|
capability: threads:read
|
||||||
|
- method: DELETE
|
||||||
|
path: /api/threads/{thread_id}
|
||||||
|
capability: threads:delete
|
||||||
|
policies: [owner:thread]
|
||||||
|
require_existing: false
|
||||||
|
- method: GET
|
||||||
|
path: /api/threads/{thread_id}/state
|
||||||
|
capability: threads:read
|
||||||
|
policies: [owner:thread]
|
||||||
|
- method: POST
|
||||||
|
path: /api/threads/{thread_id}/state
|
||||||
|
capability: threads:write
|
||||||
|
policies: [owner:thread]
|
||||||
|
- method: POST
|
||||||
|
path: /api/threads/{thread_id}/history
|
||||||
|
capability: threads:read
|
||||||
|
policies: [owner:thread]
|
||||||
|
|
||||||
|
runs:
|
||||||
|
- method: GET
|
||||||
|
path: /api/threads/{thread_id}/runs
|
||||||
|
capability: runs:read
|
||||||
|
policies: [owner:thread]
|
||||||
|
- method: GET
|
||||||
|
path: /api/threads/{thread_id}/runs/{run_id}
|
||||||
|
capability: runs:read
|
||||||
|
policies: [owner:run]
|
||||||
|
- method: GET
|
||||||
|
path: /api/threads/{thread_id}/runs/{run_id}/messages
|
||||||
|
capability: runs:read
|
||||||
|
policies: [owner:run]
|
||||||
|
- method: POST
|
||||||
|
path: /api/threads/{thread_id}/runs
|
||||||
|
capability: runs:create
|
||||||
|
policies: [owner:thread]
|
||||||
|
- method: POST
|
||||||
|
path: /api/threads/{thread_id}/runs/stream
|
||||||
|
capability: runs:create
|
||||||
|
policies: [owner:thread]
|
||||||
|
- method: POST
|
||||||
|
path: /api/threads/{thread_id}/runs/wait
|
||||||
|
capability: runs:create
|
||||||
|
policies: [owner:thread]
|
||||||
|
- method: POST
|
||||||
|
path: /api/threads/runs
|
||||||
|
capability: runs:create
|
||||||
|
- method: POST
|
||||||
|
path: /api/threads/runs/stream
|
||||||
|
capability: runs:create
|
||||||
|
- method: POST
|
||||||
|
path: /api/threads/runs/wait
|
||||||
|
capability: runs:create
|
||||||
|
- methods: [GET, POST]
|
||||||
|
path: /api/threads/{thread_id}/runs/{run_id}/stream
|
||||||
|
capability: runs:read
|
||||||
|
policies: [owner:run]
|
||||||
|
- method: GET
|
||||||
|
path: /api/threads/{thread_id}/runs/{run_id}/join
|
||||||
|
capability: runs:read
|
||||||
|
policies: [owner:run]
|
||||||
|
- method: POST
|
||||||
|
path: /api/threads/{thread_id}/runs/{run_id}/cancel
|
||||||
|
capability: runs:cancel
|
||||||
|
policies: [owner:run]
|
||||||
|
- method: DELETE
|
||||||
|
path: /api/threads/{thread_id}/runs/{run_id}
|
||||||
|
capability: runs:cancel
|
||||||
|
policies: [owner:run]
|
||||||
|
|
||||||
|
feedback:
|
||||||
|
- method: PUT
|
||||||
|
path: /api/threads/{thread_id}/runs/{run_id}/feedback
|
||||||
|
policies: [owner:run]
|
||||||
|
- method: POST
|
||||||
|
path: /api/threads/{thread_id}/runs/{run_id}/feedback
|
||||||
|
policies: [owner:run]
|
||||||
|
- method: GET
|
||||||
|
path: /api/threads/{thread_id}/runs/{run_id}/feedback
|
||||||
|
policies: [owner:run]
|
||||||
|
- method: GET
|
||||||
|
path: /api/threads/{thread_id}/runs/{run_id}/feedback/stats
|
||||||
|
policies: [owner:run]
|
||||||
|
- method: DELETE
|
||||||
|
path: /api/threads/{thread_id}/runs/{run_id}/feedback
|
||||||
|
policies: [owner:run]
|
||||||
|
- method: DELETE
|
||||||
|
path: /api/threads/{thread_id}/runs/{run_id}/feedback/{feedback_id}
|
||||||
|
policies: [owner:run]
|
||||||
|
|
||||||
|
suggestions:
|
||||||
|
- method: POST
|
||||||
|
path: /api/threads/{thread_id}/suggestions
|
||||||
|
capability: threads:read
|
||||||
|
policies: [owner:thread]
|
||||||
|
|
||||||
|
uploads:
|
||||||
|
- method: POST
|
||||||
|
path: /api/threads/{thread_id}/uploads
|
||||||
|
capability: threads:write
|
||||||
|
policies: [owner:thread]
|
||||||
|
require_existing: false
|
||||||
|
- method: GET
|
||||||
|
path: /api/threads/{thread_id}/uploads/list
|
||||||
|
capability: threads:read
|
||||||
|
policies: [owner:thread]
|
||||||
|
- method: DELETE
|
||||||
|
path: /api/threads/{thread_id}/uploads/{filename}
|
||||||
|
capability: threads:delete
|
||||||
|
policies: [owner:thread]
|
||||||
|
|
||||||
|
artifacts:
|
||||||
|
- method: GET
|
||||||
|
path: /api/threads/{thread_id}/artifacts/{path:path}
|
||||||
|
capability: threads:read
|
||||||
|
policies: [owner:thread]
|
||||||
|
|
||||||
|
agents:
|
||||||
|
- method: GET
|
||||||
|
path: /api/agents
|
||||||
|
- method: GET
|
||||||
|
path: /api/agents/check
|
||||||
|
- method: GET
|
||||||
|
path: /api/agents/{name}
|
||||||
|
- method: POST
|
||||||
|
path: /api/agents
|
||||||
|
- method: PUT
|
||||||
|
path: /api/agents/{name}
|
||||||
|
- method: GET
|
||||||
|
path: /api/user-profile
|
||||||
|
- method: PUT
|
||||||
|
path: /api/user-profile
|
||||||
|
- method: DELETE
|
||||||
|
path: /api/agents/{name}
|
||||||
|
|
||||||
|
channels:
|
||||||
|
- method: GET
|
||||||
|
path: /api/channels/
|
||||||
|
- method: POST
|
||||||
|
path: /api/channels/{name}/restart
|
||||||
|
|
||||||
|
mcp:
|
||||||
|
- method: GET
|
||||||
|
path: /api/mcp/config
|
||||||
|
- method: PUT
|
||||||
|
path: /api/mcp/config
|
||||||
|
|
||||||
|
models:
|
||||||
|
- method: GET
|
||||||
|
path: /api/models
|
||||||
|
- method: GET
|
||||||
|
path: /api/models/{model_name}
|
||||||
|
|
||||||
|
skills:
|
||||||
|
- method: GET
|
||||||
|
path: /api/skills
|
||||||
|
- method: POST
|
||||||
|
path: /api/skills/install
|
||||||
|
- method: GET
|
||||||
|
path: /api/skills/custom
|
||||||
|
- method: GET
|
||||||
|
path: /api/skills/custom/{skill_name}
|
||||||
|
- method: PUT
|
||||||
|
path: /api/skills/custom/{skill_name}
|
||||||
|
- method: DELETE
|
||||||
|
path: /api/skills/custom/{skill_name}
|
||||||
|
- method: GET
|
||||||
|
path: /api/skills/custom/{skill_name}/history
|
||||||
|
- method: POST
|
||||||
|
path: /api/skills/custom/{skill_name}/rollback
|
||||||
|
- method: GET
|
||||||
|
path: /api/skills/{skill_name}
|
||||||
|
- method: PUT
|
||||||
|
path: /api/skills/{skill_name}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
"""Runtime state utilities for the auth plugin."""
|
||||||
|
|
||||||
|
from app.plugins.auth.runtime.config_state import get_auth_config, reset_auth_config, set_auth_config
|
||||||
|
|
||||||
|
__all__ = ["get_auth_config", "reset_auth_config", "set_auth_config"]
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
"""Runtime state holder for auth configuration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.plugins.auth.domain.config import AuthConfig, load_auth_config_from_env
|
||||||
|
|
||||||
|
_auth_config: AuthConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_config() -> AuthConfig:
|
||||||
|
global _auth_config
|
||||||
|
if _auth_config is None:
|
||||||
|
_auth_config = load_auth_config_from_env()
|
||||||
|
return _auth_config
|
||||||
|
|
||||||
|
|
||||||
|
def set_auth_config(config: AuthConfig) -> None:
|
||||||
|
global _auth_config
|
||||||
|
_auth_config = config
|
||||||
|
|
||||||
|
|
||||||
|
def reset_auth_config() -> None:
|
||||||
|
global _auth_config
|
||||||
|
_auth_config = None
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["get_auth_config", "reset_auth_config", "set_auth_config"]
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
"""Security layer for the auth plugin."""
|
||||||
|
|
||||||
|
from app.plugins.auth.security.actor_context import (
|
||||||
|
bind_request_actor_context,
|
||||||
|
bind_user_actor_context,
|
||||||
|
resolve_request_user_id,
|
||||||
|
)
|
||||||
|
from app.plugins.auth.security.csrf import (
|
||||||
|
CSRF_COOKIE_NAME,
|
||||||
|
CSRF_HEADER_NAME,
|
||||||
|
CSRFMiddleware,
|
||||||
|
get_csrf_token,
|
||||||
|
is_secure_request,
|
||||||
|
)
|
||||||
|
from app.plugins.auth.security.dependencies import (
|
||||||
|
CurrentAuthService,
|
||||||
|
CurrentUserRepository,
|
||||||
|
get_auth_service,
|
||||||
|
get_current_user_from_request,
|
||||||
|
get_current_user_id,
|
||||||
|
get_optional_user_from_request,
|
||||||
|
get_user_repository,
|
||||||
|
)
|
||||||
|
from app.plugins.auth.security.langgraph import add_owner_filter, auth, authenticate
|
||||||
|
from app.plugins.auth.security.middleware import AuthMiddleware
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CSRF_COOKIE_NAME",
|
||||||
|
"CSRF_HEADER_NAME",
|
||||||
|
"CSRFMiddleware",
|
||||||
|
"AuthMiddleware",
|
||||||
|
"CurrentAuthService",
|
||||||
|
"CurrentUserRepository",
|
||||||
|
"add_owner_filter",
|
||||||
|
"auth",
|
||||||
|
"authenticate",
|
||||||
|
"bind_request_actor_context",
|
||||||
|
"bind_user_actor_context",
|
||||||
|
"get_auth_service",
|
||||||
|
"get_csrf_token",
|
||||||
|
"get_current_user_from_request",
|
||||||
|
"get_current_user_id",
|
||||||
|
"get_optional_user_from_request",
|
||||||
|
"get_user_repository",
|
||||||
|
"is_secure_request",
|
||||||
|
"resolve_request_user_id",
|
||||||
|
]
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
"""Auth-plugin bridge from request user to runtime actor context."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from deerflow.runtime.actor_context import ActorContext, bind_actor_context, reset_actor_context
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_request_user_id(request: Request) -> str | None:
|
||||||
|
scope = getattr(request, "scope", None)
|
||||||
|
user = scope.get("user") if isinstance(scope, dict) else None
|
||||||
|
if user is None:
|
||||||
|
state = getattr(request, "state", None)
|
||||||
|
state_vars = vars(state) if state is not None and hasattr(state, "__dict__") else {}
|
||||||
|
user = state_vars.get("user")
|
||||||
|
user_id = getattr(user, "id", None)
|
||||||
|
if user_id is None:
|
||||||
|
return None
|
||||||
|
return str(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def bind_request_actor_context(request: Request):
|
||||||
|
token = bind_actor_context(ActorContext(user_id=resolve_request_user_id(request)))
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
reset_actor_context(token)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def bind_user_actor_context(user_id: str | None):
|
||||||
|
token = bind_actor_context(ActorContext(user_id=str(user_id) if user_id is not None else None))
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
reset_actor_context(token)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["bind_request_actor_context", "bind_user_actor_context", "resolve_request_user_id"]
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
"""CSRF protection middleware and helpers for cookie-based auth flows."""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
CSRF_COOKIE_NAME = "csrf_token"
|
||||||
|
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||||
|
CSRF_TOKEN_LENGTH = 64 # bytes
|
||||||
|
|
||||||
|
|
||||||
|
def is_secure_request(request: Request) -> bool:
|
||||||
|
"""Detect whether the original client request was made over HTTPS."""
|
||||||
|
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_csrf_token() -> str:
|
||||||
|
"""Generate a secure random CSRF token."""
|
||||||
|
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
||||||
|
|
||||||
|
|
||||||
|
def should_check_csrf(request: Request) -> bool:
|
||||||
|
"""Determine if a request needs CSRF validation."""
|
||||||
|
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
path = request.url.path.rstrip("/")
|
||||||
|
if path == "/api/v1/auth/me":
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
"/api/v1/auth/logout",
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
"/api/v1/auth/initialize",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_auth_endpoint(request: Request) -> bool:
|
||||||
|
"""Check if the request is to an auth endpoint."""
|
||||||
|
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
||||||
|
|
||||||
|
|
||||||
|
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Implement CSRF protection using the double-submit cookie pattern."""
|
||||||
|
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
super().__init__(app)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
_is_auth = is_auth_endpoint(request)
|
||||||
|
|
||||||
|
if should_check_csrf(request) and not _is_auth:
|
||||||
|
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||||
|
header_token = request.headers.get(CSRF_HEADER_NAME)
|
||||||
|
|
||||||
|
if not cookie_token or not header_token:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=403,
|
||||||
|
content={"detail": "CSRF token missing. Include X-CSRF-Token header."},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not secrets.compare_digest(cookie_token, header_token):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=403,
|
||||||
|
content={"detail": "CSRF token mismatch."},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
if _is_auth and request.method == "POST":
|
||||||
|
csrf_token = generate_csrf_token()
|
||||||
|
response.set_cookie(
|
||||||
|
key=CSRF_COOKIE_NAME,
|
||||||
|
value=csrf_token,
|
||||||
|
httponly=False,
|
||||||
|
secure=is_secure_request(request),
|
||||||
|
samesite="strict",
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def get_csrf_token(request: Request) -> str | None:
|
||||||
|
"""Get the CSRF token from the current request's cookies."""
|
||||||
|
return request.cookies.get(CSRF_COOKIE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CSRF_COOKIE_NAME",
|
||||||
|
"CSRF_HEADER_NAME",
|
||||||
|
"CSRFMiddleware",
|
||||||
|
"generate_csrf_token",
|
||||||
|
"get_csrf_token",
|
||||||
|
"is_auth_endpoint",
|
||||||
|
"is_secure_request",
|
||||||
|
"should_check_csrf",
|
||||||
|
]
|
||||||
@@ -0,0 +1,119 @@
|
|||||||
|
"""Security dependency helpers for the auth plugin."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from app.plugins.auth.domain.errors import (
|
||||||
|
AuthErrorCode,
|
||||||
|
AuthErrorResponse,
|
||||||
|
TokenError,
|
||||||
|
token_error_to_code,
|
||||||
|
)
|
||||||
|
from app.plugins.auth.domain.jwt import decode_token
|
||||||
|
from app.plugins.auth.domain.service import AuthService
|
||||||
|
from app.plugins.auth.storage import DbUserRepository, UserRepositoryProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session_factory(request: Request) -> async_sessionmaker[AsyncSession] | None:
|
||||||
|
persistence = getattr(request.app.state, "persistence", None)
|
||||||
|
if persistence is None:
|
||||||
|
return None
|
||||||
|
return getattr(persistence, "session_factory", None)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _auth_session(request: Request) -> AsyncIterator[AsyncSession]:
|
||||||
|
injected = getattr(request.state, "_auth_session", None)
|
||||||
|
if injected is not None:
|
||||||
|
yield injected
|
||||||
|
return
|
||||||
|
|
||||||
|
session_factory = _get_session_factory(request)
|
||||||
|
if session_factory is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Auth session not available")
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_repository(request: Request) -> UserRepositoryProtocol:
|
||||||
|
async with _auth_session(request) as session:
|
||||||
|
return DbUserRepository(session)
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_service(request: Request) -> AuthService:
|
||||||
|
session_factory = _get_session_factory(request)
|
||||||
|
if session_factory is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Auth session factory not available")
|
||||||
|
return AuthService(session_factory)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user_from_request(request: Request):
|
||||||
|
access_token = request.cookies.get("access_token")
|
||||||
|
if not access_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = decode_token(access_token)
|
||||||
|
if isinstance(payload, TokenError):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail=AuthErrorResponse(
|
||||||
|
code=token_error_to_code(payload),
|
||||||
|
message=f"Token error: {payload.value}",
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async with _auth_session(request) as session:
|
||||||
|
user_repo = DbUserRepository(session)
|
||||||
|
user = await user_repo.get_user_by_id(payload.sub)
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if user.token_version != payload.ver:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail=AuthErrorResponse(
|
||||||
|
code=AuthErrorCode.TOKEN_INVALID,
|
||||||
|
message="Token revoked (password changed)",
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_optional_user_from_request(request: Request):
|
||||||
|
try:
|
||||||
|
return await get_current_user_from_request(request)
|
||||||
|
except HTTPException:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user_id(request: Request) -> str | None:
|
||||||
|
user = await get_optional_user_from_request(request)
|
||||||
|
return user.id if user else None
|
||||||
|
|
||||||
|
|
||||||
|
CurrentUserRepository = Annotated[UserRepositoryProtocol, Depends(get_user_repository)]
|
||||||
|
CurrentAuthService = Annotated[AuthService, Depends(get_auth_service)]
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CurrentAuthService",
|
||||||
|
"CurrentUserRepository",
|
||||||
|
"get_auth_service",
|
||||||
|
"get_current_user_from_request",
|
||||||
|
"get_current_user_id",
|
||||||
|
"get_optional_user_from_request",
|
||||||
|
"get_user_repository",
|
||||||
|
]
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
"""LangGraph auth adapter for the auth plugin."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from langgraph_sdk import Auth
|
||||||
|
|
||||||
|
from app.plugins.auth.security.dependencies import get_current_user_from_request
|
||||||
|
|
||||||
|
auth = Auth()
|
||||||
|
|
||||||
|
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
|
||||||
|
|
||||||
|
|
||||||
|
def _check_csrf(request) -> None:
|
||||||
|
method = getattr(request, "method", "") or ""
|
||||||
|
if method.upper() not in _CSRF_METHODS:
|
||||||
|
return
|
||||||
|
|
||||||
|
cookie_token = request.cookies.get("csrf_token")
|
||||||
|
header_token = request.headers.get("x-csrf-token")
|
||||||
|
|
||||||
|
if not cookie_token or not header_token:
|
||||||
|
raise Auth.exceptions.HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="CSRF token missing. Include X-CSRF-Token header.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not secrets.compare_digest(cookie_token, header_token):
|
||||||
|
raise Auth.exceptions.HTTPException(status_code=403, detail="CSRF token mismatch.")
|
||||||
|
|
||||||
|
|
||||||
|
@auth.authenticate
|
||||||
|
async def authenticate(request):
|
||||||
|
_check_csrf(request)
|
||||||
|
resolver_request = SimpleNamespace(
|
||||||
|
cookies=getattr(request, "cookies", {}),
|
||||||
|
state=SimpleNamespace(_auth_session=getattr(request, "_auth_session", None)),
|
||||||
|
app=SimpleNamespace(state=SimpleNamespace(persistence=getattr(request, "_persistence", None))),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user = await get_current_user_from_request(resolver_request)
|
||||||
|
except Exception as exc:
|
||||||
|
status_code = getattr(exc, "status_code", None)
|
||||||
|
if status_code is None:
|
||||||
|
raise
|
||||||
|
detail = getattr(exc, "detail", "Not authenticated")
|
||||||
|
message = detail.get("message") if isinstance(detail, dict) else str(detail)
|
||||||
|
raise Auth.exceptions.HTTPException(status_code=status_code, detail=message) from exc
|
||||||
|
|
||||||
|
return user.id
|
||||||
|
|
||||||
|
|
||||||
|
@auth.on
|
||||||
|
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
||||||
|
metadata = value.setdefault("metadata", {})
|
||||||
|
metadata["user_id"] = ctx.user.identity
|
||||||
|
return {"user_id": ctx.user.identity}
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["add_owner_filter", "auth", "authenticate"]
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
"""Global authentication middleware for the auth plugin."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request, Response
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
from app.plugins.auth.authorization import _ALL_PERMISSIONS, AuthContext
|
||||||
|
from app.plugins.auth.domain.errors import AuthErrorCode, AuthErrorResponse
|
||||||
|
from app.plugins.auth.injection.registry_loader import RoutePolicyRegistry
|
||||||
|
from app.plugins.auth.security.dependencies import get_current_user_from_request
|
||||||
|
from deerflow.runtime.actor_context import ActorContext, bind_actor_context, reset_actor_context
|
||||||
|
|
||||||
|
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = ("/health", "/docs", "/redoc", "/openapi.json")
|
||||||
|
|
||||||
|
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
"/api/v1/auth/logout",
|
||||||
|
"/api/v1/auth/setup-status",
|
||||||
|
"/api/v1/auth/initialize",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_public(path: str) -> bool:
|
||||||
|
stripped = path.rstrip("/")
|
||||||
|
if stripped in _PUBLIC_EXACT_PATHS:
|
||||||
|
return True
|
||||||
|
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
super().__init__(app)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
registry = getattr(request.app.state, "auth_route_policy_registry", None)
|
||||||
|
is_public = False
|
||||||
|
if isinstance(registry, RoutePolicyRegistry):
|
||||||
|
is_public = registry.is_public_request(request.method, request.url.path)
|
||||||
|
if is_public or _is_public(request.url.path):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
if not request.cookies.get("access_token"):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={
|
||||||
|
"detail": AuthErrorResponse(
|
||||||
|
code=AuthErrorCode.NOT_AUTHENTICATED,
|
||||||
|
message="Authentication required",
|
||||||
|
).model_dump()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user = await get_current_user_from_request(request)
|
||||||
|
except HTTPException as exc:
|
||||||
|
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||||
|
|
||||||
|
auth_context = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
||||||
|
request.scope["user"] = user
|
||||||
|
request.scope["auth"] = auth_context
|
||||||
|
request.state.user = user
|
||||||
|
request.state.auth = auth_context
|
||||||
|
token = bind_actor_context(ActorContext(user_id=str(user.id)))
|
||||||
|
try:
|
||||||
|
return await call_next(request)
|
||||||
|
finally:
|
||||||
|
reset_actor_context(token)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["AuthMiddleware", "_is_public"]
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
"""Auth plugin storage package.
|
||||||
|
|
||||||
|
This package owns auth-specific ORM models and repositories while
|
||||||
|
continuing to use the application's shared persistence infrastructure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.plugins.auth.storage.contracts import User, UserCreate, UserRepositoryProtocol
|
||||||
|
from app.plugins.auth.storage.models import User as UserModel
|
||||||
|
from app.plugins.auth.storage.repositories import DbUserRepository
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DbUserRepository",
|
||||||
|
"User",
|
||||||
|
"UserCreate",
|
||||||
|
"UserModel",
|
||||||
|
"UserRepositoryProtocol",
|
||||||
|
]
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user