mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 08:55:59 +00:00
feat(auth): authentication module with multi-tenant isolation (RFC-001)
Introduce an always-on auth layer with auto-created admin on first boot, multi-tenant isolation for threads/stores, and a full setup/login flow. Backend - JWT access tokens with `ver` field for stale-token rejection; bump on password/email change - Password hashing, HttpOnly+Secure cookies (Secure derived from request scheme at runtime) - CSRF middleware covering both REST and LangGraph routes - IP-based login rate limiting (5 attempts / 5-min lockout) with bounded dict growth and X-Forwarded-For bypass fix - Multi-worker-safe admin auto-creation (single DB write, WAL once) - needs_setup + token_version on User model; SQLite schema migration - Thread/store isolation by owner; orphan thread migration on first admin registration - thread_id validated as UUID to prevent log injection - CLI tool to reset admin password - Decorator-based authz module extracted from auth core Frontend - Login and setup pages with SSR guard for needs_setup flow - Account settings page (change password / email) - AuthProvider + route guards; skips redirect when no users registered - i18n (en-US / zh-CN) for auth surfaces - Typed auth API client; parseAuthError unwraps FastAPI detail envelope Infra & tooling - Unified `serve.sh` with gateway mode + auto dep install - Public PyPI uv.toml pin for CI compatibility - Regenerated uv.lock with public index Tests - HTTP vs HTTPS cookie security tests - Auth middleware, rate limiter, CSRF, setup flow coverage
This commit is contained in:
@@ -13,17 +13,26 @@ matching the LangGraph Platform wire format expected by the
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import APIRouter, HTTPException, Path, Request
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.gateway.authz import require_auth, require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_store
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread ID validation (prevents log-injection via control characters)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
|
||||
ThreadId = Annotated[str, Path(description="Thread UUID", pattern=_UUID_RE.pattern)]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Store namespace
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -65,6 +74,13 @@ class ThreadCreateRequest(BaseModel):
|
||||
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
||||
|
||||
@field_validator("thread_id")
|
||||
@classmethod
|
||||
def _validate_uuid(cls, v: str | None) -> str | None:
|
||||
if v is not None and not _UUID_RE.match(v):
|
||||
raise ValueError("thread_id must be a valid UUID")
|
||||
return v
|
||||
|
||||
|
||||
class ThreadSearchRequest(BaseModel):
|
||||
"""Request body for searching threads."""
|
||||
@@ -215,17 +231,23 @@ def _derive_thread_status(checkpoint_tuple) -> str:
|
||||
|
||||
|
||||
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
|
||||
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
|
||||
@require_auth
|
||||
@require_permission("threads", "delete", owner_check=True)
|
||||
async def delete_thread_data(thread_id: ThreadId, request: Request) -> ThreadDeleteResponse:
|
||||
"""Delete local persisted filesystem data for a thread.
|
||||
|
||||
Cleans DeerFlow-managed thread directories, removes checkpoint data,
|
||||
and removes the thread record from the Store.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can delete their thread.
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
# Clean local filesystem
|
||||
response = _delete_thread_data(thread_id)
|
||||
|
||||
# Remove from Store (best-effort)
|
||||
store = get_store(request)
|
||||
if store is not None:
|
||||
try:
|
||||
await store.adelete(THREADS_NS, thread_id)
|
||||
@@ -233,7 +255,6 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
||||
logger.debug("Could not delete store record for thread %s (not critical)", thread_id)
|
||||
|
||||
# Remove checkpoints (best-effort)
|
||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
||||
if checkpointer is not None:
|
||||
try:
|
||||
if hasattr(checkpointer, "adelete_thread"):
|
||||
@@ -251,12 +272,23 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
The thread record is written to the Store (for fast listing) and an
|
||||
empty checkpoint is written to the checkpointer (for state reads).
|
||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
||||
|
||||
If authenticated, the user's ID is injected into the thread metadata
|
||||
for multi-tenant isolation.
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = time.time()
|
||||
|
||||
from app.gateway.deps import get_optional_user_from_request
|
||||
|
||||
user = await get_optional_user_from_request(request)
|
||||
|
||||
thread_metadata = dict(body.metadata)
|
||||
if user:
|
||||
thread_metadata["user_id"] = str(user.id)
|
||||
|
||||
# Idempotency: return existing record from Store when already present
|
||||
if store is not None:
|
||||
existing_record = await _store_get(store, thread_id)
|
||||
@@ -279,7 +311,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
"status": "idle",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"metadata": body.metadata,
|
||||
"metadata": thread_metadata,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
@@ -296,7 +328,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
"source": "input",
|
||||
"writes": None,
|
||||
"parents": {},
|
||||
**body.metadata,
|
||||
**thread_metadata,
|
||||
"created_at": now,
|
||||
}
|
||||
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
|
||||
@@ -304,13 +336,13 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
logger.exception("Failed to create checkpoint for thread %s", thread_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||
|
||||
logger.info("Thread created: %s", thread_id)
|
||||
logger.info("Thread created: %s (user_id=%s)", thread_id, thread_metadata.get("user_id"))
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status="idle",
|
||||
created_at=str(now),
|
||||
updated_at=str(now),
|
||||
metadata=body.metadata,
|
||||
metadata=thread_metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -330,10 +362,18 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
newly found thread is immediately written to the Store so that the next
|
||||
search skips Phase 2 for that thread — the Store converges to a full
|
||||
index over time without a one-shot migration job.
|
||||
|
||||
If authenticated, only threads belonging to the current user are returned
|
||||
(enforced by user_id metadata filter for multi-tenant isolation).
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
from app.gateway.deps import get_optional_user_from_request
|
||||
|
||||
user = await get_optional_user_from_request(request)
|
||||
user_id = str(user.id) if user else None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 1: Store
|
||||
# -----------------------------------------------------------------------
|
||||
@@ -409,6 +449,10 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
# -----------------------------------------------------------------------
|
||||
results = list(merged.values())
|
||||
|
||||
# Multi-tenant isolation: filter by user_id if authenticated
|
||||
if user_id:
|
||||
results = [r for r in results if r.metadata.get("user_id") == user_id]
|
||||
|
||||
if body.metadata:
|
||||
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
|
||||
|
||||
@@ -420,13 +464,20 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
|
||||
|
||||
@router.patch("/{thread_id}", response_model=ThreadResponse)
|
||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
||||
"""Merge metadata into a thread record."""
|
||||
@require_auth
|
||||
@require_permission("threads", "write", owner_check=True, inject_record=True)
|
||||
async def patch_thread(thread_id: ThreadId, request: Request, body: ThreadPatchRequest, thread_record: dict = None) -> ThreadResponse:
|
||||
"""Merge metadata into a thread record.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can patch their thread.
|
||||
"""
|
||||
store = get_store(request)
|
||||
if store is None:
|
||||
raise HTTPException(status_code=503, detail="Store not available")
|
||||
|
||||
record = await _store_get(store, thread_id)
|
||||
record = thread_record
|
||||
if record is None:
|
||||
record = await _store_get(store, thread_id)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
@@ -451,12 +502,17 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
|
||||
|
||||
|
||||
@router.get("/{thread_id}", response_model=ThreadResponse)
|
||||
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
@require_auth
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread(thread_id: ThreadId, request: Request) -> ThreadResponse:
|
||||
"""Get thread info.
|
||||
|
||||
Reads metadata from the Store and derives the accurate execution
|
||||
status from the checkpointer. Falls back to the checkpointer alone
|
||||
for threads that pre-date Store adoption (backward compat).
|
||||
|
||||
Multi-tenant isolation: returns 404 if the thread does not belong to
|
||||
the authenticated user.
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
@@ -488,26 +544,33 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
|
||||
}
|
||||
|
||||
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle") # type: ignore[union-attr]
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=status,
|
||||
created_at=str(record.get("created_at", "")), # type: ignore[union-attr]
|
||||
updated_at=str(record.get("updated_at", "")), # type: ignore[union-attr]
|
||||
metadata=record.get("metadata", {}), # type: ignore[union-attr]
|
||||
created_at=str(record.get("created_at", "")),
|
||||
updated_at=str(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
values=serialize_channel_values(channel_values),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
||||
@require_auth
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread_state(thread_id: ThreadId, request: Request) -> ThreadStateResponse:
|
||||
"""Get the latest state snapshot for a thread.
|
||||
|
||||
Channel values are serialized to ensure LangChain message objects
|
||||
are converted to JSON-safe dicts.
|
||||
|
||||
Multi-tenant isolation: returns 404 if thread does not belong to user.
|
||||
"""
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
@@ -552,12 +615,16 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
||||
|
||||
|
||||
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
|
||||
@require_auth
|
||||
@require_permission("threads", "write", owner_check=True)
|
||||
async def update_thread_state(thread_id: ThreadId, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
|
||||
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
|
||||
|
||||
Writes a new checkpoint that merges *body.values* into the latest
|
||||
channel values, then syncs any updated ``title`` field back to the Store
|
||||
so that ``/threads/search`` reflects the change immediately.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can update their thread.
|
||||
"""
|
||||
checkpointer = get_checkpointer(request)
|
||||
store = get_store(request)
|
||||
@@ -635,8 +702,13 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
|
||||
|
||||
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
|
||||
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
|
||||
"""Get checkpoint history for a thread."""
|
||||
@require_auth
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread_history(thread_id: ThreadId, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
|
||||
"""Get checkpoint history for a thread.
|
||||
|
||||
Multi-tenant isolation: returns 404 if thread does not belong to user.
|
||||
"""
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
Reference in New Issue
Block a user