feat(auth): authentication module with multi-tenant isolation (RFC-001)

Introduce an always-on auth layer with auto-created admin on first boot,
multi-tenant isolation for threads/stores, and a full setup/login flow.

Backend
- JWT access tokens with `ver` field for stale-token rejection; bump on
  password/email change
- Password hashing, HttpOnly+Secure cookies (Secure derived from request
  scheme at runtime)
- CSRF middleware covering both REST and LangGraph routes
- IP-based login rate limiting (5 attempts / 5-min lockout) with bounded
  dict growth and X-Forwarded-For bypass fix
- Multi-worker-safe admin auto-creation (single DB write, WAL once)
- needs_setup + token_version on User model; SQLite schema migration
- Thread/store isolation by owner; orphan thread migration on first admin
  registration
- thread_id validated as UUID to prevent log injection
- CLI tool to reset admin password
- Decorator-based authz module extracted from auth core

Frontend
- Login and setup pages with SSR guard for needs_setup flow
- Account settings page (change password / email)
- AuthProvider + route guards; skips redirect when no users registered
- i18n (en-US / zh-CN) for auth surfaces
- Typed auth API client; parseAuthError unwraps FastAPI detail envelope

Infra & tooling
- Unified `serve.sh` with gateway mode + auto dep install
- Public PyPI uv.toml pin for CI compatibility
- Regenerated uv.lock with public index

Tests
- HTTP vs HTTPS cookie security tests
- Auth middleware, rate limiter, CSRF, setup flow coverage
This commit is contained in:
greatmengqi
2026-04-08 00:31:43 +08:00
parent 636053fb6d
commit 27b66d6753
214 changed files with 18830 additions and 1065 deletions
+2 -2
View File
@@ -1,3 +1,3 @@
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads
from . import artifacts, assistants_compat, auth, mcp, models, skills, suggestions, thread_runs, threads, uploads
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
__all__ = ["artifacts", "assistants_compat", "auth", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
+4 -4
View File
@@ -24,7 +24,7 @@ class AgentResponse(BaseModel):
description: str = Field(default="", description="Agent description")
model: str | None = Field(default=None, description="Optional model override")
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
soul: str | None = Field(default=None, description="SOUL.md content (included on GET /{name})")
soul: str | None = Field(default=None, description="SOUL.md content")
class AgentsListResponse(BaseModel):
@@ -92,17 +92,17 @@ def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False
"/agents",
response_model=AgentsListResponse,
summary="List Custom Agents",
description="List all custom agents available in the agents directory.",
description="List all custom agents available in the agents directory, including their soul content.",
)
async def list_agents() -> AgentsListResponse:
"""List all custom agents.
Returns:
List of all custom agents with their metadata (without soul content).
List of all custom agents with their metadata and soul content.
"""
try:
agents = list_custom_agents()
return AgentsListResponse(agents=[_agent_config_to_response(a) for a in agents])
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
except Exception as e:
logger.error(f"Failed to list agents: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
+303
View File
@@ -0,0 +1,303 @@
"""Authentication endpoints."""
import logging
import time
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr, Field
from app.gateway.auth import (
UserResponse,
create_access_token,
)
from app.gateway.auth.config import get_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.csrf_middleware import is_secure_request
from app.gateway.deps import get_current_user_from_request, get_local_provider
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
# ── Request/Response Models ──────────────────────────────────────────────
class LoginResponse(BaseModel):
"""Response model for login — token only lives in HttpOnly cookie."""
expires_in: int # seconds
needs_setup: bool = False
class RegisterRequest(BaseModel):
"""Request model for user registration."""
email: EmailStr
password: str = Field(..., min_length=8)
class ChangePasswordRequest(BaseModel):
"""Request model for password change (also handles setup flow)."""
current_password: str
new_password: str = Field(..., min_length=8)
new_email: EmailStr | None = None
class MessageResponse(BaseModel):
"""Generic message response."""
message: str
# ── Helpers ───────────────────────────────────────────────────────────────
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
"""Set the access_token HttpOnly cookie on the response."""
config = get_auth_config()
is_https = is_secure_request(request)
response.set_cookie(
key="access_token",
value=token,
httponly=True,
secure=is_https,
samesite="lax",
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
)
# ── Rate Limiting ────────────────────────────────────────────────────────
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
_MAX_LOGIN_ATTEMPTS = 5
_LOCKOUT_SECONDS = 300 # 5 minutes
# ip → (fail_count, lock_until_timestamp)
_login_attempts: dict[str, tuple[int, float]] = {}
def _get_client_ip(request: Request) -> str:
"""Extract the real client IP for rate limiting.
Uses ``X-Real-IP`` header set by nginx (``proxy_set_header X-Real-IP
$remote_addr``). Nginx unconditionally overwrites any client-supplied
``X-Real-IP``, so the value seen by Gateway is always the TCP peer IP
that nginx observed — it cannot be spoofed by the client.
``request.client.host`` is NOT reliable because uvicorn's default
``proxy_headers=True`` replaces it with the *first* entry from
``X-Forwarded-For``, which IS client-spoofable.
``X-Forwarded-For`` is intentionally NOT used for the same reason.
"""
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
# Fallback: direct connection without nginx (e.g. unit tests, dev).
return request.client.host if request.client else "unknown"
def _check_rate_limit(ip: str) -> None:
"""Raise 429 if the IP is currently locked out."""
record = _login_attempts.get(ip)
if record is None:
return
fail_count, lock_until = record
if fail_count >= _MAX_LOGIN_ATTEMPTS:
if time.time() < lock_until:
raise HTTPException(
status_code=429,
detail="Too many login attempts. Try again later.",
)
del _login_attempts[ip]
_MAX_TRACKED_IPS = 10000
def _record_login_failure(ip: str) -> None:
"""Record a failed login attempt for the given IP."""
# Evict expired lockouts when dict grows too large
if len(_login_attempts) >= _MAX_TRACKED_IPS:
now = time.time()
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
for k in expired:
del _login_attempts[k]
# If still too large, evict cheapest-to-lose half: below-threshold
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
if len(_login_attempts) >= _MAX_TRACKED_IPS:
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
for k, _ in by_time[: len(by_time) // 2]:
del _login_attempts[k]
record = _login_attempts.get(ip)
if record is None:
_login_attempts[ip] = (1, 0.0)
else:
new_count = record[0] + 1
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
_login_attempts[ip] = (new_count, lock_until)
def _record_login_success(ip: str) -> None:
"""Clear failure counter for the given IP on successful login."""
_login_attempts.pop(ip, None)
# ── Endpoints ─────────────────────────────────────────────────────────────
@router.post("/login/local", response_model=LoginResponse)
async def login_local(
request: Request,
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
):
"""Local email/password login."""
client_ip = _get_client_ip(request)
_check_rate_limit(client_ip)
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
if user is None:
_record_login_failure(client_ip)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
)
_record_login_success(client_ip)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return LoginResponse(
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
needs_setup=user.needs_setup,
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(request: Request, response: Response, body: RegisterRequest):
"""Register a new user account (always 'user' role).
Admin is auto-created on first boot. This endpoint creates regular users.
Auto-login by setting the session cookie.
"""
try:
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
@router.post("/logout", response_model=MessageResponse)
async def logout(request: Request, response: Response):
"""Logout current user by clearing the cookie."""
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
return MessageResponse(message="Successfully logged out")
@router.post("/change-password", response_model=MessageResponse)
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
"""Change password for the currently authenticated user.
Also handles the first-boot setup flow:
- If new_email is provided, updates email (checks uniqueness)
- If user.needs_setup is True and new_email is given, clears needs_setup
- Always increments token_version to invalidate old sessions
- Re-issues session cookie with new token_version
"""
from app.gateway.auth.password import hash_password_async, verify_password_async
user = await get_current_user_from_request(request)
if user.password_hash is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
if not await verify_password_async(body.current_password, user.password_hash):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
provider = get_local_provider()
# Update email if provided
if body.new_email is not None:
existing = await provider.get_user_by_email(body.new_email)
if existing and str(existing.id) != str(user.id):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
user.email = body.new_email
# Update password + bump version
user.password_hash = await hash_password_async(body.new_password)
user.token_version += 1
# Clear setup flag if this is the setup flow
if user.needs_setup and body.new_email is not None:
user.needs_setup = False
await provider.update_user(user)
# Re-issue cookie with new token_version
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return MessageResponse(message="Password changed successfully")
@router.get("/me", response_model=UserResponse)
async def get_me(request: Request):
"""Get current authenticated user info."""
user = await get_current_user_from_request(request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
@router.get("/setup-status")
async def setup_status():
"""Check if admin account exists. Always False after first boot."""
user_count = await get_local_provider().count_users()
return {"needs_setup": user_count == 0}
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
@router.get("/oauth/{provider}")
async def oauth_login(provider: str):
"""Initiate OAuth login flow.
Redirects to the OAuth provider's authorization URL.
Currently a placeholder - requires OAuth provider implementation.
"""
if provider not in ["github", "google"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported OAuth provider: {provider}",
)
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="OAuth login not yet implemented",
)
@router.get("/callback/{provider}")
async def oauth_callback(provider: str, code: str, state: str):
"""OAuth callback endpoint.
Handles the OAuth provider's callback after user authorization.
Currently a placeholder.
"""
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="OAuth callback not yet implemented",
)
+6 -6
View File
@@ -2,6 +2,7 @@ import json
import logging
from fastapi import APIRouter
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from deerflow.models import create_chat_model
@@ -106,22 +107,21 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S
if not conversation:
return SuggestionsResponse(suggestions=[])
prompt = (
system_instruction = (
"You are generating follow-up questions to help the user continue the conversation.\n"
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
"Requirements:\n"
"- Questions must be relevant to the conversation.\n"
"- Questions must be relevant to the preceding conversation.\n"
"- Questions must be written in the same language as the user.\n"
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
"- Do NOT include numbering, markdown, or any extra text.\n"
"- Output MUST be a JSON array of strings only.\n\n"
"Conversation:\n"
f"{conversation}\n"
"- Output MUST be a JSON array of strings only.\n"
)
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try:
model = create_chat_model(name=request.model_name, thinking_enabled=False)
response = model.invoke(prompt)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or []
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
+39 -5
View File
@@ -19,6 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field
from app.gateway.authz import require_auth, require_permission
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import RunRecord, serialize_channel_values
@@ -92,19 +93,28 @@ def _record_to_response(record: RunRecord) -> RunResponse:
@router.post("/{thread_id}/runs", response_model=RunResponse)
@require_auth
@require_permission("runs", "create", owner_check=True)
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
"""Create a background run (returns immediately)."""
"""Create a background run (returns immediately).
Multi-tenant isolation: only the thread owner can create runs.
"""
record = await start_run(body, thread_id, request)
return _record_to_response(record)
@router.post("/{thread_id}/runs/stream")
@require_auth
@require_permission("runs", "create", owner_check=True)
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
"""Create a run and stream events via SSE.
The response includes a ``Content-Location`` header with the run's
resource URL, matching the LangGraph Platform protocol. The
``useStream`` React hook uses this to extract run metadata.
Multi-tenant isolation: only the thread owner can stream runs.
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
@@ -125,8 +135,13 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
@router.post("/{thread_id}/runs/wait", response_model=dict)
@require_auth
@require_permission("runs", "create", owner_check=True)
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until it completes, returning the final state."""
"""Create a run and block until it completes, returning the final state.
Multi-tenant isolation: only the thread owner can wait for runs.
"""
record = await start_run(body, thread_id, request)
if record.task is not None:
@@ -150,16 +165,26 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) ->
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
@require_auth
@require_permission("runs", "read", owner_check=True)
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
"""List all runs for a thread."""
"""List all runs for a thread.
Multi-tenant isolation: only the thread owner can list runs.
"""
run_mgr = get_run_manager(request)
records = await run_mgr.list_by_thread(thread_id)
return [_record_to_response(r) for r in records]
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
@require_auth
@require_permission("runs", "read", owner_check=True)
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
"""Get details of a specific run."""
"""Get details of a specific run.
Multi-tenant isolation: only the thread owner can get runs.
"""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
@@ -168,6 +193,8 @@ async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
@router.post("/{thread_id}/runs/{run_id}/cancel")
@require_auth
@require_permission("runs", "cancel", owner_check=True)
async def cancel_run(
thread_id: str,
run_id: str,
@@ -181,6 +208,8 @@ async def cancel_run(
- action=rollback: Stop execution, revert to pre-run checkpoint state
- wait=true: Block until the run fully stops, return 204
- wait=false: Return immediately with 202
Multi-tenant isolation: only the thread owner can cancel runs.
"""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
@@ -205,8 +234,13 @@ async def cancel_run(
@router.get("/{thread_id}/runs/{run_id}/join")
@require_auth
@require_permission("runs", "read", owner_check=True)
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
"""Join an existing run's SSE stream."""
"""Join an existing run's SSE stream.
Multi-tenant isolation: only the thread owner can join runs.
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
+94 -22
View File
@@ -13,17 +13,26 @@ matching the LangGraph Platform wire format expected by the
from __future__ import annotations
import logging
import re
import time
import uuid
from typing import Any
from typing import Annotated, Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from fastapi import APIRouter, HTTPException, Path, Request
from pydantic import BaseModel, Field, field_validator
from app.gateway.authz import require_auth, require_permission
from app.gateway.deps import get_checkpointer, get_store
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
# ---------------------------------------------------------------------------
# Thread ID validation (prevents log-injection via control characters)
# ---------------------------------------------------------------------------
_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
ThreadId = Annotated[str, Path(description="Thread UUID", pattern=_UUID_RE.pattern)]
# ---------------------------------------------------------------------------
# Store namespace
# ---------------------------------------------------------------------------
@@ -65,6 +74,13 @@ class ThreadCreateRequest(BaseModel):
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
@field_validator("thread_id")
@classmethod
def _validate_uuid(cls, v: str | None) -> str | None:
if v is not None and not _UUID_RE.match(v):
raise ValueError("thread_id must be a valid UUID")
return v
class ThreadSearchRequest(BaseModel):
"""Request body for searching threads."""
@@ -215,17 +231,23 @@ def _derive_thread_status(checkpoint_tuple) -> str:
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
@require_auth
@require_permission("threads", "delete", owner_check=True)
async def delete_thread_data(thread_id: ThreadId, request: Request) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread.
Cleans DeerFlow-managed thread directories, removes checkpoint data,
and removes the thread record from the Store.
Multi-tenant isolation: only the thread owner can delete their thread.
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
# Clean local filesystem
response = _delete_thread_data(thread_id)
# Remove from Store (best-effort)
store = get_store(request)
if store is not None:
try:
await store.adelete(THREADS_NS, thread_id)
@@ -233,7 +255,6 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
logger.debug("Could not delete store record for thread %s (not critical)", thread_id)
# Remove checkpoints (best-effort)
checkpointer = getattr(request.app.state, "checkpointer", None)
if checkpointer is not None:
try:
if hasattr(checkpointer, "adelete_thread"):
@@ -251,12 +272,23 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
The thread record is written to the Store (for fast listing) and an
empty checkpoint is written to the checkpointer (for state reads).
Idempotent: returns the existing record when ``thread_id`` already exists.
If authenticated, the user's ID is injected into the thread metadata
for multi-tenant isolation.
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
thread_id = body.thread_id or str(uuid.uuid4())
now = time.time()
from app.gateway.deps import get_optional_user_from_request
user = await get_optional_user_from_request(request)
thread_metadata = dict(body.metadata)
if user:
thread_metadata["user_id"] = str(user.id)
# Idempotency: return existing record from Store when already present
if store is not None:
existing_record = await _store_get(store, thread_id)
@@ -279,7 +311,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
"status": "idle",
"created_at": now,
"updated_at": now,
"metadata": body.metadata,
"metadata": thread_metadata,
},
)
except Exception:
@@ -296,7 +328,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
"source": "input",
"writes": None,
"parents": {},
**body.metadata,
**thread_metadata,
"created_at": now,
}
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
@@ -304,13 +336,13 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
logger.exception("Failed to create checkpoint for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to create thread")
logger.info("Thread created: %s", thread_id)
logger.info("Thread created: %s (user_id=%s)", thread_id, thread_metadata.get("user_id"))
return ThreadResponse(
thread_id=thread_id,
status="idle",
created_at=str(now),
updated_at=str(now),
metadata=body.metadata,
metadata=thread_metadata,
)
@@ -330,10 +362,18 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
newly found thread is immediately written to the Store so that the next
search skips Phase 2 for that thread — the Store converges to a full
index over time without a one-shot migration job.
If authenticated, only threads belonging to the current user are returned
(enforced by user_id metadata filter for multi-tenant isolation).
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
from app.gateway.deps import get_optional_user_from_request
user = await get_optional_user_from_request(request)
user_id = str(user.id) if user else None
# -----------------------------------------------------------------------
# Phase 1: Store
# -----------------------------------------------------------------------
@@ -409,6 +449,10 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
# -----------------------------------------------------------------------
results = list(merged.values())
# Multi-tenant isolation: filter by user_id if authenticated
if user_id:
results = [r for r in results if r.metadata.get("user_id") == user_id]
if body.metadata:
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
@@ -420,13 +464,20 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
@router.patch("/{thread_id}", response_model=ThreadResponse)
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
"""Merge metadata into a thread record."""
@require_auth
@require_permission("threads", "write", owner_check=True, inject_record=True)
async def patch_thread(thread_id: ThreadId, request: Request, body: ThreadPatchRequest, thread_record: dict = None) -> ThreadResponse:
"""Merge metadata into a thread record.
Multi-tenant isolation: only the thread owner can patch their thread.
"""
store = get_store(request)
if store is None:
raise HTTPException(status_code=503, detail="Store not available")
record = await _store_get(store, thread_id)
record = thread_record
if record is None:
record = await _store_get(store, thread_id)
if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
@@ -451,12 +502,17 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
@router.get("/{thread_id}", response_model=ThreadResponse)
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
@require_auth
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: ThreadId, request: Request) -> ThreadResponse:
"""Get thread info.
Reads metadata from the Store and derives the accurate execution
status from the checkpointer. Falls back to the checkpointer alone
for threads that pre-date Store adoption (backward compat).
Multi-tenant isolation: returns 404 if the thread does not belong to
the authenticated user.
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
@@ -488,26 +544,33 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
}
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle") # type: ignore[union-attr]
if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
channel_values = checkpoint.get("channel_values", {})
return ThreadResponse(
thread_id=thread_id,
status=status,
created_at=str(record.get("created_at", "")), # type: ignore[union-attr]
updated_at=str(record.get("updated_at", "")), # type: ignore[union-attr]
metadata=record.get("metadata", {}), # type: ignore[union-attr]
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
values=serialize_channel_values(channel_values),
)
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
@require_auth
@require_permission("threads", "read", owner_check=True)
async def get_thread_state(thread_id: ThreadId, request: Request) -> ThreadStateResponse:
"""Get the latest state snapshot for a thread.
Channel values are serialized to ensure LangChain message objects
are converted to JSON-safe dicts.
Multi-tenant isolation: returns 404 if thread does not belong to user.
"""
checkpointer = get_checkpointer(request)
@@ -552,12 +615,16 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
@require_auth
@require_permission("threads", "write", owner_check=True)
async def update_thread_state(thread_id: ThreadId, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
Writes a new checkpoint that merges *body.values* into the latest
channel values, then syncs any updated ``title`` field back to the Store
so that ``/threads/search`` reflects the change immediately.
Multi-tenant isolation: only the thread owner can update their thread.
"""
checkpointer = get_checkpointer(request)
store = get_store(request)
@@ -635,8 +702,13 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
"""Get checkpoint history for a thread."""
@require_auth
@require_permission("threads", "read", owner_check=True)
async def get_thread_history(thread_id: ThreadId, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
"""Get checkpoint history for a thread.
Multi-tenant isolation: returns 404 if thread does not belong to user.
"""
checkpointer = get_checkpointer(request)
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}